Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[202405] Add cert authorization with common name support. #322

Merged
merged 2 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions gnmi_server/clientCertAuth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gnmi

import (
"github.com/sonic-net/sonic-gnmi/common_utils"
"github.com/sonic-net/sonic-gnmi/swsscommon"
"github.com/golang/glog"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
Expand All @@ -10,7 +11,7 @@ import (
"google.golang.org/grpc/status"
)

func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) {
func ClientCertAuthenAndAuthor(ctx context.Context, serviceConfigTableName string) (context.Context, error) {
rc, ctx := common_utils.GetContext(ctx)
p, ok := peer.FromContext(ctx)
if !ok {
Expand All @@ -32,10 +33,44 @@ func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) {
return ctx, status.Error(codes.Unauthenticated, "invalid username in certificate common name.")
}

if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil {
glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err)
return ctx, status.Errorf(codes.Unauthenticated, "")
if serviceConfigTableName != "" {
if err := PopulateAuthStructByCommonName(username, &rc.Auth, serviceConfigTableName); err != nil {
return ctx, err
}
} else {
if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil {
glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err)
return ctx, status.Errorf(codes.Unauthenticated, "")
}
}

return ctx, nil
}

func PopulateAuthStructByCommonName(certCommonName string, auth *common_utils.AuthInfo, serviceConfigTableName string) error {
if serviceConfigTableName == "" {
return status.Errorf(codes.Unauthenticated, "Service config table name should not be empty")
}

var configDbConnector = swsscommon.NewConfigDBConnector()
defer swsscommon.DeleteConfigDBConnector_Native(configDbConnector.ConfigDBConnector_Native)
configDbConnector.Connect(false)

var fieldValuePairs = configDbConnector.Get_entry(serviceConfigTableName, certCommonName)
if fieldValuePairs.Size() > 0 {
if fieldValuePairs.Has_key("role") {
var role = fieldValuePairs.Get("role")
auth.Roles = []string{role}
}
} else {
glog.Warningf("Failed to retrieve cert common name mapping; %s", certCommonName)
}

swsscommon.DeleteFieldValueMap(fieldValuePairs)

if len(auth.Roles) == 0 {
return status.Errorf(codes.Unauthenticated, "Invalid cert cname:'%s', not a trusted cert common name.", certCommonName)
} else {
return nil
}
}
2 changes: 1 addition & 1 deletion gnmi_server/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (

func (srv *Server) GetSubscribePreferences(req *spb_gnoi.SubscribePreferencesReq, stream spb_gnoi.Debug_GetSubscribePreferencesServer) error {
ctx := stream.Context()
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand Down
32 changes: 16 additions & 16 deletions gnmi_server/gnoi.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func KillOrRestartProcess(restart bool, serviceName string) error {
}

func (srv *Server) KillProcess(ctx context.Context, req *gnoi_system_pb.KillProcessRequest) (*gnoi_system_pb.KillProcessResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func RebootSystem(fileName string) error {
func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest) (*gnoi_system_pb.RebootResponse, error) {
fileName := common_utils.GNMI_WORK_PATH + "/config_db.json.tmp"

_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -102,7 +102,7 @@ func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest

// TODO: Support GNOI RebootStatus
func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootStatusRequest) (*gnoi_system_pb.RebootStatusResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -112,7 +112,7 @@ func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootS

// TODO: Support GNOI CancelReboot
func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelRebootRequest) (*gnoi_system_pb.CancelRebootResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -121,7 +121,7 @@ func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelR
}
func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.System_PingServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand All @@ -130,7 +130,7 @@ func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.Syste
}
func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_system_pb.System_TracerouteServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand All @@ -139,23 +139,23 @@ func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_sys
}
func (srv *Server) SetPackage(rs gnoi_system_pb.System_SetPackageServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
log.V(1).Info("gNOI: SetPackage")
return status.Errorf(codes.Unimplemented, "")
}
func (srv *Server) SwitchControlProcessor(ctx context.Context, req *gnoi_system_pb.SwitchControlProcessorRequest) (*gnoi_system_pb.SwitchControlProcessorResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
log.V(1).Info("gNOI: SwitchControlProcessor")
return nil, status.Errorf(codes.Unimplemented, "")
}
func (srv *Server) Time(ctx context.Context, req *gnoi_system_pb.TimeRequest) (*gnoi_system_pb.TimeResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -192,7 +192,7 @@ func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRe

}
func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*spb_jwt.RefreshResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,7 +220,7 @@ func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*s
}

func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRequest) (*spb.ClearNeighborsResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -252,7 +252,7 @@ func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRe
}

func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (*spb.CopyConfigResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -283,7 +283,7 @@ func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (
}

func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequest) (*spb.TechsupportResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -315,7 +315,7 @@ func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequ
}

func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallRequest) (*spb.ImageInstallResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -347,7 +347,7 @@ func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallReques
}

func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) (*spb.ImageRemoveResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -379,7 +379,7 @@ func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest)
}

func (srv *Server) ImageDefault(ctx context.Context, req *spb.ImageDefaultRequest) (*spb.ImageDefaultResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down
21 changes: 11 additions & 10 deletions gnmi_server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type Config struct {
EnableNativeWrite bool
ZmqPort string
IdleConnDuration int
ConfigTableName string
}

var AuthLock sync.Mutex
Expand Down Expand Up @@ -216,30 +217,30 @@ func (srv *Server) Port() int64 {
return srv.config.Port
}

func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, error) {
func authenticate(config *Config, ctx context.Context) (context.Context, error) {
var err error
success := false
rc, ctx := common_utils.GetContext(ctx)
if !UserAuth.Any() {
if !config.UserAuth.Any() {
//No Auth enabled
rc.Auth.AuthEnabled = false
return ctx, nil
}
rc.Auth.AuthEnabled = true
if UserAuth.Enabled("password") {
if config.UserAuth.Enabled("password") {
ctx, err = BasicAuthenAndAuthor(ctx)
if err == nil {
success = true
}
}
if !success && UserAuth.Enabled("jwt") {
if !success && config.UserAuth.Enabled("jwt") {
_, ctx, err = JwtAuthenAndAuthor(ctx)
if err == nil {
success = true
}
}
if !success && UserAuth.Enabled("cert") {
ctx, err = ClientCertAuthenAndAuthor(ctx)
if !success && config.UserAuth.Enabled("cert") {
ctx, err = ClientCertAuthenAndAuthor(ctx, config.ConfigTableName)
if err == nil {
success = true
}
Expand All @@ -258,7 +259,7 @@ func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, err
// Subscribe implements the gNMI Subscribe RPC.
func (s *Server) Subscribe(stream gnmipb.GNMI_SubscribeServer) error {
ctx := stream.Context()
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -343,7 +344,7 @@ func IsNativeOrigin(origin string) bool {
// Get implements the Get RPC in gNMI spec.
func (s *Server) Get(ctx context.Context, req *gnmipb.GetRequest) (*gnmipb.GetResponse, error) {
common_utils.IncCounter(common_utils.GNMI_GET)
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
common_utils.IncCounter(common_utils.GNMI_GET_FAIL)
return nil, err
Expand Down Expand Up @@ -449,7 +450,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe
common_utils.IncCounter(common_utils.GNMI_SET_FAIL)
return nil, grpc.Errorf(codes.Unimplemented, "GNMI is in read-only mode")
}
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
common_utils.IncCounter(common_utils.GNMI_SET_FAIL)
return nil, err
Expand Down Expand Up @@ -550,7 +551,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe
}

func (s *Server) Capabilities(ctx context.Context, req *gnmipb.CapabilityRequest) (*gnmipb.CapabilityResponse, error) {
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading