Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grpc: fix receiving empty messages when compression is enabled and maxReceiveMessageSize is maxInt64 #7753 #7914

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,14 +874,35 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMes
return nil, 0, err
}

out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1), pool)
out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool)
if err != nil {
out.Free()
return nil, 0, err
}
if err = checkReceiveMessageOverflow(int64(out.Len()), int64(maxReceiveMessageSize), dcReader); err != nil {
return nil, out.Len() + 1, err
}
return out, out.Len(), nil
}

// checkReceiveMessageOverflow checks if the number of bytes read from the stream exceeds
// the maximum receive message size allowed by the client. If the `readBytes` equals
// `maxReceiveMessageSize`, the function attempts to read one more byte from the `dcReader`
// to detect if there's an overflow.
//
// If additional data is read, or an error other than `io.EOF` is encountered, the function
// returns an error indicating that the message size has exceeded the permissible limit.
func checkReceiveMessageOverflow(readBytes, maxReceiveMessageSize int64, dcReader io.Reader) error {
if readBytes == maxReceiveMessageSize {
b := make([]byte, 1)
if n, err := dcReader.Read(b); n > 0 || err != io.EOF {
return fmt.Errorf("overflow: received message size is larger than the allowed maxReceiveMessageSize (%d bytes)",
maxReceiveMessageSize)
}
}
return nil
}

type recvCompressor interface {
RecvCompress() string
}
Expand Down
83 changes: 83 additions & 0 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ package grpc
import (
"bytes"
"compress/gzip"
"errors"
"io"
"math"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
_ "google.golang.org/grpc/encoding/gzip"
protoenc "google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/transport"
Expand Down Expand Up @@ -294,3 +298,82 @@ func BenchmarkGZIPCompressor512KiB(b *testing.B) {
func BenchmarkGZIPCompressor1MiB(b *testing.B) {
bmCompressor(b, 1024*1024, NewGZIPCompressor())
}

// compressData compresses data using gzip and returns the compressed bytes.
func compressData(data []byte) []byte {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write(data)
_ = gz.Close()
return buf.Bytes()
}

// TestDecompress tests the decompress function with various scenarios, including
// successful decompression, error handling, and edge cases like overflow or
// premature data end. It ensures that the function behaves correctly with different
// inputs, buffer sizes, and error conditions, using the "gzip" compressor for testing.

func TestDecompress(t *testing.T) {
c := encoding.GetCompressor("gzip")

compressInput := func(input []byte) mem.BufferSlice {
compressedData := compressData(input)
return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)}
}

tests := []struct {
name string
compressor encoding.Compressor
input []byte
maxReceiveMessageSize int
want []byte
error error
}{
{
name: "Decompresses successfully with sufficient buffer size",
compressor: c,
input: []byte("decompressed data"),
maxReceiveMessageSize: 50,
want: []byte("decompressed data"),
error: nil,
},
{
name: "failure, empty receive message",
compressor: c,
input: []byte{},
maxReceiveMessageSize: 10,
want: nil,
error: nil,
},
{
name: "overflow failure, receive message exceeds maxReceiveMessageSize",
compressor: c,
input: []byte("small message"),
maxReceiveMessageSize: 5,
want: nil,
error: errors.New("overflow: received message size is larger than the allowed maxReceiveMessageSize (5 bytes)"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compressedMsg := compressInput(tt.input)
output, numSliceInBuf, err := decompress(tt.compressor, compressedMsg, tt.maxReceiveMessageSize, mem.DefaultBufferPool())

var wantMsg mem.BufferSlice
if tt.want != nil {
wantMsg = mem.BufferSlice{mem.NewBuffer(&tt.want, nil)}
}
if tt.error != nil && err == nil {
t.Fatalf("decompress() error, got err=%v, want err=%v", err, tt.error)
}
if tt.error == nil && numSliceInBuf != wantMsg.Len() {
t.Fatalf("decompress() number of slices mismatch, got = %d, want = %d", numSliceInBuf, wantMsg.Len())
}
if diff := cmp.Diff(wantMsg.Materialize(), output.Materialize()); diff != "" {
t.Fatalf("Mismatch in output:\n%s", diff)
}

})
}
}