Skip to content

Commit

Permalink
incoming interceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
mkysel committed Dec 20, 2024
1 parent f223c36 commit 23220d5
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 4 deletions.
18 changes: 14 additions & 4 deletions pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,20 @@ func NewAPIServer(
return nil, err
}

unary := []grpc.UnaryServerInterceptor{prometheus.UnaryServerInterceptor}
stream := []grpc.StreamServerInterceptor{prometheus.StreamServerInterceptor}
incomingInterceptor, err := server.NewIncomingInterceptor(log)
if err != nil {
return nil, err
}
unary := []grpc.UnaryServerInterceptor{
prometheus.UnaryServerInterceptor,
loggingInterceptor.Unary(),
incomingInterceptor.Unary(),
}
stream := []grpc.StreamServerInterceptor{
prometheus.StreamServerInterceptor,
loggingInterceptor.Stream(),
incomingInterceptor.Stream(),
}

options := []grpc.ServerOption{
grpc.ChainUnaryInterceptor(unary...),
Expand All @@ -81,8 +93,6 @@ func NewAPIServer(
PermitWithoutStream: true,
MinTime: 15 * time.Second,
}),
grpc.ChainUnaryInterceptor(loggingInterceptor.Unary()),
grpc.ChainStreamInterceptor(loggingInterceptor.Stream()),

// grpc.MaxRecvMsgSize(s.Config.Options.MaxMsgSize),
}
Expand Down
75 changes: 75 additions & 0 deletions pkg/interceptors/server/incoming.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package server

import (
"context"
"fmt"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
"net"
)

type IncomingInterceptor struct {
logger *zap.Logger
}

func NewIncomingInterceptor(logger *zap.Logger) (*IncomingInterceptor, error) {
if logger == nil {
return nil, fmt.Errorf("logger is required")
}

return &IncomingInterceptor{
logger: logger,
}, nil
}

func (i *IncomingInterceptor) logIncomingAddressIfAvailable(ctx context.Context) {
if i.logger.Core().Enabled(zap.DebugLevel) {
if p, ok := peer.FromContext(ctx); ok {
clientAddr := p.Addr.String()
var dnsName []string
// Attempt to resolve the DNS name
host, _, err := net.SplitHostPort(clientAddr)
if err == nil {
dnsName, err = net.LookupAddr(host)
if err != nil || len(dnsName) == 0 {
dnsName = []string{"Unknown"}
}
} else {
dnsName = []string{"Unknown"}
}
i.logger.Debug(
fmt.Sprintf("Incoming request from %s (DNS: %s)", clientAddr, dnsName[0]),
)
}
}
}

// Unary intercepts unary RPC calls to log errors.
func (i *IncomingInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
i.logIncomingAddressIfAvailable(ctx)

// Call the handler to complete the RPC
return handler(ctx, req)
}
}

// Stream intercepts stream RPC calls to log errors.
func (i *IncomingInterceptor) Stream() grpc.StreamServerInterceptor {
return func(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
i.logIncomingAddressIfAvailable(ss.Context())
// Call the handler to complete the RPC
return handler(srv, ss)
}
}
80 changes: 80 additions & 0 deletions pkg/interceptors/server/incoming_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package server

import (
"context"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
"net"
"testing"
)

func mockUnaryHandler(ctx context.Context, req interface{}) (interface{}, error) {
return "response", nil
}

func mockStreamHandler(srv interface{}, ss grpc.ServerStream) error {
return nil
}

type mockServerStreamWithContext struct {
grpc.ServerStream
ctx context.Context
}

func (m *mockServerStreamWithContext) Context() context.Context {
return m.ctx
}

func TestIncomingInterceptor_Unary(t *testing.T) {
logger, logs := createTestLogger()

interceptor, err := NewIncomingInterceptor(logger)
if err != nil {
t.Fatalf("failed to create interceptor: %v", err)
}

ctx := peer.NewContext(context.Background(), &peer.Peer{
Addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345},
})

// Call the unary interceptor
_, _ = interceptor.Unary()(ctx, nil, nil, mockUnaryHandler)

require.NoError(t, err)
require.Equal(t, 1, logs.Len(), "expected one log entry but got none")

logEntry := logs.All()[0]

require.Equal(t, zapcore.DebugLevel, logEntry.Level)
require.Contains(t, logEntry.Message, "Incoming request")

}

func TestIncomingInterceptor_Stream(t *testing.T) {
logger, logs := createTestLogger()

interceptor, err := NewIncomingInterceptor(logger)
if err != nil {
t.Fatalf("failed to create interceptor: %v", err)
}

ctx := peer.NewContext(context.Background(), &peer.Peer{
Addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345},
})

// Create a mock server stream
stream := &mockServerStreamWithContext{ctx: ctx}

// Call the stream interceptor
_ = interceptor.Stream()(nil, stream, nil, mockStreamHandler)

require.NoError(t, err)
require.Equal(t, 1, logs.Len(), "expected one log entry but got none")

logEntry := logs.All()[0]

require.Equal(t, zapcore.DebugLevel, logEntry.Level)
require.Contains(t, logEntry.Message, "Incoming request")
}

0 comments on commit 23220d5

Please sign in to comment.