-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* filepath is used instead of more hacky methods * Getting reader from uri is now abstracted to a separate function * More errors are handled, and more gracefully
- Loading branch information
Showing
1 changed file
with
76 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,113 @@ | ||
package main | ||
|
||
import ( | ||
"flag" | ||
"fmt" | ||
"github.com/pkg/errors" | ||
"io" | ||
"net/http" | ||
"roob.re/omemo-wget/aesgcm" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
"os" | ||
"path/filepath" | ||
"roob.re/omemo-wget/aesgcm" | ||
"strings" | ||
) | ||
|
||
func main() { | ||
if len(os.Args) < 2 { | ||
_, _ = fmt.Fprintf(os.Stderr, "Usage: %s <path/to/file|https://url#hash> [hash]\n", os.Args[0]) | ||
os.Exit(1) | ||
} | ||
|
||
str := os.Args[1] | ||
var in io.ReadCloser | ||
var hash string | ||
var filename string | ||
outfile := flag.String("o", "", "out file. Use '-' for stdout. Defaults to guess from input uri/path") | ||
flag.Parse() | ||
|
||
parts := strings.Split(str, "/") | ||
filename = parts[len(parts)-1] | ||
|
||
if strings.HasPrefix(str, "aesgcm://") { | ||
str = strings.Replace(str, "aesgcm://", "https://", 1) | ||
if flag.NArg() < 1 { | ||
stderrExit(errors.New(fmt.Sprintf("Usage: %s <path/to/file|uri#hash> [hash] [-o out]\n", os.Args[0])), 1) | ||
return | ||
} | ||
|
||
if strings.HasPrefix(str, "https://") { | ||
resp, err := http.Get(str) | ||
if err != nil { | ||
stderrExit(err, 2) | ||
} | ||
uri := flag.Args()[0] | ||
|
||
in = resp.Body | ||
parts := strings.Split(filename, "#") | ||
if len(parts) < 2 { | ||
stderrExit(errors.New("Malformed aesgcm url"), 5) | ||
} | ||
hash = parts[len(parts)-1] | ||
filename = parts[len(parts) - 2] | ||
var hash string | ||
parts := strings.Split(uri, "#") | ||
path := parts[0] | ||
if len(parts) >= 2 { | ||
hash = parts[1] | ||
} else if flag.NArg() >= 2 { | ||
hash = flag.Args()[1] | ||
} else { | ||
if len(os.Args) < 3 { | ||
_, _ = fmt.Fprintf(os.Stderr, "Hash is mandatory if not included in the url\n\nUsage: %s <path/to/file|https://url#hash> [hash]\n", os.Args[0]) | ||
os.Exit(3) | ||
} | ||
hash = os.Args[2] | ||
|
||
var err error | ||
in, err = os.Open(str) | ||
if err != nil { | ||
stderrExit(err, 2) | ||
} | ||
stderrExit(errors.New("hash must be either included in the url (after the # character) or provided as a second argument"), 2) | ||
return | ||
} | ||
|
||
outfile := filename | ||
pos := strings.IndexAny(filename, ".") | ||
if pos != -1 { | ||
outfile = filename[:pos] + "_decrypted" + filename[pos:] | ||
} else { | ||
outfile += "_decrypted" | ||
in, err := open(path) | ||
if err != nil { | ||
stderrExit(err, 3) | ||
return | ||
} | ||
|
||
fileContents, err := ioutil.ReadAll(in) | ||
if err != nil { | ||
panic(err) | ||
stderrExit(err, 4) | ||
return | ||
} | ||
in.Close() | ||
_ = in.Close() | ||
|
||
decryptedContents, err := aesgcm.Decrypt(fileContents, hash) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
out, err := os.Create(outfile) | ||
var out io.WriteCloser | ||
switch *outfile { | ||
case "-": | ||
out = os.Stdout | ||
case "": | ||
// Generate a suitable name | ||
basename := filepath.Base(path) | ||
ext := filepath.Ext(basename) | ||
|
||
for _, err := os.Stat(basename); err == nil; _, err = os.Stat(basename) { | ||
basename = strings.Replace(basename, ext, "_decrypted"+ext, 1) | ||
} | ||
*outfile = basename | ||
fallthrough | ||
default: | ||
f, err := os.Create(*outfile) | ||
if err != nil { | ||
stderrExit(errors.New("error creating output file: "+err.Error()), 6) | ||
return | ||
} | ||
out = f | ||
} | ||
|
||
_, err = out.Write(decryptedContents) | ||
if err != nil { | ||
stderrExit(err, 5) | ||
stderrExit(errors.New("error writing to output file: "+err.Error()), 6) | ||
return | ||
} | ||
out.Write(decryptedContents) | ||
out.Close() | ||
_ = out.Close() | ||
} | ||
|
||
func open(uri string) (io.ReadCloser, error) { | ||
switch true { | ||
case strings.HasPrefix(uri, "aesgcm"): | ||
uri = strings.Replace(uri, "aesgcm", "https", 1) | ||
fallthrough | ||
case strings.HasPrefix(uri, "https"): | ||
resp, err := http.Get(uri) | ||
if err != nil { | ||
return nil, errors.New(fmt.Sprintf("error fetching '%s': %s", uri, err.Error())) | ||
} | ||
|
||
return resp.Body, nil | ||
default: | ||
file, err := os.Open(uri) | ||
if err != nil { | ||
return nil, errors.New(fmt.Sprintf("could not open file '%s': %s", uri, err.Error())) | ||
} | ||
|
||
return file, nil | ||
} | ||
} | ||
|
||
func stderrExit(e error, code int) { | ||
_, _ = fmt.Fprint(os.Stderr, e) | ||
_, _ = fmt.Fprint(os.Stderr, e.Error()+"\n") | ||
os.Exit(code) | ||
} |