Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grpc: fix receiving empty messages when compression is enabled and maxReceiveMessageSize is maxInt64 #7753 #7918

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,14 +874,35 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMes
return nil, 0, err
}

out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1), pool)
out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool)
if err != nil {
out.Free()
return nil, 0, err
}
if err = checkReceiveMessageOverflow(int64(out.Len()), int64(maxReceiveMessageSize), dcReader); err != nil {
arjan-bal marked this conversation as resolved.
Show resolved Hide resolved
return nil, out.Len() + 1, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding 1 to out.Len() can result in an overflow on 32 bit systems. The following part of the function seems strange to me:

Optionally, if data will be over maxReceiveMessageSize, just return the size

I suggest making the following change to avoid this:

  1. Declare a global error for indicating that the max receive size is exceeded:
var	errMaxMessageSizeExceeded = errors.New("max message size exceeded")
  1. When the check here fails, nil, 0, errMaxMessageSizeExceeded. Update the godoc to mention the same.
  2. In the caller, i.e. recvAndDecompress, instead of checking if size > maxReceiveMessageSize, check if err == errMaxMessageSizeExceeded. Also update the error message to not mention the actual size, because we didn't read the entire message anyways: grpc: received message after decompression larger than max %d.

This ensures we're using the returned error to indicate the failure instead of relying on special values of the bytes read count.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please let the reviewer resolve comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not attempt to add out.Len() + 1, instead return errMaxMessageSizeExceeded to signal the same.

}
return out, out.Len(), nil
}

// checkReceiveMessageOverflow checks if the number of bytes read from the stream exceeds
// the maximum receive message size allowed by the client. If the `readBytes` equals
// `maxReceiveMessageSize`, the function attempts to read one more byte from the `dcReader`
// to detect if there's an overflow.
//
// If additional data is read, or an error other than `io.EOF` is encountered, the function
// returns an error indicating that the message size has exceeded the permissible limit.
func checkReceiveMessageOverflow(readBytes, maxReceiveMessageSize int64, dcReader io.Reader) error {
vinothkumarr227 marked this conversation as resolved.
Show resolved Hide resolved
if readBytes == maxReceiveMessageSize {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can invert this check and returns early to reduce indentation.

if readBytes < maxReceiveMessageSize {
   return nil
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed all the comments.

b := make([]byte, 1)
if n, err := dcReader.Read(b); n > 0 || err != io.EOF {
return fmt.Errorf("overflow: received message size is larger than the allowed maxReceiveMessageSize (%d bytes)",
maxReceiveMessageSize)
}
}
return nil
}

type recvCompressor interface {
RecvCompress() string
}
Expand Down
82 changes: 82 additions & 0 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ package grpc
import (
"bytes"
"compress/gzip"
"errors"
"io"
"math"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
_ "google.golang.org/grpc/encoding/gzip"
protoenc "google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/transport"
Expand Down Expand Up @@ -294,3 +298,81 @@ func BenchmarkGZIPCompressor512KiB(b *testing.B) {
func BenchmarkGZIPCompressor1MiB(b *testing.B) {
bmCompressor(b, 1024*1024, NewGZIPCompressor())
}

// compressData compresses data using gzip and returns the compressed bytes.
func compressData(data []byte) []byte {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write(data)
_ = gz.Close()
return buf.Bytes()
}

// TestDecompress tests the decompress function with various scenarios, including
// successful decompression, error handling, and edge cases like overflow or
// premature data end. It ensures that the function behaves correctly with different
// inputs, buffer sizes, and error conditions, using the "gzip" compressor for testing.

func TestDecompress(t *testing.T) {
c := encoding.GetCompressor("gzip")

compressInput := func(input []byte) mem.BufferSlice {
compressedData := compressData(input)
return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)}
}

tests := []struct {
name string
compressor encoding.Compressor
input []byte
maxReceiveMessageSize int
want []byte
error error
}{
{
name: "Decompresses successfully with sufficient buffer size",
compressor: c,
input: []byte("decompressed data"),
maxReceiveMessageSize: 50,
want: []byte("decompressed data"),
error: nil,
},
{
name: "failure, empty receive message",
compressor: c,
input: []byte{},
maxReceiveMessageSize: 10,
want: nil,
error: nil,
},
{
name: "overflow failure, receive message exceeds maxReceiveMessageSize",
compressor: c,
input: []byte("small message"),
maxReceiveMessageSize: 5,
want: nil,
error: errors.New("overflow: received message size is larger than the allowed maxReceiveMessageSize (5 bytes)"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compressedMsg := compressInput(tt.input)
output, numSliceInBuf, err := decompress(tt.compressor, compressedMsg, tt.maxReceiveMessageSize, mem.DefaultBufferPool())
var wantMsg mem.BufferSlice
if tt.want != nil {
wantMsg = mem.BufferSlice{mem.NewBuffer(&tt.want, nil)}
}
if tt.error != nil && err == nil {
t.Fatalf("decompress() error, got err=%v, want err=%v", err, tt.error)
}
if tt.error == nil && numSliceInBuf != wantMsg.Len() {
t.Fatalf("decompress() number of slices mismatch, got = %d, want = %d", numSliceInBuf, wantMsg.Len())
}
if diff := cmp.Diff(wantMsg.Materialize(), output.Materialize()); diff != "" {
t.Fatalf("Mismatch in output:\n%s", diff)
}

})
}
}
Loading