Skip to content

Commit

Permalink
Thread through client IP and port to Lua context. #226
Browse files Browse the repository at this point in the history
  • Loading branch information
mofirouz committed Aug 6, 2018
1 parent 066cd7e commit f3d3857
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 46 deletions.
28 changes: 26 additions & 2 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"compress/flate"
"compress/gzip"

"github.com/dgrijalva/jwt-go"
"github.com/gofrs/uuid"
"github.com/golang/protobuf/jsonpb"
Expand All @@ -49,6 +50,7 @@ import (
"google.golang.org/grpc/credentials"
_ "google.golang.org/grpc/encoding/gzip" // enable gzip compression on server for grpc
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -256,8 +258,30 @@ func apiInterceptorFunc(logger *zap.Logger, config Config, runtimePool *RuntimeP
startNanos := time.Now().UTC().UnixNano()
span := trace.NewSpan(name, nil, trace.StartOptions{})

clientAddr := ""
clientIP := ""
clientPort := ""
md, _ := metadata.FromIncomingContext(ctx)
if ips := md.Get("x-forwarded-for"); len(ips) > 0 {
// look for gRPC-Gateway / LB header
clientAddr = strings.Split(ips[0], ",")[0]
} else if peerInfo, ok := peer.FromContext(ctx); ok {
// if missing, try to look up gRPC peer info
clientAddr = peerInfo.Addr.String()
}

clientAddr = strings.TrimSpace(clientAddr)
if host, port, err := net.SplitHostPort(clientAddr); err == nil {
clientIP = host
clientPort = port
} else if addrErr, ok := err.(*net.AddrError); ok && addrErr.Err == "missing port in address" {
clientIP = clientAddr
} else {
logger.Debug("Could not extract client address from request.", zap.Error(err))
}

// Actual before hook function execution.
beforeHookResult, hookErr := invokeReqBeforeHook(logger, config, runtimePool, jsonpbMarshaler, jsonpbUnmarshaler, "", uid, username, expiry, info.FullMethod, req)
beforeHookResult, hookErr := invokeReqBeforeHook(logger, config, runtimePool, jsonpbMarshaler, jsonpbUnmarshaler, "", uid, username, expiry, clientIP, clientPort, info.FullMethod, req)

// Stats measurement end boundary.
span.End()
Expand All @@ -283,7 +307,7 @@ func apiInterceptorFunc(logger *zap.Logger, config Config, runtimePool *RuntimeP
span := trace.NewSpan(name, nil, trace.StartOptions{})

// Actual after hook function execution.
invokeReqAfterHook(logger, config, runtimePool, jsonpbMarshaler, "", uid, username, expiry, info.FullMethod, handlerResult)
invokeReqAfterHook(logger, config, runtimePool, jsonpbMarshaler, "", uid, username, expiry, info.FullMethod, clientIP, clientPort, handlerResult)

// Stats measurement end boundary.
span.End()
Expand Down
26 changes: 25 additions & 1 deletion server/api_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package server

import (
"net"
"strings"

"github.com/gofrs/uuid"
Expand All @@ -25,6 +26,7 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -71,7 +73,29 @@ func (s *ApiServer) RpcFunc(ctx context.Context, in *api.Rpc) (*api.Rpc, error)
return nil, status.Error(codes.NotFound, "RPC function not found")
}

result, fnErr, code := runtime.InvokeFunction(ExecutionModeRPC, lf, queryParams, uid, username, expiry, "", in.Payload)
clientAddr := ""
clientIP := ""
clientPort := ""
md, _ := metadata.FromIncomingContext(ctx)
if ips := md.Get("x-forwarded-for"); len(ips) > 0 {
// look for gRPC-Gateway / LB header
clientAddr = strings.Split(ips[0], ",")[0]
} else if peerInfo, ok := peer.FromContext(ctx); ok {
// if missing, try to look up gRPC peer info
clientAddr = peerInfo.Addr.String()
}

clientAddr = strings.TrimSpace(clientAddr)
if host, port, err := net.SplitHostPort(clientAddr); err == nil {
clientIP = host
clientPort = port
} else if addrErr, ok := err.(*net.AddrError); ok && addrErr.Err == "missing port in address" {
clientIP = clientAddr
} else {
s.logger.Debug("Could not extract client address from request.", zap.Error(err))
}

result, fnErr, code := runtime.InvokeFunction(ExecutionModeRPC, lf, queryParams, uid, username, expiry, "", clientIP, clientPort, in.Payload)
s.runtimePool.Put(runtime)

if fnErr != nil {
Expand Down
9 changes: 5 additions & 4 deletions server/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ import (
"fmt"

"context"
"strings"
"time"

"github.com/golang/protobuf/jsonpb"
"github.com/heroiclabs/nakama/rtapi"
"go.opencensus.io/stats"
"go.opencensus.io/tag"
"go.opencensus.io/trace"
"go.uber.org/zap"
"strings"
"time"
)

type Pipeline struct {
Expand Down Expand Up @@ -147,7 +148,7 @@ func (p *Pipeline) ProcessRequest(logger *zap.Logger, session Session, envelope
span := trace.NewSpan(name, nil, trace.StartOptions{})

// Actual before hook function execution.
hookResult, hookErr := invokeReqBeforeHook(logger, p.config, p.runtimePool, p.jsonpbMarshaler, p.jsonpbUnmarshaler, session.ID().String(), session.UserID(), session.Username(), session.Expiry(), messageName, envelope)
hookResult, hookErr := invokeReqBeforeHook(logger, p.config, p.runtimePool, p.jsonpbMarshaler, p.jsonpbUnmarshaler, session.ID().String(), session.UserID(), session.Username(), session.Expiry(), session.ClientIP(), session.ClientPort(), messageName, envelope)

// Stats measurement end boundary.
span.End()
Expand Down Expand Up @@ -202,7 +203,7 @@ func (p *Pipeline) ProcessRequest(logger *zap.Logger, session Session, envelope
span := trace.NewSpan(name, nil, trace.StartOptions{})

// Actual after hook function execution.
invokeReqAfterHook(logger, p.config, p.runtimePool, p.jsonpbMarshaler, session.ID().String(), session.UserID(), session.Username(), session.Expiry(), messageName, envelope)
invokeReqAfterHook(logger, p.config, p.runtimePool, p.jsonpbMarshaler, session.ID().String(), session.UserID(), session.Username(), session.Expiry(), session.ClientIP(), session.ClientPort(), messageName, envelope)

// Stats measurement end boundary.
span.End()
Expand Down
2 changes: 1 addition & 1 deletion server/pipeline_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (p *Pipeline) rpc(logger *zap.Logger, session Session, envelope *rtapi.Enve
return
}

result, fnErr, _ := runtime.InvokeFunction(ExecutionModeRPC, lf, nil, session.UserID().String(), session.Username(), session.Expiry(), session.ID().String(), rpcMessage.Payload)
result, fnErr, _ := runtime.InvokeFunction(ExecutionModeRPC, lf, nil, session.UserID().String(), session.Username(), session.Expiry(), session.ID().String(), session.ClientIP(), session.ClientPort(), rpcMessage.Payload)
p.runtimePool.Put(runtime)
if fnErr != nil {
logger.Error("Runtime RPC function caused an error", zap.String("id", rpcMessage.Id), zap.Error(fnErr))
Expand Down
4 changes: 2 additions & 2 deletions server/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ func (r *Runtime) GetCallback(e ExecutionMode, key string) *lua.LFunction {
return nil
}

func (r *Runtime) InvokeFunction(execMode ExecutionMode, fn *lua.LFunction, queryParams map[string][]string, uid string, username string, sessionExpiry int64, sid string, payload interface{}) (interface{}, error, codes.Code) {
ctx := NewLuaContext(r.vm, r.luaEnv, execMode, queryParams, uid, username, sessionExpiry, sid)
func (r *Runtime) InvokeFunction(execMode ExecutionMode, fn *lua.LFunction, queryParams map[string][]string, uid string, username string, sessionExpiry int64, sid string, clientIP string, clientPort string, payload interface{}) (interface{}, error, codes.Code) {
ctx := NewLuaContext(r.vm, r.luaEnv, execMode, queryParams, sessionExpiry, uid, username, sid, clientIP, clientPort)
var lv lua.LValue
if payload != nil {
lv = ConvertValue(r.vm, payload)
Expand Down
10 changes: 5 additions & 5 deletions server/runtime_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"google.golang.org/grpc/status"
)

func invokeReqBeforeHook(logger *zap.Logger, config Config, runtimePool *RuntimePool, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, sessionID string, uid uuid.UUID, username string, expiry int64, callbackID string, req interface{}) (interface{}, error) {
func invokeReqBeforeHook(logger *zap.Logger, config Config, runtimePool *RuntimePool, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, sessionID string, uid uuid.UUID, username string, expiry int64, clientIP string, clientPort string, callbackID string, req interface{}) (interface{}, error) {
id := strings.ToLower(callbackID)
if !runtimePool.HasCallback(ExecutionModeBefore, id) {
return req, nil
Expand Down Expand Up @@ -65,7 +65,7 @@ func invokeReqBeforeHook(logger *zap.Logger, config Config, runtimePool *Runtime
if uid != uuid.Nil {
userID = uid.String()
}
result, fnErr, code := runtime.InvokeFunction(ExecutionModeBefore, lf, nil, userID, username, expiry, sessionID, reqMap)
result, fnErr, code := runtime.InvokeFunction(ExecutionModeBefore, lf, nil, userID, username, expiry, sessionID, clientIP, clientPort, reqMap)
runtimePool.Put(runtime)

if fnErr != nil {
Expand Down Expand Up @@ -105,7 +105,7 @@ func invokeReqBeforeHook(logger *zap.Logger, config Config, runtimePool *Runtime
return reqProto, nil
}

func invokeReqAfterHook(logger *zap.Logger, config Config, runtimePool *RuntimePool, jsonpbMarshaler *jsonpb.Marshaler, sessionID string, uid uuid.UUID, username string, expiry int64, callbackID string, req interface{}) {
func invokeReqAfterHook(logger *zap.Logger, config Config, runtimePool *RuntimePool, jsonpbMarshaler *jsonpb.Marshaler, sessionID string, uid uuid.UUID, username string, expiry int64, clientIP string, clientPort string, callbackID string, req interface{}) {
id := strings.ToLower(callbackID)
if !runtimePool.HasCallback(ExecutionModeAfter, id) {
return
Expand Down Expand Up @@ -143,7 +143,7 @@ func invokeReqAfterHook(logger *zap.Logger, config Config, runtimePool *RuntimeP
if uid != uuid.Nil {
userID = uid.String()
}
_, fnErr, _ := runtime.InvokeFunction(ExecutionModeAfter, lf, nil, userID, username, expiry, sessionID, reqMap)
_, fnErr, _ := runtime.InvokeFunction(ExecutionModeAfter, lf, nil, userID, username, expiry, sessionID, clientIP, clientPort, reqMap)
runtimePool.Put(runtime)

if fnErr != nil {
Expand Down Expand Up @@ -176,7 +176,7 @@ func invokeMatchmakerMatchedHook(logger *zap.Logger, runtimePool *RuntimePool, e
return "", false
}

ctx := NewLuaContext(runtime.vm, runtime.luaEnv, ExecutionModeMatchmaker, nil, "", "", 0, "")
ctx := NewLuaContext(runtime.vm, runtime.luaEnv, ExecutionModeMatchmaker, nil, 0, "", "", "", "", "")

entriesTable := runtime.vm.CreateTable(len(entries), 0)
for i, entry := range entries {
Expand Down
32 changes: 25 additions & 7 deletions server/runtime_lua_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,31 @@ const (
__CTX_USERNAME = "username"
__CTX_USER_SESSION_EXP = "user_session_exp"
__CTX_SESSION_ID = "session_id"
__CTX_CLIENT_IP = "client_ip"
__CTX_CLIENT_PORT = "client_port"
__CTX_MATCH_ID = "match_id"
__CTX_MATCH_NODE = "match_node"
__CTX_MATCH_LABEL = "match_label"
__CTX_MATCH_TICK_RATE = "match_tick_rate"
)

func NewLuaContext(l *lua.LState, env *lua.LTable, mode ExecutionMode, queryParams map[string][]string, uid string, username string, sessionExpiry int64, sid string) *lua.LTable {
func NewLuaContext(l *lua.LState, env *lua.LTable, mode ExecutionMode, queryParams map[string][]string, sessionExpiry int64,
userID, username, sessionID, clientIP, clientPort string) *lua.LTable {
size := 3
if uid != "" {
if userID != "" {
size += 3
if sid != "" {
if sessionID != "" {
size++
}
}

if clientIP != "" {
size++
}
if clientPort != "" {
size++
}

lt := l.CreateTable(0, size)
lt.RawSetString(__CTX_ENV, env)
lt.RawSetString(__CTX_MODE, lua.LString(mode.String()))
Expand All @@ -82,15 +92,23 @@ func NewLuaContext(l *lua.LState, env *lua.LTable, mode ExecutionMode, queryPara
lt.RawSetString(__CTX_QUERY_PARAMS, ConvertValue(l, queryParams))
}

if uid != "" {
lt.RawSetString(__CTX_USER_ID, lua.LString(uid))
if userID != "" {
lt.RawSetString(__CTX_USER_ID, lua.LString(userID))
lt.RawSetString(__CTX_USERNAME, lua.LString(username))
lt.RawSetString(__CTX_USER_SESSION_EXP, lua.LNumber(sessionExpiry))
if sid != "" {
lt.RawSetString(__CTX_SESSION_ID, lua.LString(sid))
if sessionID != "" {
lt.RawSetString(__CTX_SESSION_ID, lua.LString(sessionID))
}
}

if clientIP != "" {
lt.RawSetString(__CTX_CLIENT_IP, lua.LString(clientIP))
}

if clientPort != "" {
lt.RawSetString(__CTX_CLIENT_PORT, lua.LString(clientPort))
}

return lt
}

Expand Down
3 changes: 2 additions & 1 deletion server/runtime_nakama_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"crypto/sha256"

"crypto/md5"

"github.com/gofrs/uuid"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/golang/protobuf/ptypes/wrappers"
Expand Down Expand Up @@ -327,7 +328,7 @@ func (n *NakamaModule) runOnce(l *lua.LState) int {
return
}

ctx := NewLuaContext(l, ConvertMap(l, n.config.GetRuntime().Environment), ExecutionModeRunOnce, nil, "", "", 0, "")
ctx := NewLuaContext(l, ConvertMap(l, n.config.GetRuntime().Environment), ExecutionModeRunOnce, nil, 0, "", "", "", "", "")

l.Push(LSentinel)
l.Push(fn)
Expand Down
2 changes: 2 additions & 0 deletions server/session_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type Session interface {
Logger() *zap.Logger
ID() uuid.UUID
UserID() uuid.UUID
ClientIP() string
ClientPort() string

Username() string
SetUsername(string)
Expand Down
38 changes: 25 additions & 13 deletions server/session_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ var ErrSessionQueueFull = errors.New("session outgoing queue full")

type sessionWS struct {
sync.Mutex
logger *zap.Logger
config Config
id uuid.UUID
userID uuid.UUID
username *atomic.String
expiry int64
logger *zap.Logger
config Config
id uuid.UUID
userID uuid.UUID
username *atomic.String
expiry int64
clientIP string
clientPort string

jsonpbMarshaler *jsonpb.Marshaler
jsonpbUnmarshaler *jsonpb.Unmarshaler
Expand All @@ -61,19 +63,21 @@ type sessionWS struct {
outgoingStopCh chan struct{}
}

func NewSessionWS(logger *zap.Logger, config Config, userID uuid.UUID, username string, expiry int64, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, conn *websocket.Conn, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker) Session {
func NewSessionWS(logger *zap.Logger, config Config, userID uuid.UUID, username string, expiry int64, clientIP string, clientPort string, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, conn *websocket.Conn, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker) Session {
sessionID := uuid.Must(uuid.NewV4())
sessionLogger := logger.With(zap.String("uid", userID.String()), zap.String("sid", sessionID.String()))

sessionLogger.Info("New WebSocket session connected")

return &sessionWS{
logger: sessionLogger,
config: config,
id: sessionID,
userID: userID,
username: atomic.NewString(username),
expiry: expiry,
logger: sessionLogger,
config: config,
id: sessionID,
userID: userID,
username: atomic.NewString(username),
expiry: expiry,
clientIP: clientIP,
clientPort: clientPort,

jsonpbMarshaler: jsonpbMarshaler,
jsonpbUnmarshaler: jsonpbUnmarshaler,
Expand Down Expand Up @@ -107,6 +111,14 @@ func (s *sessionWS) UserID() uuid.UUID {
return s.userID
}

func (s *sessionWS) ClientIP() string {
return s.clientIP
}

func (s *sessionWS) ClientPort() string {
return s.clientPort
}

func (s *sessionWS) Username() string {
return s.username.Load()
}
Expand Down
Loading

0 comments on commit f3d3857

Please sign in to comment.