From 4a3c5c5a9c25813738f8e9764dc35571484a2101 Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Mon, 23 Sep 2024 19:03:23 +0200 Subject: [PATCH] feat(go): add basic `westhttp` Signed-off-by: Roman Volosatovs --- examples/go/http/http.go | 2 +- examples/go/http/http_test.go | 100 +++---------- tests/go/wasi/wasi_test.go | 135 +++-------------- westhttp/http.go | 271 ++++++++++++++++++++++++++++++++++ 4 files changed, 316 insertions(+), 192 deletions(-) create mode 100644 westhttp/http.go diff --git a/examples/go/http/http.go b/examples/go/http/http.go index 4a82908..f723f80 100644 --- a/examples/go/http/http.go +++ b/examples/go/http/http.go @@ -45,7 +45,7 @@ func handle(req types.IncomingRequest, out types.ResponseOutparam) *types.ErrorC slog.Debug("setting response outparam") types.ResponseOutparamSet(out, cm.OK[cm.Result[types.ErrorCodeShape, types.OutgoingResponse, types.ErrorCode]](resp)) stream := bodyWrite.OK() - s := "foo bar baz" + s := "hello world" writeRes := stream.BlockingWriteAndFlush(cm.NewList(unsafe.StringData(s), uint(len(s)))) if writeRes.IsErr() { slog.Error("failed to write to stream", "err", writeRes.Err()) diff --git a/examples/go/http/http_test.go b/examples/go/http/http_test.go index 58c92fb..57009f8 100644 --- a/examples/go/http/http_test.go +++ b/examples/go/http/http_test.go @@ -3,100 +3,44 @@ package wasi_test import ( + "io" + "net/http" "testing" - "unsafe" - "github.com/bytecodealliance/wasm-tools-go/cm" west "github.com/rvolosatovs/west" _ "github.com/rvolosatovs/west/bindings" - testtypes "github.com/rvolosatovs/west/bindings/wasi/http/types" - httpext "github.com/rvolosatovs/west/bindings/wasiext/http/ext" incominghandler "github.com/rvolosatovs/west/tests/go/wasi/bindings/wasi/http/incoming-handler" - "github.com/rvolosatovs/west/tests/go/wasi/bindings/wasi/http/types" + "github.com/rvolosatovs/west/westhttp" "github.com/stretchr/testify/assert" ) func TestIncomingHandler(t *testing.T) { west.RunTest(t, func() { - headers := testtypes.NewFields() - headers.Append( - testtypes.FieldKey("foo"), - testtypes.FieldValue(cm.NewList( - unsafe.SliceData([]byte("bar")), - 3, - )), - ) - headers.Append( - testtypes.FieldKey("foo"), - testtypes.FieldValue(cm.NewList( - unsafe.SliceData([]byte("baz")), - 3, - )), - ) - headers.Set( - testtypes.FieldKey("key"), - cm.NewList( - unsafe.SliceData( - []testtypes.FieldValue{ - testtypes.FieldValue(cm.NewList( - unsafe.SliceData([]byte("value")), - 5, - )), - }, - ), - 1, - ), - ) - req := testtypes.NewOutgoingRequest(headers) - req.SetPathWithQuery(cm.Some("test")) - req.SetMethod(testtypes.MethodGet()) - out := httpext.NewResponseOutparam() - incominghandler.Exports.Handle( - types.IncomingRequest(httpext.NewIncomingRequest(req)), - types.ResponseOutparam(out.F0), - ) - out.F1.Subscribe().Block() - respOptResRes := out.F1.Get() - respResRes := respOptResRes.Some() - if !assert.NotNil(t, respResRes) { - t.FailNow() + req, err := http.NewRequest("", "test", nil) + if err != nil { + t.Fatalf("failed to create new HTTP request: %s", err) } - respRes := respResRes.OK() - if !assert.NotNil(t, respRes) || !assert.Nil(t, respRes.Err()) { - t.FailNow() + req.Header.Add("foo", "bar") + req.Header.Add("foo", "baz") + req.Header.Add("key", "value") + resp, err := westhttp.HandleIncomingRequest(incominghandler.Exports.Handle, req) + if err != nil { + t.Fatalf("failed to handle incoming HTTP request: %s", err) } - resp := respRes.OK() - assert.Equal(t, testtypes.StatusCode(200), resp.Status()) - hs := map[string][][]byte{} - for _, h := range resp.Headers().Entries().Slice() { - k := string(h.F0) - hs[k] = append(hs[k], h.F1.Slice()) - } - assert.Equal(t, map[string][][]byte{ + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, http.Header{ "foo": { - []byte("bar"), - []byte("baz"), + "bar", + "baz", }, "key": { - []byte("value"), + "value", }, - }, hs) - bodyRes := resp.Consume() - body := bodyRes.OK() - if !assert.NotNil(t, body) { - t.FailNow() - } - bodyStreamRes := body.Stream() - bodyStream := bodyStreamRes.OK() - if !assert.NotNil(t, bodyStream) { - t.FailNow() - } - bufRes := bodyStream.BlockingRead(4096) - buf := bufRes.OK() - if !assert.NotNil(t, buf) { - t.FailNow() + }, resp.Header) + buf, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read HTTP response body: %s", err) } - assert.Equal(t, []byte("foo bar baz"), buf.Slice()) - bodyStream.ResourceDrop() + assert.Equal(t, []byte("hello world"), buf) }) } diff --git a/tests/go/wasi/wasi_test.go b/tests/go/wasi/wasi_test.go index 1ac371a..66a99fe 100644 --- a/tests/go/wasi/wasi_test.go +++ b/tests/go/wasi/wasi_test.go @@ -5,21 +5,19 @@ package wasi_test import ( + "bytes" _ "embed" + "io" "log" "log/slog" + "net/http" "os" "testing" - "unsafe" - "github.com/bytecodealliance/wasm-tools-go/cm" west "github.com/rvolosatovs/west" _ "github.com/rvolosatovs/west/bindings" - testtypes "github.com/rvolosatovs/west/bindings/wasi/http/types" - teststreams "github.com/rvolosatovs/west/bindings/wasi/io/streams" - httpext "github.com/rvolosatovs/west/bindings/wasiext/http/ext" incominghandler "github.com/rvolosatovs/west/tests/go/wasi/bindings/wasi/http/incoming-handler" - "github.com/rvolosatovs/west/tests/go/wasi/bindings/wasi/http/types" + "github.com/rvolosatovs/west/westhttp" "github.com/stretchr/testify/assert" ) @@ -48,120 +46,31 @@ func init() { func TestIncomingHandler(t *testing.T) { west.RunTest(t, func() { - headers := testtypes.NewFields() - headers.Append( - testtypes.FieldKey("foo"), - testtypes.FieldValue(cm.NewList( - unsafe.SliceData([]byte("bar")), - 3, - )), - ) - headers.Append( - testtypes.FieldKey("foo"), - testtypes.FieldValue(cm.NewList( - unsafe.SliceData([]byte("baz")), - 3, - )), - ) - headers.Set( - testtypes.FieldKey("key"), - cm.NewList( - unsafe.SliceData( - []testtypes.FieldValue{ - testtypes.FieldValue(cm.NewList( - unsafe.SliceData([]byte("value")), - 5, - )), - }, - ), - 1, - ), - ) - req := testtypes.NewOutgoingRequest(headers) - req.SetPathWithQuery(cm.Some("5")) - req.SetMethod(testtypes.MethodPost()) - reqBodyRes := req.Body() - if !assert.Nil(t, reqBodyRes.Err()) { - t.FailNow() + req, err := http.NewRequest(http.MethodPost, "5", bytes.NewReader([]byte("foo bar baz"))) + if err != nil { + t.Fatalf("failed to create new HTTP request: %s", err) } - reqBody := reqBodyRes.OK() - reqStreamRes := reqBody.Write() - if !assert.Nil(t, reqStreamRes.Err()) { - t.FailNow() + req.Header.Add("foo", "bar") + req.Header.Add("foo", "baz") + req.Header.Add("key", "value") + resp, err := westhttp.HandleIncomingRequest(incominghandler.Exports.Handle, req) + if err != nil { + t.Fatalf("failed to handle incoming HTTP request: %s", err) } - reqStream := reqStreamRes.OK() - writeRes := reqStream.BlockingWriteAndFlush(cm.NewList( - unsafe.SliceData([]byte("foo bar baz")), - 11, - )) - if !assert.Nil(t, writeRes.Err()) { - t.FailNow() - } - reqStream.ResourceDrop() - reqBodyFinishRes := testtypes.OutgoingBodyFinish(*reqBody, cm.None[testtypes.Fields]()) - if !assert.Nil(t, reqBodyFinishRes.Err()) { - t.FailNow() - } - - out := httpext.NewResponseOutparam() - incominghandler.Exports.Handle( - types.IncomingRequest(httpext.NewIncomingRequest(req)), - types.ResponseOutparam(out.F0), - ) - out.F1.Subscribe().Block() - respOptResRes := out.F1.Get() - respResRes := respOptResRes.Some() - if !assert.NotNil(t, respResRes) || !assert.Nil(t, respResRes.Err()) { - t.FailNow() - } - respRes := respResRes.OK() - if !assert.Nil(t, respRes.Err()) { - t.Fatal(*respRes.Err()) - } - resp := respRes.OK() - assert.Equal(t, testtypes.StatusCode(200), resp.Status()) - hs := map[string][][]byte{} - for _, h := range resp.Headers().Entries().Slice() { - k := string(h.F0) - hs[k] = append(hs[k], h.F1.Slice()) - } - assert.Equal(t, map[string][][]byte{ + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, http.Header{ "foo": { - []byte("bar"), - []byte("baz"), + "bar", + "baz", }, "key": { - []byte("value"), + "value", }, - }, hs) - bodyRes := resp.Consume() - if !assert.Nil(t, bodyRes.Err()) { - t.FailNow() - } - - body := bodyRes.OK() - bodyStreamRes := body.Stream() - if !assert.Nil(t, bodyStreamRes.Err()) { - t.FailNow() - } - - bodyStream := bodyStreamRes.OK() - var buf []byte - for { - bufRes := bodyStream.BlockingRead(4096) - if bufRes.IsErr() && *bufRes.Err() == teststreams.StreamErrorClosed() { - break - } - if !assert.Nil(t, bufRes.Err()) { - if !assert.False(t, bufRes.Err().Closed()) { - t.FailNow() - } else { - t.Fatal(*bufRes.Err().LastOperationFailed()) - } - } - buf = append(buf, bufRes.OK().Slice()...) + }, resp.Header) + buf, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read HTTP response body: %s", err) } assert.Equal(t, []byte("🧭🧭🧭🧭🧭foo bar baz"), buf) - bodyStream.ResourceDrop() }) } diff --git a/westhttp/http.go b/westhttp/http.go new file mode 100644 index 0000000..f300b77 --- /dev/null +++ b/westhttp/http.go @@ -0,0 +1,271 @@ +package westhttp + +import ( + "bytes" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "unsafe" + + "github.com/bytecodealliance/wasm-tools-go/cm" + "github.com/rvolosatovs/west/bindings/wasi/http/types" + "github.com/rvolosatovs/west/bindings/wasi/io/poll" + "github.com/rvolosatovs/west/bindings/wasiext/http/ext" +) + +func NewFields(h http.Header) types.Fields { + headers := types.NewFields() + for name, values := range h { + for _, v := range values { + headers.Append( + types.FieldKey(name), + types.FieldValue(cm.NewList( + unsafe.SliceData([]byte(v)), + uint(len(v)), + )), + ) + } + } + return headers +} + +func NewOutgoingRequest(req *http.Request) (types.OutgoingRequest, func(func(poll.Pollable)) error, error) { + if req.TLS != nil { + return 0, nil, errors.New("`http.Request.TLS` is not currently supported") + } + res := types.NewOutgoingRequest(NewFields(req.Header)) + if s := req.URL.RequestURI(); s != "" { + if res.SetPathWithQuery(cm.Some(s)) { + return 0, nil, fmt.Errorf("failed to set path with query to `%s`", s) + } + } + if s := req.URL.Hostname(); s != "" { + if res.SetAuthority(cm.Some(s)) { + return 0, nil, fmt.Errorf("failed to set authority to `%s`", s) + } + } + if s := req.URL.Scheme; s != "" { + switch s { + case "http": + if res.SetScheme(cm.Some(types.SchemeHTTP())) { + return 0, nil, errors.New("failed to set scheme to HTTP") + } + case "https": + if res.SetScheme(cm.Some(types.SchemeHTTPS())) { + return 0, nil, errors.New("failed to set scheme to HTTPS") + } + default: + if res.SetScheme(cm.Some(types.SchemeOther(s))) { + return 0, nil, fmt.Errorf("failed to set scheme to `%s`", s) + } + } + } + + switch req.Method { + case "", http.MethodGet: + if res.SetMethod(types.MethodGet()) { + return 0, nil, errors.New("failed to set method to GET") + } + case http.MethodHead: + if res.SetMethod(types.MethodHead()) { + return 0, nil, errors.New("failed to set method to HEAD") + } + case http.MethodPost: + if res.SetMethod(types.MethodPost()) { + return 0, nil, errors.New("failed to set method to POST") + } + case http.MethodPut: + if res.SetMethod(types.MethodPut()) { + return 0, nil, errors.New("failed to set method to PUT") + } + case http.MethodPatch: + if res.SetMethod(types.MethodPatch()) { + return 0, nil, errors.New("failed to set method to PATCH") + } + case http.MethodDelete: + if res.SetMethod(types.MethodDelete()) { + return 0, nil, errors.New("failed to set method to DELETE") + } + case http.MethodConnect: + if res.SetMethod(types.MethodConnect()) { + return 0, nil, errors.New("failed to set method to CONNECT") + } + case http.MethodOptions: + if res.SetMethod(types.MethodOptions()) { + return 0, nil, errors.New("failed to set method to OPTIONS") + } + case http.MethodTrace: + if res.SetMethod(types.MethodTrace()) { + return 0, nil, errors.New("failed to set method to TRACE") + } + default: + if res.SetMethod(types.MethodOther(req.Method)) { + return 0, nil, fmt.Errorf("failed to set method to `%s`", req.Method) + } + } + if req.Body == nil { + return res, nil, nil + } + + resBodyRes := res.Body() + if err := resBodyRes.Err(); err != nil { + return 0, nil, errors.New("failed to take outgoing request body") + } + resBody := resBodyRes.OK() + resStreamRes := resBody.Write() + if err := resStreamRes.Err(); err != nil { + return 0, nil, errors.New("failed to take outgoing request body stream") + } + resStream := resStreamRes.OK() + return res, func(poll func(poll.Pollable)) error { + for { + checkWriteRes := resStream.CheckWrite() + if err := checkWriteRes.Err(); err != nil { + if err.Closed() { + slog.Debug("write stream closed") + return io.EOF + } + slog.Debug("failed to check write buffer capacity") + return fmt.Errorf("failed to check write buffer capacity: %s", err.LastOperationFailed().ToDebugString()) + } + wn := *checkWriteRes.OK() + if wn == 0 { + p := resStream.Subscribe() + poll(p) + p.ResourceDrop() + continue + } + if wn > 4096 { + wn = 4096 + } + buf := make([]byte, wn) + slog.Debug("reading buffer from body stream", "n", wn) + rn, err := req.Body.Read(buf[:]) + if rn > 0 { + slog.Debug("writing body stream chunk", "buf", buf[:rn]) + writeRes := resStream.Write(cm.NewList(unsafe.SliceData(buf), uint(rn))) + if err := writeRes.Err(); err != nil { + if err.Closed() { + slog.Debug("write stream closed") + return io.EOF + } + slog.Debug("failed to write to buffer to stream") + return fmt.Errorf("failed to write buffer: %s", err.LastOperationFailed().ToDebugString()) + } + } + if err == nil { + continue + } + if err != io.EOF { + slog.Debug("failed to read buffer from body stream") + return err + } + flushRes := resStream.Flush() + if err := flushRes.Err(); err != nil { + if err.Closed() { + slog.Debug("write stream closed") + return io.EOF + } + slog.Debug("failed to flush body stream") + return fmt.Errorf("failed to flush body stream: %s", err.LastOperationFailed().ToDebugString()) + } + resStream.ResourceDrop() + + trailers := cm.None[types.Fields]() + if len(req.Trailer) > 0 { + trailers = cm.Some(NewFields(req.Trailer)) + } + slog.Debug("finishing outgoing body") + resBodyFinishRes := types.OutgoingBodyFinish(*resBody, trailers) + if err := resBodyFinishRes.Err(); err != nil { + return fmt.Errorf("failed to finish outgoing body: %v", err) + } + return nil + } + }, nil +} + +func NewIncomingRequest(req *http.Request) (types.IncomingRequest, func(func(poll.Pollable)) error, error) { + res, write, err := NewOutgoingRequest(req) + return ext.NewIncomingRequest(res), write, err +} + +func NewIncomingResponse(resp types.IncomingResponse) (*http.Response, error) { + header := make(http.Header, len(resp.Headers().Entries().Slice())) + for _, h := range resp.Headers().Entries().Slice() { + k := string(h.F0) + header[k] = append(header[k], string(h.F1.Slice())) + } + bodyRes := resp.Consume() + if err := bodyRes.Err(); err != nil { + return nil, errors.New("failed to get response body") + } + body := bodyRes.OK() + bodyStreamRes := body.Stream() + if err := bodyStreamRes.Err(); err != nil { + return nil, errors.New("failed to get response body stream") + } + bodyStream := bodyStreamRes.OK() + var buf []byte + for { + bufRes := bodyStream.BlockingRead(4096) + if err := bufRes.Err(); err != nil { + if err.Closed() { + break + } + return nil, fmt.Errorf("failed to read response body stream: %s", err.LastOperationFailed().ToDebugString()) + } + buf = append(buf, bufRes.OK().Slice()...) + } + bodyStream.ResourceDrop() + body.ResourceDrop() + + var trailer http.Header + return &http.Response{ + Status: http.StatusText(int(resp.Status())), + StatusCode: int(resp.Status()), + Body: io.NopCloser(bytes.NewReader(buf)), + Header: header, + Trailer: trailer, + }, nil +} + +func HandleIncomingRequest[I, O ~uint32](f func(I, O), req *http.Request) (*http.Response, error) { + wr, write, err := NewIncomingRequest(req) + if err != nil { + return nil, fmt.Errorf("failed to create new outgoing HTTP request: %w", err) + } + if write != nil { + if err := write(poll.Pollable.Block); err != nil { + return nil, fmt.Errorf("failed to write body: %w", err) + } + } + + out := ext.NewResponseOutparam() + f( + I(wr), + O(out.F0), + ) + out.F1.Subscribe().Block() + respOptResRes := out.F1.Get() + respResRes := respOptResRes.Some() + if respResRes == nil { + return nil, errors.New("response missing") + } + if err := respResRes.Err(); err != nil { + return nil, errors.New("failed to get response") + } + + respRes := respResRes.OK() + if err := respRes.Err(); err != nil { + return nil, fmt.Errorf("failed to receive response: %v", err) + } + resp, err := NewIncomingResponse(*respRes.OK()) + if err != nil { + return nil, fmt.Errorf("failed to parse response: %v", err) + } + resp.Request = req + return resp, nil +}