Skip to content

Commit

Permalink
refactoring for pr
Browse files Browse the repository at this point in the history
  • Loading branch information
alexlzrv committed Nov 27, 2023
1 parent b0c4e32 commit 58fa8d6
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 61 deletions.
1 change: 1 addition & 0 deletions cmd/agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ func main() {
if err != nil {
log.Fatal(err)
}

agent.StartClient(ctx, cfg)
}
42 changes: 37 additions & 5 deletions internal/pkg/agent/config/config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package config

import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"flag"
"os"

Expand All @@ -15,7 +18,8 @@ type AgentConfig struct {
ReportInterval int `env:"REPORT_INTERVAL" json:"report_interval"`
PollInterval int `env:"POLL_INTERVAL" json:"poll_interval"`
RateLimit int `env:"RATE_LIMIT"`
PublicKey string `env:"CRYPTO_KEY" json:"crypto_key"`
PublicKeyPath string `env:"CRYPTO_KEY" json:"crypto_key"`
PublicKey *rsa.PublicKey
ConfigPath string `env:"CONFIG"`
SignKeyByte []byte
}
Expand All @@ -35,17 +39,26 @@ func NewAgentConfig() (*AgentConfig, error) {
cfg.SignKeyByte = []byte(cfg.SignKey)
}

if err := env.Parse(&cfg); err != nil {
logrus.Errorf("env parsing error: %v", err)
return nil, err
}

if cfg.ConfigPath != "" {
cfgJSON, err := readConfigFile(cfg.ConfigPath)
if err != nil {
return cfgJSON, err
}
}

if err := env.Parse(&cfg); err != nil {
logrus.Errorf("env parsing error: %v", err)
return nil, err
if cfg.PublicKeyPath != "" {
publicKey, err := cfg.getPublicKey()
if err != nil {
logrus.Errorf("error with get public key: %v", err)
}
cfg.PublicKey = publicKey
}

return &cfg, nil
}

Expand All @@ -55,7 +68,7 @@ func (c *AgentConfig) init() {
flag.IntVar(&c.PollInterval, "p", pollIntervalDefault, "Interval of poll metric")
flag.StringVar(&c.SignKey, "k", "", "Server key")
flag.IntVar(&c.RateLimit, "l", rateLimitDefault, "Rate limit")
flag.StringVar(&c.PublicKey, "-crypto-key", "", "Public key path")
flag.StringVar(&c.PublicKeyPath, "-crypto-key", "", "Public key path")
flag.StringVar(&c.ConfigPath, "c", "", "Path to config file")
flag.StringVar(&c.ConfigPath, "config", "", "Path to config file (the same as -c)")
flag.Parse()
Expand All @@ -74,3 +87,22 @@ func readConfigFile(path string) (cfg *AgentConfig, err error) {

return cfg, nil
}

func (c *AgentConfig) getPublicKey() (*rsa.PublicKey, error) {
publicKeyPEM, err := os.ReadFile(c.PublicKeyPath)
if err != nil {
return nil, err
}

publicKeyBlock, _ := pem.Decode(publicKeyPEM)
if publicKeyBlock == nil {
return nil, err
}

publicKey, err := x509.ParsePKCS1PublicKey(publicKeyBlock.Bytes)
if err != nil {
return nil, err
}

return publicKey, nil
}
21 changes: 2 additions & 19 deletions internal/pkg/agent/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,10 @@ package agent
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"os"
)

func encrypt(publicKeyPath string, data []byte) ([]byte, error) {
publicKeyPEM, err := os.ReadFile(publicKeyPath)
if err != nil {
return nil, err
}
publicKeyBlock, _ := pem.Decode(publicKeyPEM)
if publicKeyBlock == nil {
return nil, err
}

publicKey, err := x509.ParsePKIXPublicKey(publicKeyBlock.Bytes)
if err != nil {
return nil, err
}

ciphertext, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey.(*rsa.PublicKey), data)
func encrypt(key *rsa.PublicKey, data []byte) ([]byte, error) {
ciphertext, err := rsa.EncryptPKCS1v15(rand.Reader, key, data)
if err != nil {
return nil, err
}
Expand Down
18 changes: 9 additions & 9 deletions internal/pkg/agent/reporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func RunSendMetric(ctx context.Context, reportTicker *time.Ticker, c *config.Age
case <-ctx.Done():
return
case <-reportTicker.C:
ok, err := SendMetrics(ctx, s, c.ServerAddress, c.SignKeyByte, c.PublicKey)
ok, err := SendMetrics(ctx, s, c)
if err != nil {
logrus.Errorf("Error send metrics %v", err)
}
Expand All @@ -37,7 +37,7 @@ func RunSendMetric(ctx context.Context, reportTicker *time.Ticker, c *config.Age
}
}

func SendMetrics(ctx context.Context, s storage.Store, serverAddress string, signKey []byte, publicKey string) (bool, error) {
func SendMetrics(ctx context.Context, s storage.Store, c *config.AgentConfig) (bool, error) {
metricsMap, err := s.GetMetrics(ctx)
if err != nil {
logrus.Errorf("Some error ocured during metrics get: %q", err)
Expand All @@ -49,22 +49,22 @@ func SendMetrics(ctx context.Context, s storage.Store, serverAddress string, sig
metricsBatch = append(metricsBatch, v)
}

url := fmt.Sprintf("http://%s/updates/", serverAddress)
url := fmt.Sprintf("http://%s/updates/", c.ServerAddress)

if err = SendBatchJSON(url, metricsBatch, signKey, publicKey); err != nil {
if err = SendBatchJSON(url, metricsBatch, c); err != nil {
return false, fmt.Errorf("error create post request %w", err)
}
return true, nil
}

func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, signKey []byte, publicKey string) error {
func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, c *config.AgentConfig) error {
body, err := json.Marshal(metricsBatch)
if err != nil {
return fmt.Errorf("error encoding metric %w", err)
}

if publicKey != "" {
body, err = encrypt(publicKey, body)
if c.PublicKeyPath != "" {
body, err = encrypt(c.PublicKey, body)
if err != nil {
return err
}
Expand All @@ -89,8 +89,8 @@ func SendBatchJSON(url string, metricsBatch []*metrics.Metrics, signKey []byte,
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Encoding", "gzip")

if signKey != nil {
h := hmac.New(sha256.New, signKey)
if c.SignKeyByte != nil {
h := hmac.New(sha256.New, c.SignKeyByte)
h.Write(body)
serverHash := hex.EncodeToString(h.Sum(nil))
req.Header.Set("HashSHA256", serverHash)
Expand Down
23 changes: 2 additions & 21 deletions internal/pkg/middleware/crypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"io"
"net/http"
"os"
)

func CryptMiddleware(signKey []byte) func(handler http.Handler) http.Handler {
Expand Down Expand Up @@ -58,32 +55,16 @@ func CryptMiddleware(signKey []byte) func(handler http.Handler) http.Handler {
}
}

func DecryptMiddleware(privateKeyPath string) func(http.Handler) http.Handler {
func DecryptMiddleware(key *rsa.PrivateKey) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if privateKeyPath == "" {
return
}

privateKeyPEM, err := os.ReadFile(privateKeyPath)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
privateKeyBlock, _ := pem.Decode(privateKeyPEM)
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

plaintext, err := rsa.DecryptPKCS1v15(rand.Reader, privateKey, body)
plaintext, err := rsa.DecryptPKCS1v15(rand.Reader, key, body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
Expand Down
28 changes: 23 additions & 5 deletions internal/pkg/server/config/config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package config

import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"flag"
"os"

Expand Down Expand Up @@ -35,17 +38,17 @@ func NewServerConfig() (*ServerConfig, error) {
cfg.SignKeyByte = []byte(cfg.SignKey)
}

if err := env.Parse(&cfg); err != nil {
logrus.Errorf("env parsing error: %v", err)
return nil, err
}

if cfg.ConfigPath != "" {
cfgJSON, err := readConfigFile(cfg.ConfigPath)
if err != nil {
return cfgJSON, err
}
}

if err := env.Parse(&cfg); err != nil {
logrus.Errorf("env parsing error: %v", err)
return nil, err
}
return &cfg, nil
}

Expand Down Expand Up @@ -75,3 +78,18 @@ func readConfigFile(path string) (cfg *ServerConfig, err error) {

return cfg, nil
}

func (c *ServerConfig) GetPrivateKey() (*rsa.PrivateKey, error) {
privateKeyPEM, err := os.ReadFile(c.PrivateKey)
if err != nil {
return nil, err
}

privateKeyBlock, _ := pem.Decode(privateKeyPEM)
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
if err != nil {
return nil, err
}

return privateKey, nil
}
12 changes: 10 additions & 2 deletions internal/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,24 @@ func StartListener(c *config.ServerConfig) {
mux.Use(
middleware.LoggingMiddleware,
middleware.CryptMiddleware(c.SignKeyByte),
middleware.DecryptMiddleware(c.PrivateKey),
)

if c.PrivateKey != "" {
privateKey, err := c.GetPrivateKey()
if err != nil {
logrus.Errorf("Error get private key: %v", err)
}
mux.Use(middleware.DecryptMiddleware(privateKey))
}

RegisterHandlers(mux, metricStore)

if c.Restore {
if err = metricStore.LoadMetrics(c.FileStoragePath); err != nil {
logrus.Errorf("Error update metric from file %v", err)
}
}

if c.StoreInterval > 0 {
storeInterval := time.NewTicker(time.Duration(c.StoreInterval) * time.Second)
defer storeInterval.Stop()
Expand All @@ -76,8 +84,8 @@ func StartListener(c *config.ServerConfig) {
}
}()
}
wg := &sync.WaitGroup{}

wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
Expand Down

0 comments on commit 58fa8d6

Please sign in to comment.