From 5466e7620c3945b8bcd5f3d664fb176f2962adfb Mon Sep 17 00:00:00 2001 From: Alexander Lazarev Date: Wed, 15 Nov 2023 15:03:10 +0300 Subject: [PATCH 1/3] added crypt keys --- cmd/key/main.go | 45 ++++++++++++++++++++++++++++ internal/pkg/agent/config/config.go | 2 ++ internal/pkg/agent/encrypt.go | 32 ++++++++++++++++++++ internal/pkg/agent/reporter.go | 15 +++++++--- internal/pkg/agent/reporter_test.go | 3 +- internal/pkg/middleware/crypt.go | 43 ++++++++++++++++++++++++++ internal/pkg/server/config/config.go | 2 ++ internal/pkg/server/server.go | 1 + 8 files changed, 138 insertions(+), 5 deletions(-) create mode 100644 cmd/key/main.go create mode 100644 internal/pkg/agent/encrypt.go diff --git a/cmd/key/main.go b/cmd/key/main.go new file mode 100644 index 0000000..ae41f4a --- /dev/null +++ b/cmd/key/main.go @@ -0,0 +1,45 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "log" + "os" +) + +func main() { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + log.Fatal(err) + } + + publicKey := &privateKey.PublicKey + privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "SERVER PRIVATE KEY", + Bytes: privateKeyBytes, + }) + + err = os.WriteFile("./private.pem", privateKeyPEM, 0644) + if err != nil { + log.Fatal(err) + } + + publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + log.Fatal(err) + } + + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "AGENT PUBLIC KEY", + Bytes: publicKeyBytes, + }) + + err = os.WriteFile("./public.pem", publicKeyPEM, 0644) + if err != nil { + log.Fatal(err) + } +} diff --git a/internal/pkg/agent/config/config.go b/internal/pkg/agent/config/config.go index 28bcc56..2cd9c46 100644 --- a/internal/pkg/agent/config/config.go +++ b/internal/pkg/agent/config/config.go @@ -13,6 +13,7 @@ type AgentConfig struct { ReportInterval int `env:"REPORT_INTERVAL"` PollInterval int `env:"POLL_INTERVAL"` RateLimit int `env:"RATE_LIMIT"` + PublicKey string `env:"CRYPTO_KEY"` SignKeyByte []byte } @@ -44,5 +45,6 @@ func (c *AgentConfig) init() { flag.IntVar(&c.PollInterval, "p", pollIntervalDefault, "Interval of poll metric") flag.StringVar(&c.SignKey, "k", "", "Server key") flag.IntVar(&c.RateLimit, "l", rateLimitDefault, "Rate limit") + flag.StringVar(&c.PublicKey, "-crypto-key", "", "Public key path") flag.Parse() } diff --git a/internal/pkg/agent/encrypt.go b/internal/pkg/agent/encrypt.go new file mode 100644 index 0000000..c26e912 --- /dev/null +++ b/internal/pkg/agent/encrypt.go @@ -0,0 +1,32 @@ +package agent + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "os" +) + +func encrypt(publicKeyPath string, data []byte) ([]byte, error) { + publicKeyPEM, err := os.ReadFile(publicKeyPath) + if err != nil { + return nil, err + } + publicKeyBlock, _ := pem.Decode(publicKeyPEM) + if publicKeyBlock == nil { + return nil, err + } + + publicKey, err := x509.ParsePKIXPublicKey(publicKeyBlock.Bytes) + if err != nil { + return nil, err + } + + ciphertext, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey.(*rsa.PublicKey), data) + if err != nil { + return nil, err + } + + return ciphertext, nil +} diff --git a/internal/pkg/agent/reporter.go b/internal/pkg/agent/reporter.go index 53ef4a6..3355794 100644 --- a/internal/pkg/agent/reporter.go +++ b/internal/pkg/agent/reporter.go @@ -24,7 +24,7 @@ func RunSendMetric(ctx context.Context, reportTicker *time.Ticker, c *config.Age case <-ctx.Done(): return case <-reportTicker.C: - ok, err := SendMetrics(ctx, s, c.ServerAddress, c.SignKeyByte) + ok, err := SendMetrics(ctx, s, c.ServerAddress, c.SignKeyByte, c.PublicKey) if err != nil { logrus.Errorf("Error send metrics %v", err) } @@ -37,7 +37,7 @@ func RunSendMetric(ctx context.Context, reportTicker *time.Ticker, c *config.Age } } -func SendMetrics(ctx context.Context, s storage.Store, serverAddress string, signKey []byte) (bool, error) { +func SendMetrics(ctx context.Context, s storage.Store, serverAddress string, signKey []byte, publicKey string) (bool, error) { metricsMap, err := s.GetMetrics(ctx) if err != nil { logrus.Errorf("Some error ocured during metrics get: %q", err) @@ -51,18 +51,25 @@ func SendMetrics(ctx context.Context, s storage.Store, serverAddress string, sig url := fmt.Sprintf("http://%s/updates/", serverAddress) - if err = SendBatchJSON(url, metricsBatch, signKey); err != nil { + if err = SendBatchJSON(url, metricsBatch, signKey, publicKey); err != nil { return false, fmt.Errorf("error create post request %w", err) } return true, nil } -func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, signKey []byte) error { +func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, signKey []byte, publicKey string) error { body, err := json.Marshal(metricsBatch) if err != nil { return fmt.Errorf("error encoding metric %w", err) } + if publicKey != "" { + body, err = encrypt(publicKey, body) + if err != nil { + return err + } + } + var buf bytes.Buffer gz := gzip.NewWriter(&buf) if _, err = gz.Write(body); err != nil { diff --git a/internal/pkg/agent/reporter_test.go b/internal/pkg/agent/reporter_test.go index dc9c3ab..853f241 100644 --- a/internal/pkg/agent/reporter_test.go +++ b/internal/pkg/agent/reporter_test.go @@ -23,6 +23,7 @@ const ( func TestSendReport(t *testing.T) { secretKey := []byte("secret") + publicKey := "" mtr := storage.NewMetrics() err := agent.UpdateMetrics(context.Background(), mtr) @@ -70,5 +71,5 @@ func TestSendReport(t *testing.T) { })) defer server.Close() - agent.SendMetrics(context.Background(), mtr, server.URL, secretKey) + agent.SendMetrics(context.Background(), mtr, server.URL, secretKey, publicKey) } diff --git a/internal/pkg/middleware/crypt.go b/internal/pkg/middleware/crypt.go index 68fe362..9d6b544 100644 --- a/internal/pkg/middleware/crypt.go +++ b/internal/pkg/middleware/crypt.go @@ -1,11 +1,17 @@ package middleware import ( + "bytes" "crypto/hmac" + "crypto/rand" + "crypto/rsa" "crypto/sha256" + "crypto/x509" "encoding/hex" + "encoding/pem" "io" "net/http" + "os" ) func CryptMiddleware(signKey []byte) func(handler http.Handler) http.Handler { @@ -51,3 +57,40 @@ func CryptMiddleware(signKey []byte) func(handler http.Handler) http.Handler { }) } } + +func DecryptMiddleware(privateKeyPath string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if privateKeyPath == "" { + return + } + + privateKeyPEM, err := os.ReadFile(privateKeyPath) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + privateKeyBlock, _ := pem.Decode(privateKeyPEM) + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + plaintext, err := rsa.DecryptPKCS1v15(rand.Reader, privateKey, body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + r.Body = io.NopCloser(bytes.NewBuffer(plaintext)) + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/pkg/server/config/config.go b/internal/pkg/server/config/config.go index 05f0dfb..b3f66b9 100644 --- a/internal/pkg/server/config/config.go +++ b/internal/pkg/server/config/config.go @@ -14,6 +14,7 @@ type ServerConfig struct { DatabaseDSN string `env:"DATABASE_DSN"` StoreInterval int `env:"STORE_INTERVAL"` Restore bool `env:"RESTORE"` + PrivateKey string `env:"CRYPTO_KEY"` SignKeyByte []byte } @@ -45,5 +46,6 @@ func (c *ServerConfig) init() { flag.BoolVar(&c.Restore, "r", true, "Restore") flag.StringVar(&c.DatabaseDSN, "d", "", "Connect database string") flag.StringVar(&c.SignKey, "k", "", "Server key") + flag.StringVar(&c.PrivateKey, "-crypto-key", "", "Private key path") flag.Parse() } diff --git a/internal/pkg/server/server.go b/internal/pkg/server/server.go index 1a88051..ae03bc5 100644 --- a/internal/pkg/server/server.go +++ b/internal/pkg/server/server.go @@ -54,6 +54,7 @@ func StartListener(c *config.ServerConfig) { mux.Use( middleware.LoggingMiddleware, middleware.CryptMiddleware(c.SignKeyByte), + middleware.DecryptMiddleware(c.PrivateKey), ) RegisterHandlers(mux, metricStore) From b0c4e32c380acb42cf143750c2d4cacad623dc6d Mon Sep 17 00:00:00 2001 From: Alexander Lazarev Date: Wed, 15 Nov 2023 15:33:36 +0300 Subject: [PATCH 2/3] added json config --- cmd/agent/main.go | 5 ++- cmd/server/main.go | 5 ++- internal/pkg/agent/config/config.go | 40 +++++++++++++++++---- internal/pkg/agent/config/config_test.go | 2 +- internal/pkg/server/config/config.go | 44 ++++++++++++++++++----- internal/pkg/server/config/config_test.go | 2 +- 6 files changed, 78 insertions(+), 20 deletions(-) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index bc5f53a..e941db6 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -29,6 +29,9 @@ func main() { log.Fatal(err) } - cfg := config.NewAgentConfig() + cfg, err := config.NewAgentConfig() + if err != nil { + log.Fatal(err) + } agent.StartClient(ctx, cfg) } diff --git a/cmd/server/main.go b/cmd/server/main.go index 008132d..53074d5 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -19,6 +19,9 @@ func main() { log.Fatal(err) } - cfg := config.NewServerConfig() + cfg, err := config.NewServerConfig() + if err != nil { + log.Fatal(err) + } server.StartListener(cfg) } diff --git a/internal/pkg/agent/config/config.go b/internal/pkg/agent/config/config.go index 2cd9c46..2e8d991 100644 --- a/internal/pkg/agent/config/config.go +++ b/internal/pkg/agent/config/config.go @@ -1,19 +1,22 @@ package config import ( + "encoding/json" "flag" + "os" "github.com/caarlos0/env/v6" "github.com/sirupsen/logrus" ) type AgentConfig struct { - ServerAddress string `env:"ADDRESS"` + ServerAddress string `env:"ADDRESS" json:"server_address"` SignKey string `env:"KEY"` - ReportInterval int `env:"REPORT_INTERVAL"` - PollInterval int `env:"POLL_INTERVAL"` + ReportInterval int `env:"REPORT_INTERVAL" json:"report_interval"` + PollInterval int `env:"POLL_INTERVAL" json:"poll_interval"` RateLimit int `env:"RATE_LIMIT"` - PublicKey string `env:"CRYPTO_KEY"` + PublicKey string `env:"CRYPTO_KEY" json:"crypto_key"` + ConfigPath string `env:"CONFIG"` SignKeyByte []byte } @@ -24,7 +27,7 @@ const ( rateLimitDefault = 3 ) -func NewAgentConfig() *AgentConfig { +func NewAgentConfig() (*AgentConfig, error) { cfg := AgentConfig{} cfg.init() @@ -32,11 +35,18 @@ func NewAgentConfig() *AgentConfig { cfg.SignKeyByte = []byte(cfg.SignKey) } + if cfg.ConfigPath != "" { + cfgJSON, err := readConfigFile(cfg.ConfigPath) + if err != nil { + return cfgJSON, err + } + } + if err := env.Parse(&cfg); err != nil { logrus.Errorf("env parsing error: %v", err) - return nil + return nil, err } - return &cfg + return &cfg, nil } func (c *AgentConfig) init() { @@ -46,5 +56,21 @@ func (c *AgentConfig) init() { flag.StringVar(&c.SignKey, "k", "", "Server key") flag.IntVar(&c.RateLimit, "l", rateLimitDefault, "Rate limit") flag.StringVar(&c.PublicKey, "-crypto-key", "", "Public key path") + flag.StringVar(&c.ConfigPath, "c", "", "Path to config file") + flag.StringVar(&c.ConfigPath, "config", "", "Path to config file (the same as -c)") flag.Parse() } + +func readConfigFile(path string) (cfg *AgentConfig, err error) { + data, err := os.ReadFile(path) + if err != nil { + return cfg, err + } + + err = json.Unmarshal(data, &cfg) + if err != nil { + return cfg, err + } + + return cfg, nil +} diff --git a/internal/pkg/agent/config/config_test.go b/internal/pkg/agent/config/config_test.go index 327847c..8bc1c58 100644 --- a/internal/pkg/agent/config/config_test.go +++ b/internal/pkg/agent/config/config_test.go @@ -39,7 +39,7 @@ func TestNewAgentConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewAgentConfig(); !reflect.DeepEqual(got, tt.want) { + if got, _ := NewAgentConfig(); !reflect.DeepEqual(got, tt.want) { t.Errorf("NewAgentConfig() = %v, want %v", got, tt.want) } }) diff --git a/internal/pkg/server/config/config.go b/internal/pkg/server/config/config.go index b3f66b9..c32138c 100644 --- a/internal/pkg/server/config/config.go +++ b/internal/pkg/server/config/config.go @@ -1,20 +1,23 @@ package config import ( + "encoding/json" "flag" + "os" "github.com/caarlos0/env/v6" "github.com/sirupsen/logrus" ) type ServerConfig struct { - ServerAddress string `env:"ADDRESS"` + ServerAddress string `env:"ADDRESS" json:"server_address"` SignKey string `env:"KEY"` - FileStoragePath string `env:"FILE_STORAGE_PATH"` - DatabaseDSN string `env:"DATABASE_DSN"` - StoreInterval int `env:"STORE_INTERVAL"` - Restore bool `env:"RESTORE"` - PrivateKey string `env:"CRYPTO_KEY"` + FileStoragePath string `env:"FILE_STORAGE_PATH" json:"file_storage_path"` + DatabaseDSN string `env:"DATABASE_DSN" json:"database_dsn"` + StoreInterval int `env:"STORE_INTERVAL" json:"store_interval"` + Restore bool `env:"RESTORE" json:"restore"` + PrivateKey string `env:"CRYPTO_KEY" json:"crypto_key"` + ConfigPath string `env:"CONFIG"` SignKeyByte []byte } @@ -24,7 +27,7 @@ const ( filePathDefault = "/tmp/metrics-db.json" ) -func NewServerConfig() *ServerConfig { +func NewServerConfig() (*ServerConfig, error) { cfg := ServerConfig{} cfg.init() @@ -32,11 +35,18 @@ func NewServerConfig() *ServerConfig { cfg.SignKeyByte = []byte(cfg.SignKey) } + if cfg.ConfigPath != "" { + cfgJSON, err := readConfigFile(cfg.ConfigPath) + if err != nil { + return cfgJSON, err + } + } + if err := env.Parse(&cfg); err != nil { logrus.Errorf("env parsing error: %v", err) - return nil + return nil, err } - return &cfg + return &cfg, nil } func (c *ServerConfig) init() { @@ -47,5 +57,21 @@ func (c *ServerConfig) init() { flag.StringVar(&c.DatabaseDSN, "d", "", "Connect database string") flag.StringVar(&c.SignKey, "k", "", "Server key") flag.StringVar(&c.PrivateKey, "-crypto-key", "", "Private key path") + flag.StringVar(&c.ConfigPath, "c", "", "Path to config file") + flag.StringVar(&c.ConfigPath, "config", "", "Path to config file (the same as -c)") flag.Parse() } + +func readConfigFile(path string) (cfg *ServerConfig, err error) { + data, err := os.ReadFile(path) + if err != nil { + return cfg, err + } + + err = json.Unmarshal(data, &cfg) + if err != nil { + return cfg, err + } + + return cfg, nil +} diff --git a/internal/pkg/server/config/config_test.go b/internal/pkg/server/config/config_test.go index d1a18c4..884eddb 100644 --- a/internal/pkg/server/config/config_test.go +++ b/internal/pkg/server/config/config_test.go @@ -42,7 +42,7 @@ func TestNewServerConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewServerConfig(); !reflect.DeepEqual(got, tt.want) { + if got, _ := NewServerConfig(); !reflect.DeepEqual(got, tt.want) { t.Errorf("NewServerConfig() = %v, want %v", got, tt.want) } }) From 6ca2997978df58c3e75ea88431cd09d1c22f8dae Mon Sep 17 00:00:00 2001 From: Alexander Lazarev Date: Mon, 27 Nov 2023 10:14:05 +0300 Subject: [PATCH 3/3] refactoring for pr --- cmd/agent/main.go | 1 + internal/pkg/agent/config/config.go | 42 ++++++++++++++++++++++++---- internal/pkg/agent/encrypt.go | 21 ++------------ internal/pkg/agent/reporter.go | 18 ++++++------ internal/pkg/agent/reporter_test.go | 6 ++-- internal/pkg/middleware/crypt.go | 23 ++------------- internal/pkg/server/config/config.go | 28 +++++++++++++++---- internal/pkg/server/server.go | 12 ++++++-- 8 files changed, 87 insertions(+), 64 deletions(-) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index e941db6..8160b3c 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -33,5 +33,6 @@ func main() { if err != nil { log.Fatal(err) } + agent.StartClient(ctx, cfg) } diff --git a/internal/pkg/agent/config/config.go b/internal/pkg/agent/config/config.go index 2e8d991..b610e3f 100644 --- a/internal/pkg/agent/config/config.go +++ b/internal/pkg/agent/config/config.go @@ -1,7 +1,10 @@ package config import ( + "crypto/rsa" + "crypto/x509" "encoding/json" + "encoding/pem" "flag" "os" @@ -15,7 +18,8 @@ type AgentConfig struct { ReportInterval int `env:"REPORT_INTERVAL" json:"report_interval"` PollInterval int `env:"POLL_INTERVAL" json:"poll_interval"` RateLimit int `env:"RATE_LIMIT"` - PublicKey string `env:"CRYPTO_KEY" json:"crypto_key"` + PublicKeyPath string `env:"CRYPTO_KEY" json:"crypto_key"` + PublicKey *rsa.PublicKey ConfigPath string `env:"CONFIG"` SignKeyByte []byte } @@ -35,6 +39,11 @@ func NewAgentConfig() (*AgentConfig, error) { cfg.SignKeyByte = []byte(cfg.SignKey) } + if err := env.Parse(&cfg); err != nil { + logrus.Errorf("env parsing error: %v", err) + return nil, err + } + if cfg.ConfigPath != "" { cfgJSON, err := readConfigFile(cfg.ConfigPath) if err != nil { @@ -42,10 +51,14 @@ func NewAgentConfig() (*AgentConfig, error) { } } - if err := env.Parse(&cfg); err != nil { - logrus.Errorf("env parsing error: %v", err) - return nil, err + if cfg.PublicKeyPath != "" { + publicKey, err := cfg.getPublicKey() + if err != nil { + logrus.Errorf("error with get public key: %v", err) + } + cfg.PublicKey = publicKey } + return &cfg, nil } @@ -55,7 +68,7 @@ func (c *AgentConfig) init() { flag.IntVar(&c.PollInterval, "p", pollIntervalDefault, "Interval of poll metric") flag.StringVar(&c.SignKey, "k", "", "Server key") flag.IntVar(&c.RateLimit, "l", rateLimitDefault, "Rate limit") - flag.StringVar(&c.PublicKey, "-crypto-key", "", "Public key path") + flag.StringVar(&c.PublicKeyPath, "-crypto-key", "", "Public key path") flag.StringVar(&c.ConfigPath, "c", "", "Path to config file") flag.StringVar(&c.ConfigPath, "config", "", "Path to config file (the same as -c)") flag.Parse() @@ -74,3 +87,22 @@ func readConfigFile(path string) (cfg *AgentConfig, err error) { return cfg, nil } + +func (c *AgentConfig) getPublicKey() (*rsa.PublicKey, error) { + publicKeyPEM, err := os.ReadFile(c.PublicKeyPath) + if err != nil { + return nil, err + } + + publicKeyBlock, _ := pem.Decode(publicKeyPEM) + if publicKeyBlock == nil { + return nil, err + } + + publicKey, err := x509.ParsePKCS1PublicKey(publicKeyBlock.Bytes) + if err != nil { + return nil, err + } + + return publicKey, nil +} diff --git a/internal/pkg/agent/encrypt.go b/internal/pkg/agent/encrypt.go index c26e912..ba0f0f0 100644 --- a/internal/pkg/agent/encrypt.go +++ b/internal/pkg/agent/encrypt.go @@ -3,27 +3,10 @@ package agent import ( "crypto/rand" "crypto/rsa" - "crypto/x509" - "encoding/pem" - "os" ) -func encrypt(publicKeyPath string, data []byte) ([]byte, error) { - publicKeyPEM, err := os.ReadFile(publicKeyPath) - if err != nil { - return nil, err - } - publicKeyBlock, _ := pem.Decode(publicKeyPEM) - if publicKeyBlock == nil { - return nil, err - } - - publicKey, err := x509.ParsePKIXPublicKey(publicKeyBlock.Bytes) - if err != nil { - return nil, err - } - - ciphertext, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey.(*rsa.PublicKey), data) +func encrypt(key *rsa.PublicKey, data []byte) ([]byte, error) { + ciphertext, err := rsa.EncryptPKCS1v15(rand.Reader, key, data) if err != nil { return nil, err } diff --git a/internal/pkg/agent/reporter.go b/internal/pkg/agent/reporter.go index 3355794..93653d9 100644 --- a/internal/pkg/agent/reporter.go +++ b/internal/pkg/agent/reporter.go @@ -24,7 +24,7 @@ func RunSendMetric(ctx context.Context, reportTicker *time.Ticker, c *config.Age case <-ctx.Done(): return case <-reportTicker.C: - ok, err := SendMetrics(ctx, s, c.ServerAddress, c.SignKeyByte, c.PublicKey) + ok, err := SendMetrics(ctx, s, c) if err != nil { logrus.Errorf("Error send metrics %v", err) } @@ -37,7 +37,7 @@ func RunSendMetric(ctx context.Context, reportTicker *time.Ticker, c *config.Age } } -func SendMetrics(ctx context.Context, s storage.Store, serverAddress string, signKey []byte, publicKey string) (bool, error) { +func SendMetrics(ctx context.Context, s storage.Store, c *config.AgentConfig) (bool, error) { metricsMap, err := s.GetMetrics(ctx) if err != nil { logrus.Errorf("Some error ocured during metrics get: %q", err) @@ -49,22 +49,22 @@ func SendMetrics(ctx context.Context, s storage.Store, serverAddress string, sig metricsBatch = append(metricsBatch, v) } - url := fmt.Sprintf("http://%s/updates/", serverAddress) + url := fmt.Sprintf("http://%s/updates/", c.ServerAddress) - if err = SendBatchJSON(url, metricsBatch, signKey, publicKey); err != nil { + if err = SendBatchJSON(url, metricsBatch, c); err != nil { return false, fmt.Errorf("error create post request %w", err) } return true, nil } -func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, signKey []byte, publicKey string) error { +func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, c *config.AgentConfig) error { body, err := json.Marshal(metricsBatch) if err != nil { return fmt.Errorf("error encoding metric %w", err) } - if publicKey != "" { - body, err = encrypt(publicKey, body) + if c.PublicKeyPath != "" { + body, err = encrypt(c.PublicKey, body) if err != nil { return err } @@ -89,8 +89,8 @@ func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, signKey []byte, req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Encoding", "gzip") - if signKey != nil { - h := hmac.New(sha256.New, signKey) + if c.SignKeyByte != nil { + h := hmac.New(sha256.New, c.SignKeyByte) h.Write(body) serverHash := hex.EncodeToString(h.Sum(nil)) req.Header.Set("HashSHA256", serverHash) diff --git a/internal/pkg/agent/reporter_test.go b/internal/pkg/agent/reporter_test.go index 853f241..d6a0433 100644 --- a/internal/pkg/agent/reporter_test.go +++ b/internal/pkg/agent/reporter_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/mayr0y/animated-octo-couscous.git/internal/pkg/agent" + "github.com/mayr0y/animated-octo-couscous.git/internal/pkg/agent/config" "github.com/mayr0y/animated-octo-couscous.git/internal/pkg/metrics" "github.com/mayr0y/animated-octo-couscous.git/internal/pkg/storage" ) @@ -22,8 +23,7 @@ const ( ) func TestSendReport(t *testing.T) { - secretKey := []byte("secret") - publicKey := "" + c, _ := config.NewAgentConfig() mtr := storage.NewMetrics() err := agent.UpdateMetrics(context.Background(), mtr) @@ -71,5 +71,5 @@ func TestSendReport(t *testing.T) { })) defer server.Close() - agent.SendMetrics(context.Background(), mtr, server.URL, secretKey, publicKey) + agent.SendMetrics(context.Background(), mtr, c) } diff --git a/internal/pkg/middleware/crypt.go b/internal/pkg/middleware/crypt.go index 9d6b544..21e336b 100644 --- a/internal/pkg/middleware/crypt.go +++ b/internal/pkg/middleware/crypt.go @@ -6,12 +6,9 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" - "crypto/x509" "encoding/hex" - "encoding/pem" "io" "net/http" - "os" ) func CryptMiddleware(signKey []byte) func(handler http.Handler) http.Handler { @@ -58,32 +55,16 @@ func CryptMiddleware(signKey []byte) func(handler http.Handler) http.Handler { } } -func DecryptMiddleware(privateKeyPath string) func(http.Handler) http.Handler { +func DecryptMiddleware(key *rsa.PrivateKey) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if privateKeyPath == "" { - return - } - - privateKeyPEM, err := os.ReadFile(privateKeyPath) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - privateKeyBlock, _ := pem.Decode(privateKeyPEM) - privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - plaintext, err := rsa.DecryptPKCS1v15(rand.Reader, privateKey, body) + plaintext, err := rsa.DecryptPKCS1v15(rand.Reader, key, body) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/internal/pkg/server/config/config.go b/internal/pkg/server/config/config.go index c32138c..fd37652 100644 --- a/internal/pkg/server/config/config.go +++ b/internal/pkg/server/config/config.go @@ -1,7 +1,10 @@ package config import ( + "crypto/rsa" + "crypto/x509" "encoding/json" + "encoding/pem" "flag" "os" @@ -35,17 +38,17 @@ func NewServerConfig() (*ServerConfig, error) { cfg.SignKeyByte = []byte(cfg.SignKey) } + if err := env.Parse(&cfg); err != nil { + logrus.Errorf("env parsing error: %v", err) + return nil, err + } + if cfg.ConfigPath != "" { cfgJSON, err := readConfigFile(cfg.ConfigPath) if err != nil { return cfgJSON, err } } - - if err := env.Parse(&cfg); err != nil { - logrus.Errorf("env parsing error: %v", err) - return nil, err - } return &cfg, nil } @@ -75,3 +78,18 @@ func readConfigFile(path string) (cfg *ServerConfig, err error) { return cfg, nil } + +func (c *ServerConfig) GetPrivateKey() (*rsa.PrivateKey, error) { + privateKeyPEM, err := os.ReadFile(c.PrivateKey) + if err != nil { + return nil, err + } + + privateKeyBlock, _ := pem.Decode(privateKeyPEM) + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + if err != nil { + return nil, err + } + + return privateKey, nil +} diff --git a/internal/pkg/server/server.go b/internal/pkg/server/server.go index ae03bc5..013524c 100644 --- a/internal/pkg/server/server.go +++ b/internal/pkg/server/server.go @@ -54,9 +54,16 @@ func StartListener(c *config.ServerConfig) { mux.Use( middleware.LoggingMiddleware, middleware.CryptMiddleware(c.SignKeyByte), - middleware.DecryptMiddleware(c.PrivateKey), ) + if c.PrivateKey != "" { + privateKey, err := c.GetPrivateKey() + if err != nil { + logrus.Errorf("Error get private key: %v", err) + } + mux.Use(middleware.DecryptMiddleware(privateKey)) + } + RegisterHandlers(mux, metricStore) if c.Restore { @@ -64,6 +71,7 @@ func StartListener(c *config.ServerConfig) { logrus.Errorf("Error update metric from file %v", err) } } + if c.StoreInterval > 0 { storeInterval := time.NewTicker(time.Duration(c.StoreInterval) * time.Second) defer storeInterval.Stop() @@ -76,8 +84,8 @@ func StartListener(c *config.ServerConfig) { } }() } - wg := &sync.WaitGroup{} + wg := &sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done()