Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in handling multiple incoming bodies piped through a TransFormStream to an outgoing body #171

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions builtins/web/fetch/fetch_event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ bool FetchEvent::init_incoming_request(JSContext *cx, JS::HandleObject self,
JS::RootedObject request(
cx, &JS::GetReservedSlot(self, static_cast<uint32_t>(Slots::Request)).toObject());

MOZ_ASSERT(!Request::request_handle(request));

MOZ_ASSERT(!RequestOrResponse::maybe_handle(request));
JS::SetReservedSlot(request, static_cast<uint32_t>(Request::Slots::Request),
JS::PrivateValue(req));

Expand Down Expand Up @@ -175,7 +174,7 @@ bool start_response(JSContext *cx, JS::HandleObject response_obj) {
host_api::HttpOutgoingResponse* response =
host_api::HttpOutgoingResponse::make(status, std::move(headers));

auto existing_handle = Response::response_handle(response_obj);
auto existing_handle = Response::maybe_response_handle(response_obj);
if (existing_handle) {
MOZ_ASSERT(existing_handle->is_incoming());
} else {
Expand Down
215 changes: 123 additions & 92 deletions builtins/web/fetch/request-response.cpp

Large diffs are not rendered by default.

11 changes: 5 additions & 6 deletions builtins/web/fetch/request-response.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class RequestOrResponse final {
static bool is_instance(JSObject *obj);
static bool is_incoming(JSObject *obj);
static host_api::HttpRequestResponseBase *handle(JSObject *obj);
static host_api::HttpHeadersReadOnly *headers_handle(JSObject *obj);
static host_api::HttpRequestResponseBase *maybe_handle(JSObject *obj);
static host_api::HttpHeadersReadOnly *maybe_headers_handle(JSObject *obj);
static bool has_body(JSObject *obj);
static host_api::HttpIncomingBody *incoming_body_handle(JSObject *obj);
static host_api::HttpOutgoingBody *outgoing_body_handle(JSObject *obj);
Expand Down Expand Up @@ -66,7 +67,8 @@ class RequestOrResponse final {
*/
static JSObject *headers(JSContext *cx, JS::HandleObject obj);

static bool append_body(JSContext *cx, JS::HandleObject self, JS::HandleObject source);
static bool append_body(JSContext *cx, JS::HandleObject self, JS::HandleObject source,
api::TaskCompletionCallback callback, HandleObject callback_receiver);

using ParseBodyCB = bool(JSContext *cx, JS::HandleObject self, JS::UniqueChars buf, size_t len);

Expand Down Expand Up @@ -142,9 +144,6 @@ class Request final : public BuiltinImpl<Request> {

static JSObject *response_promise(JSObject *obj);
static JSString *method(JS::HandleObject obj);
static host_api::HttpRequest *request_handle(JSObject *obj);
static host_api::HttpOutgoingRequest *outgoing_handle(JSObject *obj);
static host_api::HttpIncomingRequest *incoming_handle(JSObject *obj);

static const JSFunctionSpec static_methods[];
static const JSPropertySpec static_properties[];
Expand Down Expand Up @@ -209,7 +208,7 @@ class Response final : public BuiltinImpl<Response> {
static JSObject *init_slots(HandleObject response);
static JSObject *create_incoming(JSContext *cx, host_api::HttpIncomingResponse *response);

static host_api::HttpResponse *response_handle(JSObject *obj);
static host_api::HttpResponse *maybe_response_handle(JSObject *obj);
static uint16_t status(JSObject *obj);
static JSString *status_message(JSObject *obj);
static void set_status_message_from_code(JSContext *cx, JSObject *obj, uint16_t code);
Expand Down
83 changes: 58 additions & 25 deletions host-apis/wasi-0.2.0/host_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,31 +574,69 @@ void HttpOutgoingBody::write(const uint8_t *bytes, size_t len) {
MOZ_RELEASE_ASSERT(write_to_outgoing_body(borrow, bytes, len));
}

Result<Void> HttpOutgoingBody::write_all(const uint8_t *bytes, size_t len) {
if (!valid()) {
// TODO: proper error handling for all 154 error codes.
return Result<Void>::err(154);
}
class BodyWriteAllTask final : public api::AsyncTask {
HttpOutgoingBody *outgoing_body_;
PollableHandle outgoing_pollable_;

auto *state = static_cast<OutgoingBodyHandle *>(handle_state_.get());
Borrow<OutputStream> borrow(state->stream_handle_);
api::TaskCompletionCallback cb_;
Heap<JSObject *> cb_receiver_;
HostBytes bytes_;
size_t offset_ = 0;

while (len > 0) {
auto capacity_res = capacity();
if (capacity_res.is_err()) {
// TODO: proper error handling for all 154 error codes.
return Result<Void>::err(154);
public:
explicit BodyWriteAllTask(HttpOutgoingBody *outgoing_body, HostBytes bytes,
api::TaskCompletionCallback completion_callback,
HandleObject callback_receiver)
: outgoing_body_(outgoing_body), cb_(completion_callback),
cb_receiver_(callback_receiver), bytes_(std::move(bytes)) {
outgoing_pollable_ = outgoing_body_->subscribe().unwrap();
}

[[nodiscard]] bool run(api::Engine *engine) override {
auto res = outgoing_body_->capacity();
if (res.is_err()) {
return false;
}
auto capacity = capacity_res.unwrap();
auto bytes_to_write = std::min(len, static_cast<size_t>(capacity));
if (!write_to_outgoing_body(borrow, bytes, len)) {
return Result<Void>::err(154);
uint64_t capacity = res.unwrap();
MOZ_ASSERT(capacity >= 0);
auto bytes_to_write = std::min(bytes_.len - offset_, static_cast<size_t>(capacity));
outgoing_body_->write(bytes_.ptr.get() + offset_, bytes_to_write);
offset_ += bytes_to_write;
if (offset_ < bytes_.len) {
engine->queue_async_task(this);
tschneidereit marked this conversation as resolved.
Show resolved Hide resolved
} else {
bytes_.ptr.reset();
RootedObject receiver(engine->cx(), cb_receiver_);
bool result = cb_(engine->cx(), receiver);
cb_ = nullptr;
cb_receiver_ = nullptr;
return result;
}

bytes += bytes_to_write;
len -= bytes_to_write;
return true;
}

[[nodiscard]] bool cancel(api::Engine *engine) override {
MOZ_ASSERT_UNREACHABLE("BodyWriteAllTask's semantics don't allow for cancellation");
return true;
}

[[nodiscard]] int32_t id() override {
return outgoing_pollable_;
}

void trace(JSTracer *trc) override {
JS::TraceEdge(trc, &cb_receiver_, "BodyWriteAllTask completion callback receiver");
}
};

Result<Void> HttpOutgoingBody::write_all(api::Engine *engine, HostBytes bytes,
api::TaskCompletionCallback callback, HandleObject cb_receiver) {
if (!valid()) {
// TODO: proper error handling for all 154 error codes.
return Result<Void>::err(154);
}
engine->queue_async_task(new BodyWriteAllTask(this, std::move(bytes), callback, cb_receiver));
return {};
}

Expand Down Expand Up @@ -638,13 +676,8 @@ class BodyAppendTask final : public api::AsyncTask {
HandleObject callback_receiver)
: incoming_body_(incoming_body), outgoing_body_(outgoing_body), cb_(completion_callback),
cb_receiver_(callback_receiver), state_(State::BlockedOnBoth) {
auto res = incoming_body_->subscribe();
MOZ_ASSERT(!res.is_err());
incoming_pollable_ = res.unwrap();

res = outgoing_body_->subscribe();
MOZ_ASSERT(!res.is_err());
outgoing_pollable_ = res.unwrap();
incoming_pollable_ = incoming_body_->subscribe().unwrap();
outgoing_pollable_ = outgoing_body_->subscribe().unwrap();
}

[[nodiscard]] bool run(api::Engine *engine) override {
Expand Down
3 changes: 2 additions & 1 deletion include/host_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ class HttpOutgoingBody final : public Pollable {
/// The host doesn't necessarily write all bytes in any particular call to
/// `write`, so to ensure all bytes are written, we call it in a loop.
/// TODO: turn into an async task that writes chunks of the passed buffer until done.
Result<Void> write_all(const uint8_t *bytes, size_t len);
Result<Void> write_all(api::Engine *engine, HostBytes bytes, api::TaskCompletionCallback callback,
HandleObject cb_receiver);

/// Append an HttpIncomingBody to this one.
Result<Void> append(api::Engine *engine, HttpIncomingBody *other,
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/multi-stream-forwarding/expect_serve_body.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This sentence will be streamed in chunks.
58 changes: 58 additions & 0 deletions tests/e2e/multi-stream-forwarding/multi-stream-forwarding.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
addEventListener('fetch', async (event) => {
try {
if (!event.request.url.includes('/nested')) {
event.respondWith(main(event));
return;
}

let encoder = new TextEncoder();
let body = new TransformStream({
start(controller) {
},
transform(chunk, controller) {
controller.enqueue(encoder.encode(chunk));
},
flush(controller) {
}
});
let writer = body.writable.getWriter();
event.respondWith(new Response(body.readable));
let word = new URL(event.request.url).searchParams.get('word');
console.log(`streaming word: ${word}`);
for (let letter of word) {
console.log(`Writing letter ${letter}`);
await writer.write(letter);
}
if (word.endsWith(".")) {
await writer.write("\n");
}
await writer.close();
} catch (e) {
console.error(e);
}
});
async function main(event) {
let fullBody = "This sentence will be streamed in chunks.";
let responses = [];
for (let word of fullBody.split(" ").join("+ ").split(" ")) {
responses.push((await fetch(`${event.request.url}/nested?word=${word}`)).body);
}
return new Response(concatStreams(responses));
}

function concatStreams(streams) {
let { readable, writable } = new TransformStream();
async function iter() {
for (let stream of streams) {
try {
await stream.pipeTo(writable, {preventClose: true});
} catch (e) {
console.error(`error during pipeline execution: ${e}`);
}
}
console.log("closing writable");
await writable.close();
}
iter();
return readable;
}
2 changes: 2 additions & 0 deletions tests/tests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ test_e2e(syntax-err)
test_e2e(tla-err)
test_e2e(tla-runtime-resolve)
test_e2e(tla)
test_e2e(stream-forwarding)
test_e2e(multi-stream-forwarding)

test_integration(btoa)
test_integration(crypto)
Expand Down
Loading