diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index ca1a38de0f86..f438b9941993 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -348,6 +348,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) st.ht.HandleStreams( context.Background(), func(s *ServerStream) { go handleStream(s) }, ) + wantHeader := http.Header{ "Date": nil, "Content-Type": {"application/grpc"}, @@ -379,6 +380,15 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { if err != nil { t.Fatal(err) } + // rst flag setting to verify the noop function: signalDeadlineExceeded + ch := make(chan struct{}, 1) + origSignalDeadlineExceeded := signalDeadlineExceeded + signalDeadlineExceeded = func() { + ch <- struct{}{} + } + defer func() { + signalDeadlineExceeded = origSignalDeadlineExceeded + }() runStream := func(s *ServerStream) { defer bodyw.Close() select { @@ -392,7 +402,9 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded) return } + s.WriteStatus(status.New(codes.DeadlineExceeded, "too slow")) + } ht.HandleStreams( context.Background(), func(s *ServerStream) { go runStream(s) }, @@ -407,6 +419,13 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { "Grpc-Message": {encodeGrpcMessage("too slow")}, } checkHeaderAndTrailer(t, rw, wantHeader, wantTrailer) + select { + case <-ch: // Signal received, continue with the test + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for signalDeadlineExceeded") + return + } + } // TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 0055fddd7ecf..4007dce97445 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -1100,6 +1100,9 @@ func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error { } // Send a RST_STREAM after the trailers if the client has not already half-closed. rst := s.getState() == streamActive + if rst { + signalDeadlineExceeded() + } t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true) for _, sh := range t.stats { // Note: The trailer fields are compressed with hpack after this call returns. @@ -1111,6 +1114,8 @@ func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error { return nil } +var signalDeadlineExceeded = func() {} + // Write converts the data into HTTP2 data frame and sends it out. Non-nil error // is returns if it fails (e.g., framing error, transport error). func (t *http2Server) write(s *ServerStream, hdr []byte, data mem.BufferSlice, _ *WriteOptions) error {