diff --git a/gnmi_server/clientCertAuth.go b/gnmi_server/clientCertAuth.go index 1c44d9c5..8a7d1f37 100644 --- a/gnmi_server/clientCertAuth.go +++ b/gnmi_server/clientCertAuth.go @@ -2,6 +2,7 @@ package gnmi import ( "github.com/sonic-net/sonic-gnmi/common_utils" + "github.com/sonic-net/sonic-gnmi/swsscommon" "github.com/golang/glog" "golang.org/x/net/context" "google.golang.org/grpc/codes" @@ -10,7 +11,7 @@ import ( "google.golang.org/grpc/status" ) -func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) { +func ClientCertAuthenAndAuthor(ctx context.Context, serviceConfigTableName string) (context.Context, error) { rc, ctx := common_utils.GetContext(ctx) p, ok := peer.FromContext(ctx) if !ok { @@ -32,10 +33,44 @@ func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) { return ctx, status.Error(codes.Unauthenticated, "invalid username in certificate common name.") } - if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil { - glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err) - return ctx, status.Errorf(codes.Unauthenticated, "") + if serviceConfigTableName != "" { + if err := PopulateAuthStructByCommonName(username, &rc.Auth, serviceConfigTableName); err != nil { + return ctx, err + } + } else { + if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil { + glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err) + return ctx, status.Errorf(codes.Unauthenticated, "") + } } return ctx, nil } + +func PopulateAuthStructByCommonName(certCommonName string, auth *common_utils.AuthInfo, serviceConfigTableName string) error { + if serviceConfigTableName == "" { + return status.Errorf(codes.Unauthenticated, "Service config table name should not be empty") + } + + var configDbConnector = swsscommon.NewConfigDBConnector() + defer swsscommon.DeleteConfigDBConnector_Native(configDbConnector.ConfigDBConnector_Native) + configDbConnector.Connect(false) + + var fieldValuePairs = configDbConnector.Get_entry(serviceConfigTableName, certCommonName) + if fieldValuePairs.Size() > 0 { + if fieldValuePairs.Has_key("role") { + var role = fieldValuePairs.Get("role") + auth.Roles = []string{role} + } + } else { + glog.Warningf("Failed to retrieve cert common name mapping; %s", certCommonName) + } + + swsscommon.DeleteFieldValueMap(fieldValuePairs) + + if len(auth.Roles) == 0 { + return status.Errorf(codes.Unauthenticated, "Invalid cert cname:'%s', not a trusted cert common name.", certCommonName) + } else { + return nil + } +} diff --git a/gnmi_server/debug.go b/gnmi_server/debug.go index 5239b72e..6099630e 100644 --- a/gnmi_server/debug.go +++ b/gnmi_server/debug.go @@ -35,7 +35,7 @@ import ( func (srv *Server) GetSubscribePreferences(req *spb_gnoi.SubscribePreferencesReq, stream spb_gnoi.Debug_GetSubscribePreferencesServer) error { ctx := stream.Context() - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return err } diff --git a/gnmi_server/gnoi.go b/gnmi_server/gnoi.go index 2023746b..0f519214 100644 --- a/gnmi_server/gnoi.go +++ b/gnmi_server/gnoi.go @@ -42,7 +42,7 @@ func KillOrRestartProcess(restart bool, serviceName string) error { } func (srv *Server) KillProcess(ctx context.Context, req *gnoi_system_pb.KillProcessRequest) (*gnoi_system_pb.KillProcessResponse, error) { - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -78,7 +78,7 @@ func RebootSystem(fileName string) error { func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest) (*gnoi_system_pb.RebootResponse, error) { fileName := common_utils.GNMI_WORK_PATH + "/config_db.json.tmp" - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -102,7 +102,7 @@ func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest // TODO: Support GNOI RebootStatus func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootStatusRequest) (*gnoi_system_pb.RebootStatusResponse, error) { - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -112,7 +112,7 @@ func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootS // TODO: Support GNOI CancelReboot func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelRebootRequest) (*gnoi_system_pb.CancelRebootResponse, error) { - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -121,7 +121,7 @@ func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelR } func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.System_PingServer) error { ctx := rs.Context() - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return err } @@ -130,7 +130,7 @@ func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.Syste } func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_system_pb.System_TracerouteServer) error { ctx := rs.Context() - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return err } @@ -139,7 +139,7 @@ func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_sys } func (srv *Server) SetPackage(rs gnoi_system_pb.System_SetPackageServer) error { ctx := rs.Context() - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return err } @@ -147,7 +147,7 @@ func (srv *Server) SetPackage(rs gnoi_system_pb.System_SetPackageServer) error { return status.Errorf(codes.Unimplemented, "") } func (srv *Server) SwitchControlProcessor(ctx context.Context, req *gnoi_system_pb.SwitchControlProcessorRequest) (*gnoi_system_pb.SwitchControlProcessorResponse, error) { - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -155,7 +155,7 @@ func (srv *Server) SwitchControlProcessor(ctx context.Context, req *gnoi_system_ return nil, status.Errorf(codes.Unimplemented, "") } func (srv *Server) Time(ctx context.Context, req *gnoi_system_pb.TimeRequest) (*gnoi_system_pb.TimeResponse, error) { - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -192,7 +192,7 @@ func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRe } func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*spb_jwt.RefreshResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -220,7 +220,7 @@ func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*s } func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRequest) (*spb.ClearNeighborsResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -252,7 +252,7 @@ func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRe } func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (*spb.CopyConfigResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -283,7 +283,7 @@ func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) ( } func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequest) (*spb.TechsupportResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -315,7 +315,7 @@ func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequ } func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallRequest) (*spb.ImageInstallResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -347,7 +347,7 @@ func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallReques } func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) (*spb.ImageRemoveResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -379,7 +379,7 @@ func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) } func (srv *Server) ImageDefault(ctx context.Context, req *spb.ImageDefaultRequest) (*spb.ImageDefaultResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } diff --git a/gnmi_server/server.go b/gnmi_server/server.go index 770b27ab..c21f36d6 100644 --- a/gnmi_server/server.go +++ b/gnmi_server/server.go @@ -68,6 +68,7 @@ type Config struct { EnableNativeWrite bool ZmqPort string IdleConnDuration int + ConfigTableName string } var AuthLock sync.Mutex @@ -216,30 +217,30 @@ func (srv *Server) Port() int64 { return srv.config.Port } -func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, error) { +func authenticate(config *Config, ctx context.Context) (context.Context, error) { var err error success := false rc, ctx := common_utils.GetContext(ctx) - if !UserAuth.Any() { + if !config.UserAuth.Any() { //No Auth enabled rc.Auth.AuthEnabled = false return ctx, nil } rc.Auth.AuthEnabled = true - if UserAuth.Enabled("password") { + if config.UserAuth.Enabled("password") { ctx, err = BasicAuthenAndAuthor(ctx) if err == nil { success = true } } - if !success && UserAuth.Enabled("jwt") { + if !success && config.UserAuth.Enabled("jwt") { _, ctx, err = JwtAuthenAndAuthor(ctx) if err == nil { success = true } } - if !success && UserAuth.Enabled("cert") { - ctx, err = ClientCertAuthenAndAuthor(ctx) + if !success && config.UserAuth.Enabled("cert") { + ctx, err = ClientCertAuthenAndAuthor(ctx, config.ConfigTableName) if err == nil { success = true } @@ -258,7 +259,7 @@ func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, err // Subscribe implements the gNMI Subscribe RPC. func (s *Server) Subscribe(stream gnmipb.GNMI_SubscribeServer) error { ctx := stream.Context() - ctx, err := authenticate(s.config.UserAuth, ctx) + ctx, err := authenticate(s.config, ctx) if err != nil { return err } @@ -343,7 +344,7 @@ func IsNativeOrigin(origin string) bool { // Get implements the Get RPC in gNMI spec. func (s *Server) Get(ctx context.Context, req *gnmipb.GetRequest) (*gnmipb.GetResponse, error) { common_utils.IncCounter(common_utils.GNMI_GET) - ctx, err := authenticate(s.config.UserAuth, ctx) + ctx, err := authenticate(s.config, ctx) if err != nil { common_utils.IncCounter(common_utils.GNMI_GET_FAIL) return nil, err @@ -449,7 +450,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe common_utils.IncCounter(common_utils.GNMI_SET_FAIL) return nil, grpc.Errorf(codes.Unimplemented, "GNMI is in read-only mode") } - ctx, err := authenticate(s.config.UserAuth, ctx) + ctx, err := authenticate(s.config, ctx) if err != nil { common_utils.IncCounter(common_utils.GNMI_SET_FAIL) return nil, err @@ -550,7 +551,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe } func (s *Server) Capabilities(ctx context.Context, req *gnmipb.CapabilityRequest) (*gnmipb.CapabilityResponse, error) { - ctx, err := authenticate(s.config.UserAuth, ctx) + ctx, err := authenticate(s.config, ctx) if err != nil { return nil, err } diff --git a/gnmi_server/server_test.go b/gnmi_server/server_test.go index 80ab79e5..4ec9e71f 100644 --- a/gnmi_server/server_test.go +++ b/gnmi_server/server_test.go @@ -20,8 +20,12 @@ import ( "time" "unsafe" + "crypto/x509" + "crypto/x509/pkix" + spb "github.com/sonic-net/sonic-gnmi/proto" sgpb "github.com/sonic-net/sonic-gnmi/proto/gnoi" + spb_jwt "github.com/sonic-net/sonic-gnmi/proto/gnoi/jwt" sdc "github.com/sonic-net/sonic-gnmi/sonic_data_client" sdcfg "github.com/sonic-net/sonic-gnmi/sonic_db_config" ssc "github.com/sonic-net/sonic-gnmi/sonic_service_client" @@ -42,6 +46,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/peer" "google.golang.org/grpc/status" // Register supported client types. @@ -55,6 +60,7 @@ import ( gnmipb "github.com/openconfig/gnmi/proto/gnmi" gnoi_system_pb "github.com/openconfig/gnoi/system" "github.com/sonic-net/sonic-gnmi/common_utils" + "github.com/sonic-net/sonic-gnmi/swsscommon" ) var clientTypes = []string{gclient.Type} @@ -4199,6 +4205,187 @@ func TestSaveOnSet(t *testing.T) { } } +func TestPopulateAuthStructByCommonName(t *testing.T) { + // check auth with nil cert name + err := PopulateAuthStructByCommonName("certname1", nil, "") + if err == nil { + t.Errorf("PopulateAuthStructByCommonName with empty config table should failed: %v", err) + } +} + +func CreateAuthorizationCtx() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + cert := x509.Certificate{ + Subject: pkix.Name{ + CommonName: "certname1", + }, + } + verifiedCerts := make([][]*x509.Certificate, 1) + verifiedCerts[0] = make([]*x509.Certificate, 1) + verifiedCerts[0][0] = &cert + p := peer.Peer{ + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + VerifiedChains: verifiedCerts, + }, + }, + } + ctx = peer.NewContext(ctx, &p) + return ctx, cancel +} + + func TestClientCertAuthenAndAuthor(t *testing.T) { + if !swsscommon.SonicDBConfigIsInit() { + swsscommon.SonicDBConfigInitialize() + } + + var configDb = swsscommon.NewDBConnector("CONFIG_DB", uint(0), true) + var gnmiTable = swsscommon.NewTable(configDb, "GNMI_CLIENT_CERT") + configDb.Flushdb() + + // initialize err variable + err := status.Error(codes.Unauthenticated, "") + + // when config table is empty, will authorize with PopulateAuthStruct + mockpopulate := gomonkey.ApplyFunc(PopulateAuthStruct, func(username string, auth *common_utils.AuthInfo, r []string) error { + return nil + }) + defer mockpopulate.Reset() + + // check auth with nil cert name + ctx, cancel := CreateAuthorizationCtx() + ctx, err = ClientCertAuthenAndAuthor(ctx, "") + if err != nil { + t.Errorf("CommonNameMatch with empty config table should success: %v", err) + } + + cancel() + + // check get 1 cert name + ctx, cancel = CreateAuthorizationCtx() + configDb.Flushdb() + gnmiTable.Hset("certname1", "role", "role1") + ctx, err = ClientCertAuthenAndAuthor(ctx, "GNMI_CLIENT_CERT") + if err != nil { + t.Errorf("CommonNameMatch with correct cert name should success: %v", err) + } + + cancel() + + // check get multiple cert names + ctx, cancel = CreateAuthorizationCtx() + configDb.Flushdb() + gnmiTable.Hset("certname1", "role", "role1") + gnmiTable.Hset("certname2", "role", "role2") + ctx, err = ClientCertAuthenAndAuthor(ctx, "GNMI_CLIENT_CERT") + if err != nil { + t.Errorf("CommonNameMatch with correct cert name should success: %v", err) + } + + cancel() + + // check a invalid cert cname + ctx, cancel = CreateAuthorizationCtx() + configDb.Flushdb() + gnmiTable.Hset("certname2", "role", "role2") + ctx, err = ClientCertAuthenAndAuthor(ctx, "GNMI_CLIENT_CERT") + if err == nil { + t.Errorf("CommonNameMatch with invalid cert name should fail: %v", err) + } + + cancel() + + swsscommon.DeleteTable(gnmiTable) + swsscommon.DeleteDBConnector(configDb) +} + +type MockServerStream struct { + grpc.ServerStream +} + +func (x *MockServerStream) Context() context.Context { + return context.Background() +} + +type MockPingServer struct { + MockServerStream +} + +func (x *MockPingServer) Send(m *gnoi_system_pb.PingResponse) error { + return nil +} + +type MockTracerouteServer struct { + MockServerStream +} + +func (x *MockTracerouteServer) Send(m *gnoi_system_pb.TracerouteResponse) error { + return nil +} + +type MockSetPackageServer struct { + MockServerStream +} + +func (x *MockSetPackageServer) Send(m *gnoi_system_pb.SetPackageResponse) error { + return nil +} + +func (x *MockSetPackageServer) SendAndClose(m *gnoi_system_pb.SetPackageResponse) error { + return nil +} + +func (x *MockSetPackageServer) Recv() (*gnoi_system_pb.SetPackageRequest, error) { + return nil, nil +} + +func TestGnoiAuthorization(t *testing.T) { + s := createServer(t, 8081) + go runServer(t, s) + mockAuthenticate := gomonkey.ApplyFunc(s.Authenticate, func(ctx context.Context, req *spb_jwt.AuthenticateRequest) (*spb_jwt.AuthenticateResponse, error) { + return nil, nil + }) + defer mockAuthenticate.Reset() + + err := s.Ping(new(gnoi_system_pb.PingRequest), new(MockPingServer)) + if err == nil { + t.Errorf("Ping should failed, because not implement.") + } + + s.Traceroute(new(gnoi_system_pb.TracerouteRequest), new(MockTracerouteServer)) + if err == nil { + t.Errorf("Traceroute should failed, because not implement.") + } + + s.SetPackage(new(MockSetPackageServer)) + if err == nil { + t.Errorf("SetPackage should failed, because not implement.") + } + + ctx := context.Background() + s.SwitchControlProcessor(ctx, new(gnoi_system_pb.SwitchControlProcessorRequest)) + if err == nil { + t.Errorf("SwitchControlProcessor should failed, because not implement.") + } + + s.Refresh(ctx, new(spb_jwt.RefreshRequest)) + if err == nil { + t.Errorf("Refresh should failed, because not implement.") + } + + s.ClearNeighbors(ctx, new(sgpb.ClearNeighborsRequest)) + if err == nil { + t.Errorf("ClearNeighbors should failed, because not implement.") + } + + s.CopyConfig(ctx, new(sgpb.CopyConfigRequest)) + if err == nil { + t.Errorf("CopyConfig should failed, because not implement.") + } + + s.Stop() +} + func init() { // Enable logs at UT setup flag.Lookup("v").Value.Set("10") diff --git a/telemetry/telemetry.go b/telemetry/telemetry.go index e6b47261..89fe3f3c 100644 --- a/telemetry/telemetry.go +++ b/telemetry/telemetry.go @@ -43,6 +43,7 @@ type TelemetryConfig struct { CaCert *string ServerCert *string ServerKey *string + ConfigTableName *string ZmqAddress *string ZmqPort *string Insecure *bool @@ -150,6 +151,7 @@ func setupFlags(fs *flag.FlagSet) (*TelemetryConfig, *gnmi.Config, error) { CaCert: fs.String("ca_crt", "", "CA certificate for client certificate validation. Optional."), ServerCert: fs.String("server_crt", "", "TLS server certificate"), ServerKey: fs.String("server_key", "", "TLS server private key"), + ConfigTableName: fs.String("config_table_name", "", "Config table name"), ZmqAddress: fs.String("zmq_address", "", "Orchagent ZMQ address, deprecated, please use zmq_port."), ZmqPort: fs.String("zmq_port", "", "Orchagent ZMQ port, when not set or empty string telemetry server will switch to Redis based communication channel."), Insecure: fs.Bool("insecure", false, "Skip providing TLS cert and key, for testing only!"), @@ -224,6 +226,7 @@ func setupFlags(fs *flag.FlagSet) (*TelemetryConfig, *gnmi.Config, error) { cfg.LogLevel = int(*telemetryCfg.LogLevel) cfg.Threshold = int(*telemetryCfg.Threshold) cfg.IdleConnDuration = int(*telemetryCfg.IdleConnDuration) + cfg.ConfigTableName = *telemetryCfg.ConfigTableName // TODO: After other dependent projects are migrated to ZmqPort, remove ZmqAddress zmqAddress := *telemetryCfg.ZmqAddress diff --git a/telemetry/telemetry_test.go b/telemetry/telemetry_test.go index 5bb8d281..acc79a47 100644 --- a/telemetry/telemetry_test.go +++ b/telemetry/telemetry_test.go @@ -384,12 +384,16 @@ func TestSHA512Checksum(t *testing.T) { }() fs := flag.NewFlagSet("testStartGNMIServer", flag.ContinueOnError) - os.Args = []string{"cmd", "-port", "8080", "-server_crt", testServerCert, "-server_key", testServerKey} + os.Args = []string{"cmd", "-port", "8080", "-server_crt", testServerCert, "-server_key", testServerKey, "-config_table_name", "GNMI_CLIENT_CERT"} telemetryCfg, cfg, err := setupFlags(fs) if err != nil { t.Errorf("Expected err to be nil, got err %v", err) } + if cfg.ConfigTableName != "GNMI_CLIENT_CERT" { + t.Errorf("Expected err to be GNMI_CLIENT_CERT, got %s", cfg.ConfigTableName) + } + err = saveCertKeyPair(testServerCert, testServerKey) if err != nil {