diff --git a/systemtest/profiling_test.go b/systemtest/profiling_test.go index 52dff433f79..e27a2fa360e 100644 --- a/systemtest/profiling_test.go +++ b/systemtest/profiling_test.go @@ -111,7 +111,9 @@ func TestProfiling(t *testing.T) { ctx := metadata.AppendToOutgoingContext(context.Background(), "secretToken", secretToken, "projectID", "123", - "hostID", "abc123") + "hostID", "abc123", + "rpcVersion", "1", + ) // We always insert 2 elements in KV indices for each test RPC below. // All RPCs use a columnar format, where arrays of fields are bundled diff --git a/x-pack/apm-server/profiling/auth.go b/x-pack/apm-server/profiling/auth.go index f0448ce9ace..0a62b44d8b5 100644 --- a/x-pack/apm-server/profiling/auth.go +++ b/x-pack/apm-server/profiling/auth.go @@ -11,11 +11,26 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "github.com/elastic/apm-server/internal/beater/auth" "github.com/elastic/apm-server/internal/beater/headers" ) +var ( + // This is the version of the gRPC protocol specified in collection_agent.proto. + // It's retrieved and cached in NewCollector and used to check for protocol mismatches + // against the protocol version that the client sends. + rpcProtocolVersion uint32 +) + +// GetRPCVersion returns the version of the RPC protocol +func GetRPCVersion() uint32 { + // Retrieve protocol version defined in the .proto file + options := File_collection_agent_proto.Options() + return proto.GetExtension(options, E_Version).(uint32) +} + // AuthenticateUnaryCall implements the interceptors.UnaryAuthenticator // interface, extracting the secret token supplied by the Host Agent, // which we will treat the same as an APM secret token. @@ -27,29 +42,62 @@ func (e *ElasticCollector) AuthenticateUnaryCall( ) (auth.AuthenticationDetails, auth.Authorizer, error) { md, _ := metadata.FromIncomingContext(ctx) secretToken := GetFirstOrEmpty(md, MetadataKeySecretToken) - projectID := GetFirstOrEmpty(md, MetadataKeyProjectID) - hostID := GetFirstOrEmpty(md, MetadataKeyHostID) + projectIDStr := GetFirstOrEmpty(md, MetadataKeyProjectID) + hostIDStr := GetFirstOrEmpty(md, MetadataKeyHostID) + rpcVersionStr := GetFirstOrEmpty(md, MetadataKeyRPCVersion) if secretToken == "" { - return auth.AuthenticationDetails{}, nil, status.Errorf(codes.Unauthenticated, "secret token is missing") + return auth.AuthenticationDetails{}, nil, + status.Errorf(codes.FailedPrecondition, "secret token is missing") + } + if projectIDStr == "" { + return auth.AuthenticationDetails{}, nil, + status.Errorf(codes.FailedPrecondition, "project ID is missing") } - if projectID == "" { - return auth.AuthenticationDetails{}, nil, status.Errorf(codes.Unauthenticated, "project ID is missing") + if hostIDStr == "" { + return auth.AuthenticationDetails{}, nil, + status.Errorf(codes.FailedPrecondition, "host ID is missing") } - if hostID == "" { - return auth.AuthenticationDetails{}, nil, status.Errorf(codes.Unauthenticated, "host ID is missing") + if rpcVersionStr == "" { + return auth.AuthenticationDetails{}, nil, + status.Errorf(codes.FailedPrecondition, "RPC version is missing") } - if _, err := strconv.Atoi(projectID); err != nil { + if _, err := strconv.ParseUint(projectIDStr, 10, 32); err != nil { e.logger.Errorf("possible malicious client request, "+ - "converting project ID from string (%s) to uint failed: %v", projectID, err) - return auth.AuthenticationDetails{}, nil, auth.ErrAuthFailed + "converting project ID from string (%s) to uint failed: %v", projectIDStr, err) + return auth.AuthenticationDetails{}, nil, + status.Errorf(codes.FailedPrecondition, "invalid project ID") } - if _, err := strconv.ParseUint(hostID, 16, 64); err != nil { + if _, err := strconv.ParseUint(hostIDStr, 16, 64); err != nil { e.logger.Errorf("possible malicious client request, "+ - "converting host ID from string (%s) to uint failed: %v", hostID, err) - return auth.AuthenticationDetails{}, nil, auth.ErrAuthFailed + "converting host ID from string (%s) to uint failed: %v", hostIDStr, err) + return auth.AuthenticationDetails{}, nil, + status.Errorf(codes.FailedPrecondition, "invalid host ID") + } + + rpcVersion64, err := strconv.ParseUint(rpcVersionStr, 10, 32) + if err != nil { + e.logger.Errorf("converting RPC version from string (%s) to uint failed: %v", + rpcVersionStr, err) + return auth.AuthenticationDetails{}, nil, + status.Errorf(codes.FailedPrecondition, "invalid RPC version") + } + + rpcVersion := uint32(rpcVersion64) + if rpcVersion != rpcProtocolVersion { + e.logger.Errorf("incompatible RPC version: %d => %d", rpcVersion, rpcProtocolVersion) + + if rpcVersion < rpcProtocolVersion { + return auth.AuthenticationDetails{}, nil, + status.Errorf(codes.FailedPrecondition, + "HostAgent version is unsupported, please upgrade to the latest version") + } + + return auth.AuthenticationDetails{}, nil, + status.Errorf(codes.FailedPrecondition, + "Backend is incompatible with HostAgent, please check your configuration") } return authenticator.Authenticate(ctx, headers.Bearer, secretToken) diff --git a/x-pack/apm-server/profiling/collector.go b/x-pack/apm-server/profiling/collector.go index 97a69cc4755..53cbe335650 100644 --- a/x-pack/apm-server/profiling/collector.go +++ b/x-pack/apm-server/profiling/collector.go @@ -77,9 +77,10 @@ type ElasticCollector struct { clusterID string } -// NewCollector returns a new ElasticCollector uses indexer for storing stack trace data in -// Elasticsearch, and metricsIndexer for storing host agent metrics. Separate indexers are -// used to allow for host agent metrics to be sent to a separate monitoring cluster. +// NewCollector returns a new ElasticCollector which uses indexer for storing stack trace +// data in Elasticsearch, and metricsIndexer for storing host agent metrics. Separate +// indexers are used to allow for host agent metrics to be sent to a separate monitoring +// cluster. func NewCollector( indexer esutil.BulkIndexer, metricsIndexer esutil.BulkIndexer, @@ -104,6 +105,8 @@ func NewCollector( c.indexes[i] = fmt.Sprintf("%s-%dpow%02d", common.EventsIndexPrefix, common.SamplingFactor, i+1) } + + rpcProtocolVersion = GetRPCVersion() return c } diff --git a/x-pack/apm-server/profiling/grpcext.go b/x-pack/apm-server/profiling/grpcext.go index c0997b5005c..9490432d0e5 100644 --- a/x-pack/apm-server/profiling/grpcext.go +++ b/x-pack/apm-server/profiling/grpcext.go @@ -20,6 +20,7 @@ const ( MetadataKeyHostname = "hostname" MetadataKeyKernelVersion = "kernelVersion" MetadataKeyHostID = "hostID" + MetadataKeyRPCVersion = "rpcVersion" // Tags will be auto base64 encoded/decoded MetadataKeyTags = "tags-bin" ) @@ -36,8 +37,8 @@ func GetProjectID(ctx context.Context) uint32 { // Metadata and host ID have been validated in auth interceptor, // no need to error check here. md, _ := metadata.FromIncomingContext(ctx) - projectIDs := GetFirstOrEmpty(md, MetadataKeyProjectID) - projectID64, _ := strconv.Atoi(projectIDs) + projectIDStr := GetFirstOrEmpty(md, MetadataKeyProjectID) + projectID64, _ := strconv.ParseUint(projectIDStr, 10, 32) return uint32(projectID64) } @@ -45,8 +46,8 @@ func GetHostID(ctx context.Context) uint64 { // Metadata and host ID have been validated in auth interceptor, // no need to error check here. md, _ := metadata.FromIncomingContext(ctx) - hostIDs := GetFirstOrEmpty(md, MetadataKeyHostID) - hostID, _ := strconv.ParseUint(hostIDs, 16, 64) + hostIDStr := GetFirstOrEmpty(md, MetadataKeyHostID) + hostID, _ := strconv.ParseUint(hostIDStr, 16, 64) return hostID }