diff --git a/example/go.mod b/example/go.mod new file mode 100644 index 0000000..2873377 --- /dev/null +++ b/example/go.mod @@ -0,0 +1,10 @@ +module main.go + +go 1.21 + +require github.com/uw-labs/strongbox v1.1.0 + +require ( + github.com/jacobsa/crypto v0.0.0-20190317225127-9f44e2d11115 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/example/go.sum b/example/go.sum new file mode 100644 index 0000000..64d4324 --- /dev/null +++ b/example/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jacobsa/crypto v0.0.0-20190317225127-9f44e2d11115 h1:YuDUUFNM21CAbyPOpOP8BicaTD/0klJEKt5p8yuw+uY= +github.com/jacobsa/crypto v0.0.0-20190317225127-9f44e2d11115/go.mod h1:LadVJg0XuawGk+8L1rYnIED8451UyNxEMdTWCEt5kmU= +github.com/jacobsa/oglematchers v0.0.0-20150720000706-141901ea67cd h1:9GCSedGjMcLZCrusBZuo4tyKLpKUPenUUqi34AkuFmA= +github.com/jacobsa/oglematchers v0.0.0-20150720000706-141901ea67cd/go.mod h1:TlmyIZDpGmwRoTWiakdr+HA1Tukze6C6XbRVidYq02M= +github.com/jacobsa/oglemock v0.0.0-20150831005832-e94d794d06ff h1:2xRHTvkpJ5zJmglXLRqHiZQNjUoOkhUyhTAhEQvPAWw= +github.com/jacobsa/oglemock v0.0.0-20150831005832-e94d794d06ff/go.mod h1:gJWba/XXGl0UoOmBQKRWCJdHrr3nE0T65t6ioaj3mLI= +github.com/jacobsa/ogletest v0.0.0-20170503003838-80d50a735a11 h1:BMb8s3ENQLt5ulwVIHVDWFHp8eIXmbfSExkvdn9qMXI= +github.com/jacobsa/ogletest v0.0.0-20170503003838-80d50a735a11/go.mod h1:+DBdDyfoO2McrOyDemRBq0q9CMEByef7sYl7JH5Q3BI= +github.com/jacobsa/reqtrace v0.0.0-20150505043853-245c9e0234cb h1:uSWBjJdMf47kQlXMwWEfmc864bA1wAC+Kl3ApryuG9Y= +github.com/jacobsa/reqtrace v0.0.0-20150505043853-245c9e0234cb/go.mod h1:ivcmUvxXWjb27NsPEaiYK7AidlZXS7oQ5PowUS9z3I4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/uw-labs/strongbox v1.1.0 h1:gIFhB+YFkY4wbD6ZU4/nZI26d1O6/TnSPg2ADJTV8Z4= +github.com/uw-labs/strongbox v1.1.0/go.mod h1:MeDTE5Nj3SAPmhZXuqju0KcZWJW3D1HPmU14buyWgqU= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d h1:20cMwl2fHAzkJMEA+8J4JgqBQcQGzbisXo31MIeenXI= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..59194c7 --- /dev/null +++ b/example/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "log" + "os" + + "github.com/uw-labs/strongbox" +) + +var ( + path = "" // Path to encrypted file or directory containing encrypted files + keyPath = "" // Path to strongbox key +) + +func main() { + key, err := os.ReadFile(keyPath) + if err != nil { + log.Printf("Error reading file: %v\n", err) + return + } + + keyBytes := []byte(key) // Convert key string into byte slice + + // Decode the key + dk, err := strongbox.Decode([]byte(keyBytes)) + if err != nil { + log.Fatalf("Unable to decode given private key %v", err) + } + + // Decrypt file(s) at the path provided + if err := strongbox.RecursiveDecrypt(path, dk); err != nil { + log.Fatalln(err) + } +} diff --git a/go.mod b/go.mod index 88215af..e6cb00b 100644 --- a/go.mod +++ b/go.mod @@ -1,18 +1,17 @@ module github.com/uw-labs/strongbox +replace github.com/uw-labs/strongbox/strongbox-lib => ./strongbox-lib + go 1.21 require ( - github.com/jacobsa/crypto v0.0.0-20190317225127-9f44e2d11115 - github.com/jacobsa/oglematchers v0.0.0-20150720000706-141901ea67cd // indirect - github.com/jacobsa/oglemock v0.0.0-20150831005832-e94d794d06ff // indirect - github.com/jacobsa/ogletest v0.0.0-20170503003838-80d50a735a11 // indirect - github.com/jacobsa/reqtrace v0.0.0-20150505043853-245c9e0234cb // indirect + github.com/jacobsa/crypto v0.0.0-20190317225127-9f44e2d11115 // indirect github.com/stretchr/testify v1.7.0 - golang.org/x/net v0.7.0 // indirect gopkg.in/yaml.v2 v2.4.0 ) +require github.com/uw-labs/strongbox/strongbox-lib v0.0.0-00010101000000-000000000000 + require ( github.com/davecgh/go-spew v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 7ba9094..1a75760 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/keyring.go b/keyring.go index a1df914..ded1bb9 100644 --- a/keyring.go +++ b/keyring.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + "github.com/uw-labs/strongbox/strongbox-lib" "gopkg.in/yaml.v2" ) @@ -30,17 +31,17 @@ type keyEntry struct { func (kr *fileKeyRing) AddKey(desc string, keyID []byte, key []byte) { kr.KeyEntries = append(kr.KeyEntries, keyEntry{ Description: desc, - KeyID: string(encode(keyID[:])), - Key: string(encode(key[:])), + KeyID: string(strongbox.Encode(keyID[:])), + Key: string(strongbox.Encode(key[:])), }) } func (kr *fileKeyRing) Key(keyID []byte) ([]byte, error) { - b64 := string(encode(keyID[:])) + b64 := string(strongbox.Encode(keyID[:])) for _, ke := range kr.KeyEntries { if ke.KeyID == b64 { - dec, err := decode([]byte(ke.Key)) + dec, err := strongbox.Decode([]byte(ke.Key)) if err != nil { return []byte{}, err } diff --git a/strongbox-lib/go.mod b/strongbox-lib/go.mod new file mode 100644 index 0000000..7ece1c5 --- /dev/null +++ b/strongbox-lib/go.mod @@ -0,0 +1,16 @@ +module strongbox + +go 1.21 + +require ( + github.com/jacobsa/crypto v0.0.0-20190317225127-9f44e2d11115 + gopkg.in/yaml.v2 v2.4.0 +) + +require ( + github.com/jacobsa/oglematchers v0.0.0-20150720000706-141901ea67cd // indirect + github.com/jacobsa/oglemock v0.0.0-20150831005832-e94d794d06ff // indirect + github.com/jacobsa/ogletest v0.0.0-20170503003838-80d50a735a11 // indirect + github.com/jacobsa/reqtrace v0.0.0-20150505043853-245c9e0234cb // indirect + golang.org/x/net v0.15.0 // indirect +) diff --git a/strongbox-lib/go.sum b/strongbox-lib/go.sum new file mode 100644 index 0000000..4e0494d --- /dev/null +++ b/strongbox-lib/go.sum @@ -0,0 +1,16 @@ +github.com/jacobsa/crypto v0.0.0-20190317225127-9f44e2d11115 h1:YuDUUFNM21CAbyPOpOP8BicaTD/0klJEKt5p8yuw+uY= +github.com/jacobsa/crypto v0.0.0-20190317225127-9f44e2d11115/go.mod h1:LadVJg0XuawGk+8L1rYnIED8451UyNxEMdTWCEt5kmU= +github.com/jacobsa/oglematchers v0.0.0-20150720000706-141901ea67cd h1:9GCSedGjMcLZCrusBZuo4tyKLpKUPenUUqi34AkuFmA= +github.com/jacobsa/oglematchers v0.0.0-20150720000706-141901ea67cd/go.mod h1:TlmyIZDpGmwRoTWiakdr+HA1Tukze6C6XbRVidYq02M= +github.com/jacobsa/oglemock v0.0.0-20150831005832-e94d794d06ff h1:2xRHTvkpJ5zJmglXLRqHiZQNjUoOkhUyhTAhEQvPAWw= +github.com/jacobsa/oglemock v0.0.0-20150831005832-e94d794d06ff/go.mod h1:gJWba/XXGl0UoOmBQKRWCJdHrr3nE0T65t6ioaj3mLI= +github.com/jacobsa/ogletest v0.0.0-20170503003838-80d50a735a11 h1:BMb8s3ENQLt5ulwVIHVDWFHp8eIXmbfSExkvdn9qMXI= +github.com/jacobsa/ogletest v0.0.0-20170503003838-80d50a735a11/go.mod h1:+DBdDyfoO2McrOyDemRBq0q9CMEByef7sYl7JH5Q3BI= +github.com/jacobsa/reqtrace v0.0.0-20150505043853-245c9e0234cb h1:uSWBjJdMf47kQlXMwWEfmc864bA1wAC+Kl3ApryuG9Y= +github.com/jacobsa/reqtrace v0.0.0-20150505043853-245c9e0234cb/go.mod h1:ivcmUvxXWjb27NsPEaiYK7AidlZXS7oQ5PowUS9z3I4= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/strongbox-lib/keyring.go b/strongbox-lib/keyring.go new file mode 100644 index 0000000..4fbb85f --- /dev/null +++ b/strongbox-lib/keyring.go @@ -0,0 +1,84 @@ +package strongbox + +import ( + "fmt" + "log" + "os" + "path/filepath" + + "gopkg.in/yaml.v2" +) + +type keyRing interface { + Load() error + Save() error + AddKey(name string, keyID []byte, key []byte) + Key(keyID []byte) ([]byte, error) +} + +type fileKeyRing struct { + fileName string + KeyEntries []keyEntry +} + +type keyEntry struct { + Description string `yaml:"description"` + KeyID string `yaml:"key-id"` + Key string `yaml:"key"` +} + +func (kr *fileKeyRing) AddKey(desc string, keyID []byte, key []byte) { + kr.KeyEntries = append(kr.KeyEntries, keyEntry{ + Description: desc, + KeyID: string(Encode(keyID[:])), + Key: string(Encode(key[:])), + }) +} + +func (kr *fileKeyRing) Key(keyID []byte) ([]byte, error) { + b64 := string(Encode(keyID[:])) + + for _, ke := range kr.KeyEntries { + if ke.KeyID == b64 { + dec, err := Decode([]byte(ke.Key)) + if err != nil { + return []byte{}, err + } + if len(dec) != 32 { + return []byte{}, fmt.Errorf("unexpected length of key: %d", len(dec)) + } + return dec, nil + } + } + + return []byte{}, errKeyNotFound +} + +func (kr *fileKeyRing) Load() error { + + bytes, err := os.ReadFile(kr.fileName) + if err != nil { + return err + } + + err = yaml.Unmarshal(bytes, kr) + return err +} + +func (kr *fileKeyRing) Save() error { + ser, err := yaml.Marshal(kr) + if err != nil { + log.Fatal(err) + } + + path := filepath.Dir(kr.fileName) + _, err = os.Stat(path) + if os.IsNotExist(err) { + err := os.MkdirAll(path, 0700) + if err != nil { + return fmt.Errorf("error creating strongbox home folder: %s", err) + } + } + + return os.WriteFile(kr.fileName, ser, 0600) +} diff --git a/strongbox-lib/main.go b/strongbox-lib/main.go new file mode 100644 index 0000000..0462edd --- /dev/null +++ b/strongbox-lib/main.go @@ -0,0 +1,425 @@ +package strongbox + +import ( + "bytes" + "compress/gzip" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "flag" + "fmt" + "io" + "io/fs" + "log" + "os" + "os/exec" + "os/user" + "path/filepath" + "strings" + + "github.com/jacobsa/crypto/siv" +) + +var ( + keyLoader = Key + kr keyRing + prefix = []byte("# STRONGBOX ENCRYPTED RESOURCE ;") + defaultPrefix = []byte("# STRONGBOX ENCRYPTED RESOURCE ; See https://github.com/uw-labs/strongbox\n") + + errKeyNotFound = errors.New("key not found") +) + +func DeriveHome() string { + // try explicitly set STRONGBOX_HOME + if home := os.Getenv("STRONGBOX_HOME"); home != "" { + return home + } + // Try user.Current which works in most cases, but may not work with CGO disabled. + u, err := user.Current() + if err == nil && u.HomeDir != "" { + return u.HomeDir + } + // try HOME env var + if home := os.Getenv("HOME"); home != "" { + return home + } + + log.Fatal("Could not call os/user.Current() or find $STRONGBOX_HOME or $HOME. Please recompile with CGO enabled or set $STRONGBOX_HOME or $HOME") + // not reached + return "" +} + +func DecryptCLI(flagKey string) { + var fn string + if flag.Arg(0) == "" { + // no file passed, try to read stdin + fn = "/dev/stdin" + } else { + fn = flag.Arg(0) + } + fb, err := os.ReadFile(fn) + if err != nil { + log.Fatalf("Unable to read file to decrypt %v", err) + } + dk, err := Decode([]byte(flagKey)) + if err != nil { + log.Fatalf("Unable to decode private key %v", err) + } + out, err := decrypt(fb, dk) + if err != nil { + log.Fatalf("Unable to decrypt %v", err) + } + fmt.Printf("%s", out) +} + +func GitConfig() { + args := [][]string{ + {"config", "--global", "--replace-all", "filter.strongbox.clean", "strongbox -clean %f"}, + {"config", "--global", "--replace-all", "filter.strongbox.smudge", "strongbox -smudge %f"}, + {"config", "--global", "--replace-all", "filter.strongbox.required", "true"}, + + {"config", "--global", "--replace-all", "diff.strongbox.textconv", "strongbox -diff"}, + } + for _, command := range args { + cmd := exec.Command("git", command...) + if out, err := cmd.CombinedOutput(); err != nil { + log.Fatal(string(out)) + } + } + log.Println("git global configuration updated successfully") +} + +func GenKey(desc string) { + err := kr.Load() + if err != nil && !os.IsNotExist(err) { + log.Fatal(err) + } + + key := make([]byte, 32) + _, err = rand.Read(key) + if err != nil { + log.Fatal(err) + } + + keyID := sha256.Sum256(key) + + kr.AddKey(desc, keyID[:], key) + + err = kr.Save() + if err != nil { + log.Fatal(err) + } +} + +func Diff(filename string) { + f, err := os.Open(filename) + if err != nil { + log.Fatal(err) + } + defer func() { + if err = f.Close(); err != nil { + log.Fatal(err) + } + }() + _, err = io.Copy(os.Stdout, f) + if err != nil { + log.Fatal(err) + } +} + +func Clean(r io.Reader, w io.Writer, filename string) { + // Read the file, fail on error + in, err := io.ReadAll(r) + if err != nil { + log.Fatal(err) + } + // Check the file is plaintext, if its an encrypted strongbox file, copy as is, and exit 0 + if bytes.HasPrefix(in, prefix) { + _, err = io.Copy(w, bytes.NewReader(in)) + if err != nil { + log.Fatal(err) + } + return + } + // File is plaintext and needs to be encrypted, get the key, fail on error + key, err := keyLoader(filename) + if err != nil { + log.Fatal(err) + } + // encrypt the file, fail on error + out, err := encrypt(in, key) + if err != nil { + log.Fatal(err) + } + // write out encrypted file, fail on error + _, err = io.Copy(w, bytes.NewReader(out)) + if err != nil { + log.Fatal(err) + } +} + +// Called by git on `git checkout` +func Smudge(r io.Reader, w io.Writer, filename string) { + in, err := io.ReadAll(r) + if err != nil { + log.Fatal(err) + } + + // file is a non-strongbox file, copy as is and exit + if !bytes.HasPrefix(in, prefix) { + _, err = io.Copy(w, bytes.NewReader(in)) + if err != nil { + log.Fatal(err) + } + return + } + + key, err := keyLoader(filename) + if err != nil { + // don't log error if its keyNotFound + switch err { + case errKeyNotFound: + default: + log.Println(err) + } + // Couldn't load the key, just copy as is and return + if _, err = io.Copy(w, bytes.NewReader(in)); err != nil { + log.Println(err) + } + return + } + + out, err := decrypt(in, key) + if err != nil { + log.Println(err) + out = in + } + if _, err := io.Copy(w, bytes.NewReader(out)); err != nil { + log.Println(err) + } +} + +// recursiveDecrypt will try and recursively decrypt files +// if 'key' is provided then it will decrypt all encrypted files with given key +// otherwise it will find key based on file location +// if error is generated in finding key or in decryption then it will continue with next file +// function will only return early if it failed to read/write files +func RecursiveDecrypt(target string, givenKey []byte) error { + var decErrors []string + err := filepath.WalkDir(target, func(path string, entry fs.DirEntry, err error) error { + // always return on error + if err != nil { + return err + } + + // only process files + if entry.IsDir() { + // skip .git directory + if entry.Name() == ".git" { + return fs.SkipDir + } + return nil + } + + file, err := os.OpenFile(path, os.O_RDWR, 0) + if err != nil { + return err + } + defer file.Close() + + // for optimisation only read required chunk of the file and verify if encrypted + chunk := make([]byte, len(defaultPrefix)) + _, err = file.Read(chunk) + if err != nil && err != io.EOF { + return err + } + + if !bytes.HasPrefix(chunk, prefix) { + return nil + } + + key := givenKey + if len(key) == 0 { + key, err = keyLoader(path) + if err != nil { + // continue with next file + decErrors = append(decErrors, fmt.Sprintf("unable to find key file:%s err:%s", path, err)) + return nil + } + } + + // read entire file from the beginning + file.Seek(0, io.SeekStart) + in, err := io.ReadAll(file) + if err != nil { + return err + } + + out, err := decrypt(in, key) + if err != nil { + // continue with next file + decErrors = append(decErrors, fmt.Sprintf("unable to decrypt file:%s err:%s", path, err)) + return nil + } + + if err := file.Truncate(0); err != nil { + return err + } + if _, err := file.Seek(0, io.SeekStart); err != nil { + return err + } + if _, err := file.Write(out); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + if len(decErrors) > 0 { + for _, e := range decErrors { + log.Println(e) + } + return fmt.Errorf("unable to decrypt some files") + } + + return nil +} + +func encrypt(b, key []byte) ([]byte, error) { + b = compress(b) + out, err := siv.Encrypt(nil, key, b, nil) + if err != nil { + return nil, err + } + var buf []byte + buf = append(buf, defaultPrefix...) + b64 := Encode(out) + for len(b64) > 0 { + l := 76 + if len(b64) < 76 { + l = len(b64) + } + buf = append(buf, b64[0:l]...) + buf = append(buf, '\n') + b64 = b64[l:] + } + return buf, nil +} + +func compress(b []byte) []byte { + var buf bytes.Buffer + zw := gzip.NewWriter(&buf) + _, err := zw.Write(b) + if err != nil { + log.Fatal(err) + } + if err := zw.Close(); err != nil { + log.Fatal(err) + } + return buf.Bytes() +} + +func decompress(b []byte) []byte { + zr, err := gzip.NewReader(bytes.NewReader(b)) + if err != nil { + log.Fatal(err) + } + b, err = io.ReadAll(zr) + if err != nil { + log.Fatal(err) + } + if err := zr.Close(); err != nil { + log.Fatal(err) + } + return b +} + +func Encode(decoded []byte) []byte { + b64 := make([]byte, base64.StdEncoding.EncodedLen(len(decoded))) + base64.StdEncoding.Encode(b64, decoded) + return b64 +} + +func Decode(encoded []byte) ([]byte, error) { + decoded := make([]byte, len(encoded)) + i, err := base64.StdEncoding.Decode(decoded, encoded) + if err != nil { + return nil, err + } + return decoded[0:i], nil +} + +func decrypt(enc []byte, priv []byte) ([]byte, error) { + // strip prefix and any comment up to end of line + spl := bytes.SplitN(enc, []byte("\n"), 2) + if len(spl) != 2 { + return nil, errors.New("couldn't split on end of line") + } + b64encoded := spl[1] + b64decoded, err := Decode(b64encoded) + if err != nil { + return nil, err + } + decrypted, err := siv.Decrypt(priv, b64decoded, nil) + if err != nil { + return nil, err + } + decrypted = decompress(decrypted) + return decrypted, nil +} + +// key returns private key and error +func Key(filename string) ([]byte, error) { + keyID, err := findKey(filename) + if err != nil { + return []byte{}, err + } + + err = kr.Load() + if err != nil { + return []byte{}, err + } + + key, err := kr.Key(keyID) + if err != nil { + return []byte{}, err + } + + return key, nil +} + +func findKey(filename string) ([]byte, error) { + path := filepath.Dir(filename) + for { + if fi, err := os.Stat(path); err == nil && fi.IsDir() { + keyFilename := filepath.Join(path, ".strongbox-keyid") + if keyFile, err := os.Stat(keyFilename); err == nil && !keyFile.IsDir() { + return readKey(keyFilename) + } + } + if path == "." { + break + } + path = filepath.Dir(path) + } + return []byte{}, fmt.Errorf("failed to find key id for file %s", filename) +} + +func readKey(filename string) ([]byte, error) { + fp, err := os.ReadFile(filename) + if err != nil { + return []byte{}, err + } + + b64 := strings.TrimSpace(string(fp)) + b, err := Decode([]byte(b64)) + if err != nil { + return []byte{}, err + } + if len(b) != 32 { + return []byte{}, fmt.Errorf("unexpected key length %d", len(b)) + } + return b, nil +} diff --git a/strongbox.go b/strongbox.go index dfbdc1a..1f12f4d 100644 --- a/strongbox.go +++ b/strongbox.go @@ -1,24 +1,14 @@ package main import ( - "bytes" - "compress/gzip" - "crypto/rand" - "crypto/sha256" - "encoding/base64" "errors" "flag" "fmt" - "io" - "io/fs" "log" "os" - "os/exec" - "os/user" "path/filepath" - "strings" - "github.com/jacobsa/crypto/siv" + "github.com/uw-labs/strongbox/strongbox-lib" ) var ( @@ -27,7 +17,7 @@ var ( date = "unknown" builtBy = "unknown" - keyLoader = key + keyLoader = strongbox.Key kr keyRing prefix = []byte("# STRONGBOX ENCRYPTED RESOURCE ;") defaultPrefix = []byte("# STRONGBOX ENCRYPTED RESOURCE ; See https://github.com/uw-labs/strongbox\n") @@ -74,17 +64,17 @@ func main() { } if *flagGitConfig { - gitConfig() + strongbox.GitConfig() return } if *flagDiff != "" { - diff(*flagDiff) + strongbox.Diff(*flagDiff) return } // Set up keyring file name - home := deriveHome() + home := strongbox.DeriveHome() kr = &fileKeyRing{fileName: filepath.Join(home, ".strongbox_keyring")} // if keyring flag is set replace default keyRing @@ -97,7 +87,7 @@ func main() { } if *flagGenKey != "" { - genKey(*flagGenKey) + strongbox.GenKey(*flagGenKey) return } @@ -115,12 +105,12 @@ func main() { } // for recursive decryption 'key' flag is optional but if provided // it should be valid and all encrypted file will be decrypted using it - dk, err := decode([]byte(*flagKey)) + dk, err := strongbox.Decode([]byte(*flagKey)) if err != nil && *flagKey != "" { log.Fatalf("Unable to decode given private key %v", err) } - if err = recursiveDecrypt(target, dk); err != nil { + if err = strongbox.RecursiveDecrypt(target, dk); err != nil { log.Fatalln(err) } return @@ -129,7 +119,7 @@ func main() { if *flagKey == "" { log.Fatalf("Must provide a `-key` when using -decrypt") } - decryptCLI() + strongbox.DecryptCLI(*flagKey) return } @@ -139,405 +129,12 @@ func main() { } if *flagClean != "" { - clean(os.Stdin, os.Stdout, *flagClean) + strongbox.Clean(os.Stdin, os.Stdout, *flagClean) return } if *flagSmudge != "" { - smudge(os.Stdin, os.Stdout, *flagSmudge) + strongbox.Smudge(os.Stdin, os.Stdout, *flagSmudge) return } } -func deriveHome() string { - // try explicitly set STRONGBOX_HOME - if home := os.Getenv("STRONGBOX_HOME"); home != "" { - return home - } - // Try user.Current which works in most cases, but may not work with CGO disabled. - u, err := user.Current() - if err == nil && u.HomeDir != "" { - return u.HomeDir - } - // try HOME env var - if home := os.Getenv("HOME"); home != "" { - return home - } - - log.Fatal("Could not call os/user.Current() or find $STRONGBOX_HOME or $HOME. Please recompile with CGO enabled or set $STRONGBOX_HOME or $HOME") - // not reached - return "" -} - -func decryptCLI() { - var fn string - if flag.Arg(0) == "" { - // no file passed, try to read stdin - fn = "/dev/stdin" - } else { - fn = flag.Arg(0) - } - fb, err := os.ReadFile(fn) - if err != nil { - log.Fatalf("Unable to read file to decrypt %v", err) - } - dk, err := decode([]byte(*flagKey)) - if err != nil { - log.Fatalf("Unable to decode private key %v", err) - } - out, err := decrypt(fb, dk) - if err != nil { - log.Fatalf("Unable to decrypt %v", err) - } - fmt.Printf("%s", out) -} - -func gitConfig() { - args := [][]string{ - {"config", "--global", "--replace-all", "filter.strongbox.clean", "strongbox -clean %f"}, - {"config", "--global", "--replace-all", "filter.strongbox.smudge", "strongbox -smudge %f"}, - {"config", "--global", "--replace-all", "filter.strongbox.required", "true"}, - - {"config", "--global", "--replace-all", "diff.strongbox.textconv", "strongbox -diff"}, - } - for _, command := range args { - cmd := exec.Command("git", command...) - if out, err := cmd.CombinedOutput(); err != nil { - log.Fatal(string(out)) - } - } - log.Println("git global configuration updated successfully") -} - -func genKey(desc string) { - err := kr.Load() - if err != nil && !os.IsNotExist(err) { - log.Fatal(err) - } - - key := make([]byte, 32) - _, err = rand.Read(key) - if err != nil { - log.Fatal(err) - } - - keyID := sha256.Sum256(key) - - kr.AddKey(desc, keyID[:], key) - - err = kr.Save() - if err != nil { - log.Fatal(err) - } -} - -func diff(filename string) { - f, err := os.Open(filename) - if err != nil { - log.Fatal(err) - } - defer func() { - if err = f.Close(); err != nil { - log.Fatal(err) - } - }() - _, err = io.Copy(os.Stdout, f) - if err != nil { - log.Fatal(err) - } -} - -func clean(r io.Reader, w io.Writer, filename string) { - // Read the file, fail on error - in, err := io.ReadAll(r) - if err != nil { - log.Fatal(err) - } - // Check the file is plaintext, if its an encrypted strongbox file, copy as is, and exit 0 - if bytes.HasPrefix(in, prefix) { - _, err = io.Copy(w, bytes.NewReader(in)) - if err != nil { - log.Fatal(err) - } - return - } - // File is plaintext and needs to be encrypted, get the key, fail on error - key, err := keyLoader(filename) - if err != nil { - log.Fatal(err) - } - // encrypt the file, fail on error - out, err := encrypt(in, key) - if err != nil { - log.Fatal(err) - } - // write out encrypted file, fail on error - _, err = io.Copy(w, bytes.NewReader(out)) - if err != nil { - log.Fatal(err) - } -} - -// Called by git on `git checkout` -func smudge(r io.Reader, w io.Writer, filename string) { - in, err := io.ReadAll(r) - if err != nil { - log.Fatal(err) - } - - // file is a non-strongbox file, copy as is and exit - if !bytes.HasPrefix(in, prefix) { - _, err = io.Copy(w, bytes.NewReader(in)) - if err != nil { - log.Fatal(err) - } - return - } - - key, err := keyLoader(filename) - if err != nil { - // don't log error if its keyNotFound - switch err { - case errKeyNotFound: - default: - log.Println(err) - } - // Couldn't load the key, just copy as is and return - if _, err = io.Copy(w, bytes.NewReader(in)); err != nil { - log.Println(err) - } - return - } - - out, err := decrypt(in, key) - if err != nil { - log.Println(err) - out = in - } - if _, err := io.Copy(w, bytes.NewReader(out)); err != nil { - log.Println(err) - } -} - -// recursiveDecrypt will try and recursively decrypt files -// if 'key' is provided then it will decrypt all encrypted files with given key -// otherwise it will find key based on file location -// if error is generated in finding key or in decryption then it will continue with next file -// function will only return early if it failed to read/write files -func recursiveDecrypt(target string, givenKey []byte) error { - var decErrors []string - err := filepath.WalkDir(target, func(path string, entry fs.DirEntry, err error) error { - // always return on error - if err != nil { - return err - } - - // only process files - if entry.IsDir() { - // skip .git directory - if entry.Name() == ".git" { - return fs.SkipDir - } - return nil - } - - file, err := os.OpenFile(path, os.O_RDWR, 0) - if err != nil { - return err - } - defer file.Close() - - // for optimisation only read required chunk of the file and verify if encrypted - chunk := make([]byte, len(defaultPrefix)) - _, err = file.Read(chunk) - if err != nil && err != io.EOF { - return err - } - - if !bytes.HasPrefix(chunk, prefix) { - return nil - } - - key := givenKey - if len(key) == 0 { - key, err = keyLoader(path) - if err != nil { - // continue with next file - decErrors = append(decErrors, fmt.Sprintf("unable to find key file:%s err:%s", path, err)) - return nil - } - } - - // read entire file from the beginning - file.Seek(0, io.SeekStart) - in, err := io.ReadAll(file) - if err != nil { - return err - } - - out, err := decrypt(in, key) - if err != nil { - // continue with next file - decErrors = append(decErrors, fmt.Sprintf("unable to decrypt file:%s err:%s", path, err)) - return nil - } - - if err := file.Truncate(0); err != nil { - return err - } - if _, err := file.Seek(0, io.SeekStart); err != nil { - return err - } - if _, err := file.Write(out); err != nil { - return err - } - return nil - }) - if err != nil { - return err - } - if len(decErrors) > 0 { - for _, e := range decErrors { - log.Println(e) - } - return fmt.Errorf("unable to decrypt some files") - } - - return nil -} - -func encrypt(b, key []byte) ([]byte, error) { - b = compress(b) - out, err := siv.Encrypt(nil, key, b, nil) - if err != nil { - return nil, err - } - var buf []byte - buf = append(buf, defaultPrefix...) - b64 := encode(out) - for len(b64) > 0 { - l := 76 - if len(b64) < 76 { - l = len(b64) - } - buf = append(buf, b64[0:l]...) - buf = append(buf, '\n') - b64 = b64[l:] - } - return buf, nil -} - -func compress(b []byte) []byte { - var buf bytes.Buffer - zw := gzip.NewWriter(&buf) - _, err := zw.Write(b) - if err != nil { - log.Fatal(err) - } - if err := zw.Close(); err != nil { - log.Fatal(err) - } - return buf.Bytes() -} - -func decompress(b []byte) []byte { - zr, err := gzip.NewReader(bytes.NewReader(b)) - if err != nil { - log.Fatal(err) - } - b, err = io.ReadAll(zr) - if err != nil { - log.Fatal(err) - } - if err := zr.Close(); err != nil { - log.Fatal(err) - } - return b -} - -func encode(decoded []byte) []byte { - b64 := make([]byte, base64.StdEncoding.EncodedLen(len(decoded))) - base64.StdEncoding.Encode(b64, decoded) - return b64 -} - -func decode(encoded []byte) ([]byte, error) { - decoded := make([]byte, len(encoded)) - i, err := base64.StdEncoding.Decode(decoded, encoded) - if err != nil { - return nil, err - } - return decoded[0:i], nil -} - -func decrypt(enc []byte, priv []byte) ([]byte, error) { - // strip prefix and any comment up to end of line - spl := bytes.SplitN(enc, []byte("\n"), 2) - if len(spl) != 2 { - return nil, errors.New("couldn't split on end of line") - } - b64encoded := spl[1] - b64decoded, err := decode(b64encoded) - if err != nil { - return nil, err - } - decrypted, err := siv.Decrypt(priv, b64decoded, nil) - if err != nil { - return nil, err - } - decrypted = decompress(decrypted) - return decrypted, nil -} - -// key returns private key and error -func key(filename string) ([]byte, error) { - keyID, err := findKey(filename) - if err != nil { - return []byte{}, err - } - - err = kr.Load() - if err != nil { - return []byte{}, err - } - - key, err := kr.Key(keyID) - if err != nil { - return []byte{}, err - } - - return key, nil -} - -func findKey(filename string) ([]byte, error) { - path := filepath.Dir(filename) - for { - if fi, err := os.Stat(path); err == nil && fi.IsDir() { - keyFilename := filepath.Join(path, ".strongbox-keyid") - if keyFile, err := os.Stat(keyFilename); err == nil && !keyFile.IsDir() { - return readKey(keyFilename) - } - } - if path == "." { - break - } - path = filepath.Dir(path) - } - return []byte{}, fmt.Errorf("failed to find key id for file %s", filename) -} - -func readKey(filename string) ([]byte, error) { - fp, err := os.ReadFile(filename) - if err != nil { - return []byte{}, err - } - - b64 := strings.TrimSpace(string(fp)) - b, err := decode([]byte(b64)) - if err != nil { - return []byte{}, err - } - if len(b) != 32 { - return []byte{}, fmt.Errorf("unexpected key length %d", len(b)) - } - return b, nil -} diff --git a/strongbox_test.go b/strongbox_test.go index 12bd854..1d209b3 100644 --- a/strongbox_test.go +++ b/strongbox_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/uw-labs/strongbox/strongbox-lib" ) var ( @@ -59,10 +60,10 @@ func TestMultipleClean(t *testing.T) { assert := assert.New(t) var cleaned bytes.Buffer - clean(bytes.NewReader(plain), &cleaned, "") + strongbox.Clean(bytes.NewReader(plain), &cleaned, "") var doubleCleaned bytes.Buffer - clean(bytes.NewReader(cleaned.Bytes()), &doubleCleaned, "") + strongbox.Clean(bytes.NewReader(cleaned.Bytes()), &doubleCleaned, "") assert.Equal(cleaned.String(), doubleCleaned.String()) } @@ -71,7 +72,7 @@ func TestSmudgeAlreadyPlaintext(t *testing.T) { assert := assert.New(t) var smudged bytes.Buffer - smudge(bytes.NewReader(plain), &smudged, "") + strongbox.Smudge(bytes.NewReader(plain), &smudged, "") assert.Equal(string(plain), smudged.String()) } @@ -80,14 +81,14 @@ func TestRoundTrip(t *testing.T) { assert := assert.New(t) var cleaned bytes.Buffer - clean(bytes.NewReader(plain), &cleaned, "") + strongbox.Clean(bytes.NewReader(plain), &cleaned, "") fmt.Printf("%s", string(cleaned.String())) assert.NotEqual(plain, cleaned.Bytes()) var smudged bytes.Buffer - smudge(bytes.NewReader(cleaned.Bytes()), &smudged, "") + strongbox.Smudge(bytes.NewReader(cleaned.Bytes()), &smudged, "") assert.Equal(string(plain), smudged.String()) } @@ -96,10 +97,10 @@ func TestDeterministic(t *testing.T) { assert := assert.New(t) var cleaned1 bytes.Buffer - clean(bytes.NewReader(plain), &cleaned1, "") + strongbox.Clean(bytes.NewReader(plain), &cleaned1, "") var cleaned2 bytes.Buffer - clean(bytes.NewReader(plain), &cleaned2, "") + strongbox.Clean(bytes.NewReader(plain), &cleaned2, "") assert.Equal(cleaned1.String(), cleaned2.String()) } @@ -107,9 +108,9 @@ func TestDeterministic(t *testing.T) { func BenchmarkRoundTripPlain(b *testing.B) { for n := 0; n < b.N; n++ { var cleaned bytes.Buffer - clean(bytes.NewReader(plain), &cleaned, "") + strongbox.Clean(bytes.NewReader(plain), &cleaned, "") var smudged bytes.Buffer - smudge(bytes.NewReader(cleaned.Bytes()), &smudged, "") + strongbox.Smudge(bytes.NewReader(cleaned.Bytes()), &smudged, "") } }