diff --git a/cmd/agent/main.go b/cmd/agent/main.go index e941db6..8160b3c 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -33,5 +33,6 @@ func main() { if err != nil { log.Fatal(err) } + agent.StartClient(ctx, cfg) } diff --git a/internal/pkg/agent/config/config.go b/internal/pkg/agent/config/config.go index 2e8d991..b610e3f 100644 --- a/internal/pkg/agent/config/config.go +++ b/internal/pkg/agent/config/config.go @@ -1,7 +1,10 @@ package config import ( + "crypto/rsa" + "crypto/x509" "encoding/json" + "encoding/pem" "flag" "os" @@ -15,7 +18,8 @@ type AgentConfig struct { ReportInterval int `env:"REPORT_INTERVAL" json:"report_interval"` PollInterval int `env:"POLL_INTERVAL" json:"poll_interval"` RateLimit int `env:"RATE_LIMIT"` - PublicKey string `env:"CRYPTO_KEY" json:"crypto_key"` + PublicKeyPath string `env:"CRYPTO_KEY" json:"crypto_key"` + PublicKey *rsa.PublicKey ConfigPath string `env:"CONFIG"` SignKeyByte []byte } @@ -35,6 +39,11 @@ func NewAgentConfig() (*AgentConfig, error) { cfg.SignKeyByte = []byte(cfg.SignKey) } + if err := env.Parse(&cfg); err != nil { + logrus.Errorf("env parsing error: %v", err) + return nil, err + } + if cfg.ConfigPath != "" { cfgJSON, err := readConfigFile(cfg.ConfigPath) if err != nil { @@ -42,10 +51,14 @@ func NewAgentConfig() (*AgentConfig, error) { } } - if err := env.Parse(&cfg); err != nil { - logrus.Errorf("env parsing error: %v", err) - return nil, err + if cfg.PublicKeyPath != "" { + publicKey, err := cfg.getPublicKey() + if err != nil { + logrus.Errorf("error with get public key: %v", err) + } + cfg.PublicKey = publicKey } + return &cfg, nil } @@ -55,7 +68,7 @@ func (c *AgentConfig) init() { flag.IntVar(&c.PollInterval, "p", pollIntervalDefault, "Interval of poll metric") flag.StringVar(&c.SignKey, "k", "", "Server key") flag.IntVar(&c.RateLimit, "l", rateLimitDefault, "Rate limit") - flag.StringVar(&c.PublicKey, "-crypto-key", "", "Public key path") + flag.StringVar(&c.PublicKeyPath, "-crypto-key", "", "Public key path") flag.StringVar(&c.ConfigPath, "c", "", "Path to config file") flag.StringVar(&c.ConfigPath, "config", "", "Path to config file (the same as -c)") flag.Parse() @@ -74,3 +87,22 @@ func readConfigFile(path string) (cfg *AgentConfig, err error) { return cfg, nil } + +func (c *AgentConfig) getPublicKey() (*rsa.PublicKey, error) { + publicKeyPEM, err := os.ReadFile(c.PublicKeyPath) + if err != nil { + return nil, err + } + + publicKeyBlock, _ := pem.Decode(publicKeyPEM) + if publicKeyBlock == nil { + return nil, err + } + + publicKey, err := x509.ParsePKCS1PublicKey(publicKeyBlock.Bytes) + if err != nil { + return nil, err + } + + return publicKey, nil +} diff --git a/internal/pkg/agent/encrypt.go b/internal/pkg/agent/encrypt.go index c26e912..ba0f0f0 100644 --- a/internal/pkg/agent/encrypt.go +++ b/internal/pkg/agent/encrypt.go @@ -3,27 +3,10 @@ package agent import ( "crypto/rand" "crypto/rsa" - "crypto/x509" - "encoding/pem" - "os" ) -func encrypt(publicKeyPath string, data []byte) ([]byte, error) { - publicKeyPEM, err := os.ReadFile(publicKeyPath) - if err != nil { - return nil, err - } - publicKeyBlock, _ := pem.Decode(publicKeyPEM) - if publicKeyBlock == nil { - return nil, err - } - - publicKey, err := x509.ParsePKIXPublicKey(publicKeyBlock.Bytes) - if err != nil { - return nil, err - } - - ciphertext, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey.(*rsa.PublicKey), data) +func encrypt(key *rsa.PublicKey, data []byte) ([]byte, error) { + ciphertext, err := rsa.EncryptPKCS1v15(rand.Reader, key, data) if err != nil { return nil, err } diff --git a/internal/pkg/agent/reporter.go b/internal/pkg/agent/reporter.go index 3355794..93653d9 100644 --- a/internal/pkg/agent/reporter.go +++ b/internal/pkg/agent/reporter.go @@ -24,7 +24,7 @@ func RunSendMetric(ctx context.Context, reportTicker *time.Ticker, c *config.Age case <-ctx.Done(): return case <-reportTicker.C: - ok, err := SendMetrics(ctx, s, c.ServerAddress, c.SignKeyByte, c.PublicKey) + ok, err := SendMetrics(ctx, s, c) if err != nil { logrus.Errorf("Error send metrics %v", err) } @@ -37,7 +37,7 @@ func RunSendMetric(ctx context.Context, reportTicker *time.Ticker, c *config.Age } } -func SendMetrics(ctx context.Context, s storage.Store, serverAddress string, signKey []byte, publicKey string) (bool, error) { +func SendMetrics(ctx context.Context, s storage.Store, c *config.AgentConfig) (bool, error) { metricsMap, err := s.GetMetrics(ctx) if err != nil { logrus.Errorf("Some error ocured during metrics get: %q", err) @@ -49,22 +49,22 @@ func SendMetrics(ctx context.Context, s storage.Store, serverAddress string, sig metricsBatch = append(metricsBatch, v) } - url := fmt.Sprintf("http://%s/updates/", serverAddress) + url := fmt.Sprintf("http://%s/updates/", c.ServerAddress) - if err = SendBatchJSON(url, metricsBatch, signKey, publicKey); err != nil { + if err = SendBatchJSON(url, metricsBatch, c); err != nil { return false, fmt.Errorf("error create post request %w", err) } return true, nil } -func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, signKey []byte, publicKey string) error { +func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, c *config.AgentConfig) error { body, err := json.Marshal(metricsBatch) if err != nil { return fmt.Errorf("error encoding metric %w", err) } - if publicKey != "" { - body, err = encrypt(publicKey, body) + if c.PublicKeyPath != "" { + body, err = encrypt(c.PublicKey, body) if err != nil { return err } @@ -89,8 +89,8 @@ func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, signKey []byte, req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Encoding", "gzip") - if signKey != nil { - h := hmac.New(sha256.New, signKey) + if c.SignKeyByte != nil { + h := hmac.New(sha256.New, c.SignKeyByte) h.Write(body) serverHash := hex.EncodeToString(h.Sum(nil)) req.Header.Set("HashSHA256", serverHash) diff --git a/internal/pkg/middleware/crypt.go b/internal/pkg/middleware/crypt.go index 9d6b544..21e336b 100644 --- a/internal/pkg/middleware/crypt.go +++ b/internal/pkg/middleware/crypt.go @@ -6,12 +6,9 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" - "crypto/x509" "encoding/hex" - "encoding/pem" "io" "net/http" - "os" ) func CryptMiddleware(signKey []byte) func(handler http.Handler) http.Handler { @@ -58,32 +55,16 @@ func CryptMiddleware(signKey []byte) func(handler http.Handler) http.Handler { } } -func DecryptMiddleware(privateKeyPath string) func(http.Handler) http.Handler { +func DecryptMiddleware(key *rsa.PrivateKey) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if privateKeyPath == "" { - return - } - - privateKeyPEM, err := os.ReadFile(privateKeyPath) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - privateKeyBlock, _ := pem.Decode(privateKeyPEM) - privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - plaintext, err := rsa.DecryptPKCS1v15(rand.Reader, privateKey, body) + plaintext, err := rsa.DecryptPKCS1v15(rand.Reader, key, body) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/internal/pkg/server/config/config.go b/internal/pkg/server/config/config.go index c32138c..fd37652 100644 --- a/internal/pkg/server/config/config.go +++ b/internal/pkg/server/config/config.go @@ -1,7 +1,10 @@ package config import ( + "crypto/rsa" + "crypto/x509" "encoding/json" + "encoding/pem" "flag" "os" @@ -35,17 +38,17 @@ func NewServerConfig() (*ServerConfig, error) { cfg.SignKeyByte = []byte(cfg.SignKey) } + if err := env.Parse(&cfg); err != nil { + logrus.Errorf("env parsing error: %v", err) + return nil, err + } + if cfg.ConfigPath != "" { cfgJSON, err := readConfigFile(cfg.ConfigPath) if err != nil { return cfgJSON, err } } - - if err := env.Parse(&cfg); err != nil { - logrus.Errorf("env parsing error: %v", err) - return nil, err - } return &cfg, nil } @@ -75,3 +78,18 @@ func readConfigFile(path string) (cfg *ServerConfig, err error) { return cfg, nil } + +func (c *ServerConfig) GetPrivateKey() (*rsa.PrivateKey, error) { + privateKeyPEM, err := os.ReadFile(c.PrivateKey) + if err != nil { + return nil, err + } + + privateKeyBlock, _ := pem.Decode(privateKeyPEM) + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + if err != nil { + return nil, err + } + + return privateKey, nil +} diff --git a/internal/pkg/server/server.go b/internal/pkg/server/server.go index ae03bc5..013524c 100644 --- a/internal/pkg/server/server.go +++ b/internal/pkg/server/server.go @@ -54,9 +54,16 @@ func StartListener(c *config.ServerConfig) { mux.Use( middleware.LoggingMiddleware, middleware.CryptMiddleware(c.SignKeyByte), - middleware.DecryptMiddleware(c.PrivateKey), ) + if c.PrivateKey != "" { + privateKey, err := c.GetPrivateKey() + if err != nil { + logrus.Errorf("Error get private key: %v", err) + } + mux.Use(middleware.DecryptMiddleware(privateKey)) + } + RegisterHandlers(mux, metricStore) if c.Restore { @@ -64,6 +71,7 @@ func StartListener(c *config.ServerConfig) { logrus.Errorf("Error update metric from file %v", err) } } + if c.StoreInterval > 0 { storeInterval := time.NewTicker(time.Duration(c.StoreInterval) * time.Second) defer storeInterval.Stop() @@ -76,8 +84,8 @@ func StartListener(c *config.ServerConfig) { } }() } - wg := &sync.WaitGroup{} + wg := &sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done()