diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..717ac71 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ + +__debug_bin diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..39c6f9d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,20 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Launch Package", + "type": "go", + "request": "launch", + "mode": "debug", + "program": "${workspaceFolder}", + "args": [ + "-k", "xxx", + "-s", "xxx", + "-a", "xxx" + ] + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..0b36c2a --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +# FTX AUTO LENDING + +This is a CLI tool that allows to automatically compound payouts earned from lending coins. +It will check for newly available funds every hour (5 min after hour elapsed to be precise) and automatically update the lending offer to the max. size that can be lend out on the account. + +## Example +ftx-auto-lend --key xxxxxxx --secret yyyyyyyy --subaccount mylendingsubacc --coin USD + diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1c4d3eb --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/trading-peter/ftx-auto-lend + +go 1.16 + +require ( + github.com/akamensky/argparse v1.2.2 + github.com/avast/retry-go v3.0.0+incompatible + github.com/grishinsana/goftx v1.2.0 + github.com/robfig/cron/v3 v3.0.0 + github.com/shopspring/decimal v1.2.0 + go.uber.org/ratelimit v0.2.0 +) + +replace github.com/grishinsana/goftx => ../goftx diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ca36fff --- /dev/null +++ b/go.sum @@ -0,0 +1,51 @@ +github.com/akamensky/argparse v1.2.2 h1:P17T0ZjlUNJuWTPPJ2A5dM1wxarHgHqfYH+AZTo2xQA= +github.com/akamensky/argparse v1.2.2/go.mod h1:S5kwC7IuDcEr5VeXtGPRVZ5o/FdhcMlQz4IZQuw64xA= +github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 h1:MzBOUgng9orim59UnfUTLRjMpd09C5uEVQ6RPGeCaVI= +github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129/go.mod h1:rFgpPQZYZ8vdbc+48xibu8ALc3yeyd64IhHS+PU6Yyg= +github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= +github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHSxpiH9JdtuBj0= +github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/go-numb/go-ftx v0.0.0-20200829181514-3144aa68f505/go.mod h1:rjG/Mg/la6U9w0NN/oaMZkgCpEQgseKPOl6EkvYkjCw= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/robfig/cron/v3 v3.0.0 h1:kQ6Cb7aHOHTSzNVNEhmp8EcWKLb4CbiMW9h9VyIhO4E= +github.com/robfig/cron/v3 v3.0.0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.16.0/go.mod h1:YOKImeEosDdBPnxc0gy7INqi3m1zK6A+xl6TwOBhHCA= +github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/ratelimit v0.2.0 h1:UQE2Bgi7p2B85uP5dC2bbRtig0C+OeNRnNEafLjsLPA= +go.uber.org/ratelimit v0.2.0/go.mod h1:YYBV4e4naJvhpitQrWJu1vCpgB7CboMe0qhltKt6mUg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..f6be2e5 --- /dev/null +++ b/logger.go @@ -0,0 +1,32 @@ +package main + +import ( + "io/ioutil" + "log" + "os" +) + +var ( + Trace *log.Logger + Info *log.Logger + Warning *log.Logger + Error *log.Logger +) + +func init() { + Trace = log.New(ioutil.Discard, + "TRACE: ", + log.LUTC|log.Ldate|log.Ltime|log.Lshortfile) + + Info = log.New(os.Stdout, + "INFO: ", + log.LUTC|log.Ldate|log.Ltime|log.Lshortfile) + + Warning = log.New(os.Stdout, + "WARNING: ", + log.LUTC|log.Ldate|log.Ltime|log.Lshortfile) + + Error = log.New(os.Stderr, + "ERROR: ", + log.LUTC|log.Ldate|log.Ltime|log.Lshortfile) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..1c9b489 --- /dev/null +++ b/main.go @@ -0,0 +1,133 @@ +package main + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/akamensky/argparse" + "github.com/avast/retry-go" + "github.com/grishinsana/goftx" + "github.com/robfig/cron/v3" + "github.com/shopspring/decimal" + "go.uber.org/ratelimit" +) + +var ( + limiter ratelimit.Limiter = ratelimit.New(30) + client *goftx.Client +) + +func main() { + job := cron.New() + parser := argparse.NewParser("ftx-auto-lend", "Automatically compounds lending payouts.") + apiKey := parser.String("k", "key", &argparse.Options{Required: true, Help: "API key"}) + apiSecret := parser.String("s", "secret", &argparse.Options{Required: true, Help: "API secret"}) + subAcc := parser.String("a", "subaccount", &argparse.Options{Required: false, Help: "Subaccount"}) + coin := strings.ToUpper(*parser.String("c", "coin", &argparse.Options{Required: false, Default: "USD", Help: "Coin to lend"})) + strRate := strings.ToUpper(*parser.String("r", "min-rate", &argparse.Options{Required: false, Default: "0.000001", Help: "Coin to lend"})) + err := parser.Parse(os.Args) + + if err != nil { + fmt.Print(parser.Usage(err)) + return + } + + if coin == "" { + coin = "USD" + } + + if strRate == "" { + strRate = "0.000001" + } + + minRate, err := decimal.NewFromString(strRate) + + if err != nil { + Error.Fatal("Min Rate: Invalid number") + } + + client = goftx.New( + goftx.WithAuth(*apiKey, *apiSecret), + goftx.WithSubaccount(*subAcc), + ) + + _, err = client.GetAccountInformation() + + if err != nil { + Error.Fatalln("It seems like the supplied API key is wrong. Please check and try again") + } + + job.Start() + + job.AddFunc("5 * * * *", func() { + lendable, delta, err := getMaxLendingAmount(coin) + + if err != nil { + Error.Println(err) + } + + if delta.Equal(decimal.Zero) || delta.LessThan(decimal.Zero) { + Info.Println("No increase in funds to update lending offer.") + return + } + + Info.Printf("New lendable amount of %s is %s (+%s).", coin, lendable, delta) + Info.Printf("Updating lending offer with a minimum rate of %s.", minRate) + + updateLendingOffer(coin, lendable, minRate) + }) + + fmt.Printf("Will attempt to update your lending offer for %s each hour.\nPress any key if you want to stop and exit the program.", coin) + fmt.Scanln() + fmt.Println("Bye!") +} + +func updateLendingOffer(coin string, amount decimal.Decimal, minRate decimal.Decimal) (err error) { + err = retry.Do( + func() error { + limiter.Take() + err := client.SubmitLendingOffer(coin, amount, minRate) + + if err != nil { + fmt.Printf("%+v\n", err) + return err + } + + return nil + }, + retry.Delay(time.Minute), + retry.Attempts(10), + retry.DelayType(retry.FixedDelay), + ) + + return +} + +func getMaxLendingAmount(coin string) (lendable decimal.Decimal, delta decimal.Decimal, err error) { + err = retry.Do( + func() error { + limiter.Take() + resp, err := client.GetLendingInfo() + + if err != nil { + return err + } + + for i := range resp { + if resp[i].Coin == "USD" { + lendable = resp[i].Lendable + delta = resp[i].Lendable.Sub(resp[i].Offered) + } + } + + return nil + }, + retry.Delay(time.Minute), + retry.Attempts(10), + retry.DelayType(retry.FixedDelay), + ) + + return +} diff --git a/vendor/github.com/akamensky/argparse/.gitignore b/vendor/github.com/akamensky/argparse/.gitignore new file mode 100644 index 0000000..f479089 --- /dev/null +++ b/vendor/github.com/akamensky/argparse/.gitignore @@ -0,0 +1,16 @@ +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ + +.idea/ \ No newline at end of file diff --git a/vendor/github.com/akamensky/argparse/.travis.yml b/vendor/github.com/akamensky/argparse/.travis.yml new file mode 100644 index 0000000..ff0fcf7 --- /dev/null +++ b/vendor/github.com/akamensky/argparse/.travis.yml @@ -0,0 +1,9 @@ +language: go +sudo: false +go: + - "1.x" +before_install: + - go get github.com/mattn/goveralls +script: + - go test -v . + - $GOPATH/bin/goveralls -service=travis-ci diff --git a/vendor/github.com/akamensky/argparse/LICENSE b/vendor/github.com/akamensky/argparse/LICENSE new file mode 100644 index 0000000..f1831c5 --- /dev/null +++ b/vendor/github.com/akamensky/argparse/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Alexey Kamenskiy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/akamensky/argparse/README.md b/vendor/github.com/akamensky/argparse/README.md new file mode 100644 index 0000000..9b7c3e7 --- /dev/null +++ b/vendor/github.com/akamensky/argparse/README.md @@ -0,0 +1,199 @@ +# Golang argparse + +[![GoDoc](https://godoc.org/github.com/akamensky/argparse?status.svg)](https://godoc.org/github.com/akamensky/argparse) [![Go Report Card](https://goreportcard.com/badge/github.com/akamensky/argparse)](https://goreportcard.com/report/github.com/akamensky/argparse) [![Coverage Status](https://coveralls.io/repos/github/akamensky/argparse/badge.svg?branch=master)](https://coveralls.io/github/akamensky/argparse?branch=master) [![Build Status](https://travis-ci.org/akamensky/argparse.svg?branch=master)](https://travis-ci.org/akamensky/argparse) + +Let's be honest -- Go's standard command line arguments parser `flag` terribly sucks. +It cannot come anywhere close to the Python's `argparse` module. This is why this project exists. + +The goal of this project is to bring ease of use and flexibility of `argparse` to Go. +Which is where the name of this package comes from. + +#### Installation + +To install and start using argparse simply do: + +``` +$ go get -u -v github.com/akamensky/argparse +``` + +You are good to go to write your first command line tool! +See Usage and Examples sections for information how you can use it + +#### Usage + +To start using argparse in Go see above instructions on how to install. +From here on you can start writing your first program. +Please check out examples from `examples/` directory to see how to use it in various ways. + +Here is basic example of print command (from `examples/print/` directory): +```go +package main + +import ( + "fmt" + "github.com/akamensky/argparse" + "os" +) + +func main() { + // Create new parser object + parser := argparse.NewParser("print", "Prints provided string to stdout") + // Create string flag + s := parser.String("s", "string", &argparse.Options{Required: true, Help: "String to print"}) + // Parse input + err := parser.Parse(os.Args) + if err != nil { + // In case of error print error and print usage + // This can also be done by passing -h or --help flags + fmt.Print(parser.Usage(err)) + } + // Finally print the collected string + fmt.Println(*s) +} +``` + +#### Basic options + +Create your parser instance and pass it program name and program description. +Program name if empty will be taken from `os.Args[0]` (which is okay in most cases). +Description can be as long as you wish and will be used in `--help` output +```go +parser := argparse.NewParser("progname", "Description of my awesome program. It can be as long as I wish it to be") +``` + +String will allow you to get a string from arguments, such as `$ progname --string "String content"` +```go +var myString *string = parser.String("s", "string", ...) +``` + +Selector works same as a string, except that it will only allow specific values. +For example like this `$ progname --debug-level WARN` +```go +var mySelector *string = parser.Selector("d", "debug-level", []string{"INFO", "DEBUG", "WARN"}, ...) +``` + +StringList allows to collect multiple string values into the slice of strings by repeating same flag multiple times. +Such as `$ progname --string hostname1 --string hostname2 -s hostname3` +```go +var myStringList *[]string = parser.StringList("s", "string", ...) +``` + +List allows to collect multiple values into the slice of strings by repeating same flag multiple times +(at fact - it is an Alias of StringList). +Such as `$ progname --host hostname1 --host hostname2 -H hostname3` +```go +var myList *[]string = parser.List("H", "hostname", ...) +``` + +Flag will tell you if a simple flag was set on command line (true is set, false is not). +For example `$ progname --force` +```go +var myFlag *bool = parser.Flag("f", "force", ...) +``` + +FlagCounter will tell you the number of times that simple flag was set on command line +(integer greater than or equal to 1 or 0 if not set). +For example `$ progname -vv --verbose` +```go +var myFlagCounter *int = parser.FlagCounter("v", "verbose", ...) +``` + +Int will allow you to get a decimal integer from arguments, such as `$ progname --integer "42"` +```go +var myInteger *int = parser.Int("i", "integer", ...) +``` + +IntList allows to collect multiple decimal integer values into the slice of integers by repeating same flag multiple times. +Such as `$ progname --integer 42 --integer +51 -i -1` +```go +var myIntegerList *[]int = parser.IntList("i", "integer", ...) +``` + +Float will allow you to get a floating point number from arguments, such as `$ progname --float "37.2"` +```go +var myFloat *float64 = parser.Float("f", "float", ...) +``` + +FloatList allows to collect multiple floating point number values into the slice of floats by repeating same flag multiple times. +Such as `$ progname --float 42 --float +37.2 -f -1.0` +```go +var myFloatList *[]float64 = parser.FloatList("f", "float", ...) +``` + +File will validate that file exists and will attempt to open it with provided privileges. +To be used like this `$ progname --log-file /path/to/file.log` +```go +var myLogFile *os.File = parser.File("l", "log-file", os.O_RDWR, 0600, ...) +``` + +FileList allows to collect files into the slice of files by repeating same flag multiple times. +FileList will validate that files exists and will attempt to open them with provided privileges. +To be used like this `$ progname --log-file /path/to/file.log --log-file /path/to/file_cpy.log -l /another/path/to/file.log` +```go +var myLogFiles *[]os.File = parser.FileList("l", "log-file", os.O_RDWR, 0600, ...) +``` + +You can implement sub-commands in your CLI using `parser.NewCommand()` or go even deeper with `command.NewCommand()`. +Since parser inherits from command, every command supports exactly same options as parser itself, +thus allowing to add arguments specific to that command or more global arguments added on parser itself! + +#### Basic Option Structure + +The `Option` structure is declared at `argparse.go`: +```go +type Options struct { + Required bool + Validate func(args []string) error + Help string + Default interface{} +} +``` + +You can Set `Required` to let it know if it should ask for arguments. +Or you can set `Validate` as a lambda function to make it know while value is valid. +Or you can set `Help` for your beautiful help document. +Or you can set `Default` will set the default value if user does not provide a value. + +Example: +``` +dirpath := parser.String("d", "dirpath", + &argparse.Options{ + Require: false, + Help: "the input files' folder path", + Default: "input", + }) +``` + +#### Caveats + +There are a few caveats (or more like design choices) to know about: +* Shorthand arguments MUST be a single character. Shorthand arguments are prepended with single dash `"-"` +* If not convenient shorthand argument can be completely skipped by passing empty string `""` as first argument +* Shorthand arguments ONLY for `parser.Flag()` and `parser.FlagCounter()` can be combined into single argument same as `ps -aux`, `rm -rf` or `lspci -vvk` +* Long arguments must be specified and cannot be empty. They are prepended with double dash `"--"` +* You cannot define two same arguments. Only first one will be used. For example doing `parser.Flag("t", "test", nil)` followed by `parser.String("t", "test2", nil)` will not work as second `String` argument will be ignored (note that both have `"t"` as shorthand argument). However since it is case-sensitive library, you can work arounf it by capitalizing one of the arguments +* There is a pre-defined argument for `-h|--help`, so from above attempting to define any argument using `h` as shorthand will fail +* `parser.Parse()` returns error in case of something going wrong, but it is not expected to cover ALL cases +* Any arguments that left un-parsed will be regarded as error + + +#### Contributing + +Can you write in Go? Then this projects needs your help! + +Take a look at open issues, specially the ones tagged as `help-wanted`. +If you have any improvements to offer, please open an issue first to ensure this improvement is discussed. + +There are following tasks to be done: +* Add more examples +* Improve code quality (it is messy right now and could use a major revamp to improve gocyclo report) +* Add more argument options (such as numbers parsing) +* Improve test coverage +* Write a wiki for this project + +However note that the logic outlined in method comments must be preserved +as the the library must stick with backward compatibility promise! + +#### Acknowledgments + +Thanks to Python developers for making a great `argparse` which inspired this package to match for greatness of Go diff --git a/vendor/github.com/akamensky/argparse/argparse.go b/vendor/github.com/akamensky/argparse/argparse.go new file mode 100644 index 0000000..efaf9d5 --- /dev/null +++ b/vendor/github.com/akamensky/argparse/argparse.go @@ -0,0 +1,696 @@ +// Package argparse provides users with more flexible and configurable option for command line arguments parsing. +package argparse + +import ( + "errors" + "fmt" + "os" + "strings" +) + +// DisableDescription can be assigned as a command or arguments description to hide it from the Usage output +const DisableDescription = "DISABLEDDESCRIPTIONWILLNOTSHOWUP" + +//disable help can be invoked from the parse and then needs to be propogated to subcommands +var disableHelp = false + +// Command is a basic type for this package. It represents top level Parser as well as any commands and sub-commands +// Command MUST NOT ever be created manually. Instead one should call NewCommand method of Parser or Command, +// which will setup appropriate fields and call methods that have to be called when creating new command. +type Command struct { + name string + description string + args []*arg + commands []*Command + parsed bool + happened bool + parent *Command + HelpFunc func(c *Command, msg interface{}) string + exitOnHelp bool +} + +// GetName exposes Command's name field +func (o Command) GetName() string { + return o.name +} + +// GetDescription exposes Command's description field +func (o Command) GetDescription() string { + return o.description +} + +// GetArgs exposes Command's args field +func (o Command) GetArgs() (args []Arg) { + for _, arg := range o.args { + args = append(args, arg) + } + return +} + +// GetCommands exposes Command's commands field +func (o Command) GetCommands() []*Command { + return o.commands +} + +// GetParent exposes Command's parent field +func (o Command) GetParent() *Command { + return o.parent +} + +// Help calls the overriddable Command.HelpFunc on itself, +// called when the help argument strings are passed via CLI +func (o *Command) Help(msg interface{}) string { + tempC := o + for tempC.HelpFunc == nil { + if tempC.parent == nil { + return "" + } + tempC = tempC.parent + } + return tempC.HelpFunc(o, msg) +} + +// Parser is a top level object of argparse. It MUST NOT ever be created manually. Instead one should use +// argparse.NewParser() method that will create new parser, propagate necessary private fields and call needed +// functions. +type Parser struct { + Command +} + +// Options are specific options for every argument. They can be provided if necessary. +// Possible fields are: +// +// Options.Required - tells Parser that this argument is required to be provided. +// useful when specific Command requires some data provided. +// +// Options.Validate - is a validation function. Using this field anyone can implement a custom validation for argument. +// If provided and argument is present, then function is called. If argument also consumes any following values +// (e.g. as String does), then these are provided as args to function. If validation fails the error must be returned, +// which will be the output of `Parser.Parse` method. +// +// Options.Help - A help message to be displayed in Usage output. Can be of any length as the message will be +// formatted to fit max screen width of 100 characters. +// +// Options.Default - A default value for an argument. This value will be assigned to the argument at the end of parsing +// in case if this argument was not supplied on command line. File default value is a string which it will be open with +// provided options. In case if provided value type does not match expected, the error will be returned on run-time. +type Options struct { + Required bool + Validate func(args []string) error + Help string + Default interface{} +} + +// NewParser creates new Parser object that will allow to add arguments for parsing +// It takes program name and description which will be used as part of Usage output +// Returns pointer to Parser object +func NewParser(name string, description string) *Parser { + p := &Parser{} + + p.name = name + p.description = description + + p.args = make([]*arg, 0) + p.commands = make([]*Command, 0) + + p.help("h", "help") + p.exitOnHelp = true + p.HelpFunc = (*Command).Usage + + return p +} + +// NewCommand will create a sub-command and propagate all necessary fields. +// All commands are always at the beginning of the arguments. +// Parser can have commands and those commands can have sub-commands, +// which allows for very flexible workflow. +// All commands are considered as required and all commands can have their own argument set. +// Commands are processed Parser -> Command -> sub-Command. +// Arguments will be processed in order of sub-Command -> Command -> Parser. +func (o *Command) NewCommand(name string, description string) *Command { + c := new(Command) + c.name = name + c.description = description + c.parsed = false + c.parent = o + if !disableHelp { + c.help("h", "help") + c.exitOnHelp = true + c.HelpFunc = (*Command).Usage + } + + if o.commands == nil { + o.commands = make([]*Command, 0) + } + + o.commands = append(o.commands, c) + + return c +} + +// DisableHelp removes any help arguments from the commands list of arguments +// This prevents prevents help from being parsed or invoked from the argument list +func (o *Parser) DisableHelp() { + disableHelp = true + for i, arg := range o.args { + if _, ok := arg.result.(*help); ok { + o.args = append(o.args[:i], o.args[i+1:]...) + } + } + for _, com := range o.commands { + for i, comArg := range com.args { + if _, ok := comArg.result.(*help); ok { + com.args = append(com.args[:i], com.args[i+1:]...) + } + } + } +} + +// ExitOnHelp sets the exitOnHelp variable of Parser +func (o *Command) ExitOnHelp(b bool) { + o.exitOnHelp = b + for _, c := range o.commands { + c.ExitOnHelp(b) + } +} + +// SetHelp removes the previous help argument, and creates a new one with the desired sname/lname +func (o *Parser) SetHelp(sname, lname string) { + o.DisableHelp() + o.help(sname, lname) +} + +// Flag Creates new flag type of argument, which is boolean value showing if argument was provided or not. +// Takes short name, long name and pointer to options (optional). +// Short name must be single character, but can be omitted by giving empty string. +// Long name is required. +// Returns pointer to boolean with starting value `false`. If Parser finds the flag +// provided on Command line arguments, then the value is changed to true. +// Set of Flag and FlagCounter shorthand arguments can be combined together such as `tar -cvaf foo.tar foo` +func (o *Command) Flag(short string, long string, opts *Options) *bool { + var result bool + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 1, + opts: opts, + unique: true, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add Flag: %s", err.Error())) + } + + return &result +} + +// FlagCounter Creates new flagCounter type of argument, which is integer value showing the number of times the argument has been provided. +// Takes short name, long name and pointer to options (optional). +// Short name must be single character, but can be omitted by giving empty string. +// Long name is required. +// Returns pointer to integer with starting value `0`. Each time Parser finds the flag +// provided on Command line arguments, the value is incremented by 1. +// Set of FlagCounter and Flag shorthand arguments can be combined together such as `tar -cvaf foo.tar foo` +func (o *Command) FlagCounter(short string, long string, opts *Options) *int { + var result int + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 1, + opts: opts, + unique: false, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add FlagCounter: %s", err.Error())) + } + + return &result +} + +// String creates new string argument, which will return whatever follows the argument on CLI. +// Takes as arguments short name (must be single character or an empty string) +// long name and (optional) options +func (o *Command) String(short string, long string, opts *Options) *string { + var result string + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 2, + opts: opts, + unique: true, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add String: %s", err.Error())) + } + + return &result +} + +// Int creates new int argument, which will attempt to parse following argument as int. +// Takes as arguments short name (must be single character or an empty string) +// long name and (optional) options. +// If parsing fails parser.Parse() will return an error. +func (o *Command) Int(short string, long string, opts *Options) *int { + var result int + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 2, + opts: opts, + unique: true, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add Int: %s", err.Error())) + } + + return &result +} + +// Float creates new float argument, which will attempt to parse following argument as float64. +// Takes as arguments short name (must be single character or an empty string) +// long name and (optional) options. +// If parsing fails parser.Parse() will return an error. +func (o *Command) Float(short string, long string, opts *Options) *float64 { + var result float64 + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 2, + opts: opts, + unique: true, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add Float: %s", err.Error())) + } + + return &result +} + +// File creates new file argument, which is when provided will check if file exists or attempt to create it +// depending on provided flags (same as for os.OpenFile). +// It takes same as all other arguments short and long names, additionally it takes flags that specify +// in which mode the file should be open (see os.OpenFile for details on that), file permissions that +// will be applied to a file and argument options. +// Returns a pointer to os.File which will be set to opened file on success. On error the Parser.Parse +// will return error and the pointer might be nil. +func (o *Command) File(short string, long string, flag int, perm os.FileMode, opts *Options) *os.File { + var result os.File + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 2, + opts: opts, + unique: true, + fileFlag: flag, + filePerm: perm, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add File: %s", err.Error())) + } + + return &result +} + +// List creates new list argument. This is the argument that is allowed to be present multiple times on CLI. +// All appearances of this argument on CLI will be collected into the list of default type values ​​which is strings. If no argument +// provided, then the list is empty. Takes same parameters as String +// Returns a pointer the list of strings. +func (o *Command) List(short string, long string, opts *Options) *[]string { + return o.StringList(short, long, opts) +} + +// StringList creates new string list argument. This is the argument that is allowed to be present multiple times on CLI. +// All appearances of this argument on CLI will be collected into the list of strings. If no argument +// provided, then the list is empty. Takes same parameters as String +// Returns a pointer the list of strings. +func (o *Command) StringList(short string, long string, opts *Options) *[]string { + result := make([]string, 0) + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 2, + opts: opts, + unique: false, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add StringList: %s", err.Error())) + } + + return &result +} + +// IntList creates new integer list argument. This is the argument that is allowed to be present multiple times on CLI. +// All appearances of this argument on CLI will be collected into the list of integers. If no argument +// provided, then the list is empty. Takes same parameters as Int +// Returns a pointer the list of integers. +func (o *Command) IntList(short string, long string, opts *Options) *[]int { + result := make([]int, 0) + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 2, + opts: opts, + unique: false, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add IntList: %s", err.Error())) + } + + return &result +} + +// FloatList creates new float list argument. This is the argument that is allowed to be present multiple times on CLI. +// All appearances of this argument on CLI will be collected into the list of float64 values. If no argument +// provided, then the list is empty. Takes same parameters as Float +// Returns a pointer the list of float64 values. +func (o *Command) FloatList(short string, long string, opts *Options) *[]float64 { + result := make([]float64, 0) + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 2, + opts: opts, + unique: false, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add FloatList: %s", err.Error())) + } + + return &result +} + +// FileList creates new file list argument. This is the argument that is allowed to be present multiple times on CLI. +// All appearances of this argument on CLI will be collected into the list of os.File values. If no argument +// provided, then the list is empty. Takes same parameters as File +// Returns a pointer the list of os.File values. +func (o *Command) FileList(short string, long string, flag int, perm os.FileMode, opts *Options) *[]os.File { + result := make([]os.File, 0) + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 2, + opts: opts, + unique: false, + fileFlag: flag, + filePerm: perm, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add FileList: %s", err.Error())) + } + + return &result +} + +// Selector creates a selector argument. Selector argument works in the same way as String argument, with +// the difference that the string value must be from the list of options provided by the program. +// Takes short and long names, argument options and a slice of strings which are allowed values +// for CLI argument. +// Returns a pointer to a string. If argument is not required (as in argparse.Options.Required), +// and argument was not provided, then the string is empty. +func (o *Command) Selector(short string, long string, options []string, opts *Options) *string { + var result string + + a := &arg{ + result: &result, + sname: short, + lname: long, + size: 2, + opts: opts, + unique: true, + selector: &options, + } + + if err := o.addArg(a); err != nil { + panic(fmt.Errorf("unable to add Selector: %s", err.Error())) + } + + return &result +} + +// message2String puts msg in result string +// done boolean indicates if result is ready to be returned +// Accepts an interface that can be error, string or fmt.Stringer that will be prepended to a message. +// All other interface types will be ignored +func message2String(msg interface{}) (string, bool) { + var result string + if msg != nil { + switch msg.(type) { + case subCommandError: + result = fmt.Sprintf("%s\n", msg.(error).Error()) + if msg.(subCommandError).cmd != nil { + result += msg.(subCommandError).cmd.Usage(nil) + } + return result, true + case error: + result = fmt.Sprintf("%s\n", msg.(error).Error()) + case string: + result = fmt.Sprintf("%s\n", msg.(string)) + case fmt.Stringer: + result = fmt.Sprintf("%s\n", msg.(fmt.Stringer).String()) + } + } + return result, false +} + +// getPrecedingCommands - collects info on command chain from root to current (o *Command) and all arguments in this chain +func (o *Command) getPrecedingCommands(chain *[]string, arguments *[]*arg) { + current := o + // Also add arguments + // Get line of commands until root + for current != nil { + *chain = append(*chain, current.name) + if current.args != nil { + *arguments = append(*arguments, current.args...) + } + current = current.parent + } + + // Reverse the slice + last := len(*chain) - 1 + for i := 0; i < len(*chain)/2; i++ { + (*chain)[i], (*chain)[last-i] = (*chain)[last-i], (*chain)[i] + } +} + +// getSubCommands - collects info on subcommands of current command +func (o *Command) getSubCommands(chain *[]string) []Command { + commands := make([]Command, 0) + if o.commands != nil && len(o.commands) > 0 { + *chain = append(*chain, "") + for _, v := range o.commands { + // Skip hidden commands + if v.description == DisableDescription { + continue + } + commands = append(commands, *v) + } + } + return commands +} + +// precedingCommands2Result - puts info about command chain from root to current (o *Command) into result string buffer +func (o *Command) precedingCommands2Result(result string, chain []string, arguments []*arg, maxWidth int) string { + usedHelp := false + leftPadding := len("usage: " + chain[0] + "") + // Add preceding commands + for _, v := range chain { + result = addToLastLine(result, v, maxWidth, leftPadding, true) + } + // Add arguments from this and all preceding commands + for _, v := range arguments { + // Skip arguments that are hidden + if v.opts.Help == DisableDescription { + continue + } + if v.lname == "help" && usedHelp { + } else { + result = addToLastLine(result, v.usage(), maxWidth, leftPadding, true) + } + if v.lname == "help" || v.sname == "h" { + usedHelp = true + } + } + // Add program/Command description to the result + result = result + "\n\n" + strings.Repeat(" ", leftPadding) + result = addToLastLine(result, o.description, maxWidth, leftPadding, true) + result = result + "\n\n" + + return result +} + +// subCommands2Result - puts info about subcommands of current command into result string buffer +func subCommands2Result(result string, commands []Command, maxWidth int) string { + // Add list of sub-commands to the result + if len(commands) > 0 { + cmdContent := "Commands:\n\n" + // Get biggest padding + var cmdPadding int + for _, com := range commands { + if com.description == DisableDescription { + continue + } + if len(" "+com.name+" ") > cmdPadding { + cmdPadding = len(" " + com.name + " ") + } + } + // Now add commands with known padding + for _, com := range commands { + if com.description == DisableDescription { + continue + } + cmd := " " + com.name + cmd = cmd + strings.Repeat(" ", cmdPadding-len(cmd)-1) + cmd = addToLastLine(cmd, com.description, maxWidth, cmdPadding, true) + cmdContent = cmdContent + cmd + "\n" + } + result = result + cmdContent + "\n" + } + return result +} + +// arguments2Result - puts info about all arguments of current command into result string buffer +func arguments2Result(result string, arguments []*arg, maxWidth int) string { + usedHelp := false + if len(arguments) > 0 { + argContent := "Arguments:\n\n" + // Get biggest padding + var argPadding int + // Find biggest padding + for _, argument := range arguments { + if argument.opts.Help == DisableDescription { + continue + } + if len(argument.lname)+9 > argPadding { + argPadding = len(argument.lname) + 9 + } + } + // Now add args with padding + for _, argument := range arguments { + if argument.opts.Help == DisableDescription { + continue + } + if argument.lname == "help" && usedHelp { + } else { + arg := " " + if argument.sname != "" { + arg = arg + "-" + argument.sname + " " + } else { + arg = arg + " " + } + arg = arg + "--" + argument.lname + arg = arg + strings.Repeat(" ", argPadding-len(arg)) + if argument.opts != nil && argument.opts.Help != "" { + arg = addToLastLine(arg, argument.getHelpMessage(), maxWidth, argPadding, true) + } + argContent = argContent + arg + "\n" + } + if argument.lname == "help" || argument.sname == "h" { + usedHelp = true + } + } + result = result + argContent + "\n" + } + return result +} + +// Happened shows whether Command was specified on CLI arguments or not. If Command did not "happen", then +// all its descendant commands and arguments are not parsed. Returns a boolean value. +func (o *Command) Happened() bool { + return o.happened +} + +// Usage returns a multiline string that is the same as a help message for this Parser or Command. +// Since Parser is a Command as well, they work in exactly same way. Meaning that usage string +// can be retrieved for any level of commands. It will only include information about this Command, +// its sub-commands, current Command arguments and arguments of all preceding commands (if any) +// +// Accepts an interface that can be error, string or fmt.Stringer that will be prepended to a message. +// All other interface types will be ignored +func (o *Command) Usage(msg interface{}) string { + for _, cmd := range o.commands { + if cmd.Happened() { + return cmd.Usage(msg) + } + } + + // Stay classy + maxWidth := 80 + // List of arguments from all preceding commands + arguments := make([]*arg, 0) + // Line of commands until root + var chain []string + + // Put message in result + result, done := message2String(msg) + if done { + return result + } + + //collect info about Preceding Commands into chain and arguments + o.getPrecedingCommands(&chain, &arguments) + // If this Command has sub-commands we need their list + commands := o.getSubCommands(&chain) + + // Build usage description from description of preceding commands chain and each of subcommands + result += "usage:" + result = o.precedingCommands2Result(result, chain, arguments, maxWidth) + result = subCommands2Result(result, commands, maxWidth) + // Add list of arguments to the result + result = arguments2Result(result, arguments, maxWidth) + + return result +} + +// Parse method can be applied only on Parser. It takes a slice of strings (as in os.Args) +// and it will process this slice as arguments of CLI (the original slice is not modified). +// Returns error on any failure. In case of failure recommended course of action is to +// print received error alongside with usage information (might want to check which Command +// was active when error happened and print that specific Command usage). +// In case no error returned all arguments should be safe to use. Safety of using arguments +// before Parse operation is complete is not guaranteed. +func (o *Parser) Parse(args []string) error { + subargs := make([]string, len(args)) + copy(subargs, args) + + result := o.parse(&subargs) + unparsed := make([]string, 0) + for _, v := range subargs { + if v != "" { + unparsed = append(unparsed, v) + } + } + if result == nil && len(unparsed) > 0 { + return errors.New("unknown arguments " + strings.Join(unparsed, " ")) + } + + return result +} diff --git a/vendor/github.com/akamensky/argparse/argument.go b/vendor/github.com/akamensky/argparse/argument.go new file mode 100644 index 0000000..9b5aed3 --- /dev/null +++ b/vendor/github.com/akamensky/argparse/argument.go @@ -0,0 +1,516 @@ +package argparse + +import ( + "fmt" + "os" + "reflect" + "strconv" + "strings" +) + +type arg struct { + result interface{} // Pointer to the resulting value + opts *Options // Options + sname string // Short name (in Parser will start with "-" + lname string // Long name (in Parser will start with "--" + size int // Size defines how many args after match will need to be consumed + unique bool // Specifies whether flag should be present only ones + parsed bool // Specifies whether flag has been parsed already + fileFlag int // File mode to open file with + filePerm os.FileMode // File permissions to set a file + selector *[]string // Used in Selector type to allow to choose only one from list of options + parent *Command // Used to get access to specific Command + eqChar bool // This is used if the command is passed in with an equals char as a seperator +} + +// Arg interface provides exporting of arg structure, while exposing it +type Arg interface { + GetOpts() *Options + GetSname() string + GetLname() string +} + +func (o arg) GetOpts() *Options { + return o.opts +} + +func (o arg) GetSname() string { + return o.sname +} + +func (o arg) GetLname() string { + return o.lname +} + +type help struct{} + +// checkLongName if long argumet present. +// checkLongName - returns the argumet's long name number of occurrences and error. +// For long name return value is 0 or 1. +func (o *arg) checkLongName(argument string) int { + // Check for long name only if not empty + if o.lname != "" { + // If argument begins with "--" and next is not "-" then it is a long name + if len(argument) > 2 && strings.HasPrefix(argument, "--") && argument[2] != '-' { + if argument[2:] == o.lname { + return 1 + } + } + } + + return 0 +} + +// checkShortName if argumet present. +// checkShortName - returns the argumet's short name number of occurrences and error. +// For shorthand argument - 0 if there is no occurrences, or count of occurrences. +// Shorthand argument with parametr, mast be the only or last in the argument string. +func (o *arg) checkShortName(argument string) (int, error) { + // Check for short name only if not empty + if o.sname != "" { + + // If argument begins with "-" and next is not "-" then it is a short name + if len(argument) > 1 && strings.HasPrefix(argument, "-") && argument[1] != '-' { + count := strings.Count(argument[1:], o.sname) + switch { + // For args with size 1 (Flag,FlagCounter) multiple shorthand in one argument are allowed + case o.size == 1: + return count, nil + // For args with o.size > 1, shorthand argument is allowed only to complete the sequence of arguments combined into one + case o.size > 1: + if count > 1 { + return count, fmt.Errorf("[%s] argument: The parameter must follow", o.name()) + } + if strings.HasSuffix(argument[1:], o.sname) { + return count, nil + } + //if o.size < 1 - it is an error + default: + return 0, fmt.Errorf("Argument's size < 1 is not allowed") + } + } + } + + return 0, nil +} + +// check if argumet present. +// check - returns the argumet's number of occurrences and error. +// For long name return value is 0 or 1. +// For shorthand argument - 0 if there is no occurrences, or count of occurrences. +// Shorthand argument with parametr, mast be the only or last in the argument string. +func (o *arg) check(argument string) (int, error) { + rez := o.checkLongName(argument) + if rez > 0 { + return rez, nil + } + + return o.checkShortName(argument) +} + +func (o *arg) reduceLongName(position int, args *[]string) { + argument := (*args)[position] + // Check for long name only if not empty + if o.lname != "" { + // If argument begins with "--" and next is not "-" then it is a long name + if len(argument) > 2 && strings.HasPrefix(argument, "--") && argument[2] != '-' { + if o.eqChar { + splitInd := strings.LastIndex(argument, "=") + equalArg := []string{argument[:splitInd], argument[splitInd+1:]} + argument = equalArg[0] + } + if argument[2:] == o.lname { + for i := position; i < position+o.size; i++ { + (*args)[i] = "" + } + } + } + } +} + +func (o *arg) reduceShortName(position int, args *[]string) { + argument := (*args)[position] + // Check for short name only if not empty + if o.sname != "" { + // If argument begins with "-" and next is not "-" then it is a short name + if len(argument) > 1 && strings.HasPrefix(argument, "-") && argument[1] != '-' { + // For args with size 1 (Flag,FlagCounter) we allow multiple shorthand in one + if o.size == 1 { + if strings.Contains(argument[1:], o.sname) { + (*args)[position] = strings.Replace(argument, o.sname, "", -1) + if (*args)[position] == "-" { + (*args)[position] = "" + } + if o.eqChar { + (*args)[position] = "" + } + } + // For all other types it must be separate argument + } else { + if argument[1:] == o.sname { + for i := position; i < position+o.size; i++ { + (*args)[i] = "" + } + } + } + } + } +} + +// clear out already used argument from args at position +func (o *arg) reduce(position int, args *[]string) { + o.reduceLongName(position, args) + o.reduceShortName(position, args) +} + +func (o *arg) parseInt(args []string, argCount int) error { + //data of integer type is for + switch { + //FlagCounter argument + case len(args) < 1: + if o.size > 1 { + return fmt.Errorf("[%s] must be followed by an integer", o.name()) + } + *o.result.(*int) += argCount + case len(args) > 1: + return fmt.Errorf("[%s] followed by too many arguments", o.name()) + //or Int argument with one integer parameter + default: + val, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("[%s] bad integer value [%s]", o.name(), args[0]) + } + *o.result.(*int) = val + } + o.parsed = true + return nil +} + +func (o *arg) parseBool(args []string) error { + //data of bool type is for Flag argument + *o.result.(*bool) = true + o.parsed = true + return nil +} + +func (o *arg) parseFloat(args []string) error { + //data of float64 type is for Float argument with one float parameter + if len(args) < 1 { + return fmt.Errorf("[%s] must be followed by a floating point number", o.name()) + } + if len(args) > 1 { + return fmt.Errorf("[%s] followed by too many arguments", o.name()) + } + + val, err := strconv.ParseFloat(args[0], 64) + if err != nil { + return fmt.Errorf("[%s] bad floating point value [%s]", o.name(), args[0]) + } + + *o.result.(*float64) = val + o.parsed = true + return nil +} + +func (o *arg) parseString(args []string) error { + //data of string type is for String argument with one string parameter + if len(args) < 1 { + return fmt.Errorf("[%s] must be followed by a string", o.name()) + } + if len(args) > 1 { + return fmt.Errorf("[%s] followed by too many arguments", o.name()) + } + + // Selector case + if o.selector != nil { + match := false + for _, v := range *o.selector { + if args[0] == v { + match = true + } + } + if !match { + return fmt.Errorf("bad value for [%s]. Allowed values are %v", o.name(), *o.selector) + } + } + + *o.result.(*string) = args[0] + o.parsed = true + return nil +} + +func (o *arg) parseFile(args []string) error { + //data of os.File type is for File argument with one file name parameter + if len(args) < 1 { + return fmt.Errorf("[%s] must be followed by a path to file", o.name()) + } + if len(args) > 1 { + return fmt.Errorf("[%s] followed by too many arguments", o.name()) + } + + f, err := os.OpenFile(args[0], o.fileFlag, o.filePerm) + if err != nil { + return err + } + + *o.result.(*os.File) = *f + o.parsed = true + return nil +} + +func (o *arg) parseStringList(args []string) error { + //data of []string type is for List and StringList argument with set of string parameters + if len(args) < 1 { + return fmt.Errorf("[%s] must be followed by a string", o.name()) + } + if len(args) > 1 { + return fmt.Errorf("[%s] followed by too many arguments", o.name()) + } + + *o.result.(*[]string) = append(*o.result.(*[]string), args[0]) + o.parsed = true + return nil +} + +func (o *arg) parseIntList(args []string) error { + //data of []int type is for IntList argument with set of int parameters + switch { + case len(args) < 1: + return fmt.Errorf("[%s] must be followed by an integer", o.name()) + case len(args) > 1: + return fmt.Errorf("[%s] followed by too many arguments", o.name()) + } + + val, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("[%s] bad integer value [%s]", o.name(), args[0]) + } + *o.result.(*[]int) = append(*o.result.(*[]int), val) + o.parsed = true + return nil +} + +func (o *arg) parseFloatList(args []string) error { + //data of []float64 type is for FloatList argument with set of int parameters + switch { + case len(args) < 1: + return fmt.Errorf("[%s] must be followed by a floating point number", o.name()) + case len(args) > 1: + return fmt.Errorf("[%s] followed by too many arguments", o.name()) + } + + val, err := strconv.ParseFloat(args[0], 64) + if err != nil { + return fmt.Errorf("[%s] bad floating point value [%s]", o.name(), args[0]) + } + *o.result.(*[]float64) = append(*o.result.(*[]float64), val) + o.parsed = true + return nil +} + +func (o *arg) parseFileList(args []string) error { + //data of []os.File type is for FileList argument with set of int parameters + switch { + case len(args) < 1: + return fmt.Errorf("[%s] must be followed by a path to file", o.name()) + case len(args) > 1: + return fmt.Errorf("[%s] followed by too many arguments", o.name()) + } + f, err := os.OpenFile(args[0], o.fileFlag, o.filePerm) + if err != nil { + //if one of FileList's file opening have been failed, close all other in this list + errs := make([]string, 0, len(*o.result.(*[]os.File))) + for _, f := range *o.result.(*[]os.File) { + if err := f.Close(); err != nil { + //almost unreal, but what if another process closed this file + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + err = fmt.Errorf("while handling error: %v, other errors occured: %#v", err.Error(), errs) + } + *o.result.(*[]os.File) = []os.File{} + return err + } + *o.result.(*[]os.File) = append(*o.result.(*[]os.File), *f) + o.parsed = true + return nil +} + +// To overwrite while testing +// Possibly extend to allow user overriding +var exit func(int) = os.Exit +var print func(...interface{}) (int, error) = fmt.Println + +func (o *arg) parseSomeType(args []string, argCount int) error { + var err error + switch o.result.(type) { + case *help: + print(o.parent.Help(nil)) + if o.parent.exitOnHelp { + exit(0) + } + //data of bool type is for Flag argument + case *bool: + err = o.parseBool(args) + case *int: + err = o.parseInt(args, argCount) + case *float64: + err = o.parseFloat(args) + case *string: + err = o.parseString(args) + case *os.File: + err = o.parseFile(args) + case *[]string: + err = o.parseStringList(args) + case *[]int: + err = o.parseIntList(args) + case *[]float64: + err = o.parseFloatList(args) + case *[]os.File: + err = o.parseFileList(args) + default: + err = fmt.Errorf("unsupported type [%t]", o.result) + } + return err +} + +func (o *arg) parse(args []string, argCount int) error { + // If unique do not allow more than one time + if o.unique && (o.parsed || argCount > 1) { + return fmt.Errorf("[%s] can only be present once", o.name()) + } + + // If validation function provided -- execute, on error return it immediately + if o.opts != nil && o.opts.Validate != nil { + err := o.opts.Validate(args) + if err != nil { + return err + } + } + return o.parseSomeType(args, argCount) +} + +func (o *arg) name() string { + var name string + if o.lname == "" { + name = "-" + o.sname + } else if o.sname == "" { + name = "--" + o.lname + } else { + name = "-" + o.sname + "|" + "--" + o.lname + } + return name +} + +func (o *arg) usage() string { + var result string + result = o.name() + switch o.result.(type) { + case *bool: + break + case *int: + result = result + " " + case *float64: + result = result + " " + case *string: + if o.selector != nil { + result = result + " (" + strings.Join(*o.selector, "|") + ")" + } else { + result = result + " \"\"" + } + case *os.File: + result = result + " " + case *[]string: + result = result + " \"\"" + " [" + result + " \"\" ...]" + default: + break + } + if o.opts == nil || o.opts.Required == false { + result = "[" + result + "]" + } + return result +} + +func (o *arg) getHelpMessage() string { + message := "" + if len(o.opts.Help) > 0 { + message += o.opts.Help + if !o.opts.Required && o.opts.Default != nil { + message += fmt.Sprintf(". Default: %v", o.opts.Default) + } + } + return message +} + +// setDefaultFile - gets default os.File object based on provided default filename string +func (o *arg) setDefaultFile() error { + // In case of File we should get string as default value + if v, ok := o.opts.Default.(string); ok { + f, err := os.OpenFile(v, o.fileFlag, o.filePerm) + if err != nil { + return err + } + *o.result.(*os.File) = *f + } else { + return fmt.Errorf("cannot use default type [%T] as value of pointer with type [*string]", o.opts.Default) + } + return nil +} + +// setDefaultFiles - gets list of default os.File objects based on provided list of default filenames strings +func (o *arg) setDefaultFiles() error { + // In case of FileList we should get []string as default value + var files []os.File + if fileNames, ok := o.opts.Default.([]string); ok { + files = make([]os.File, 0, len(fileNames)) + for _, v := range fileNames { + f, err := os.OpenFile(v, o.fileFlag, o.filePerm) + if err != nil { + //if one of FileList's file opening have been failed, close all other in this list + errs := make([]string, 0, len(*o.result.(*[]os.File))) + for _, f := range *o.result.(*[]os.File) { + if err := f.Close(); err != nil { + //almost unreal, but what if another process closed this file + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + err = fmt.Errorf("while handling error: %v, other errors occured: %#v", err.Error(), errs) + } + *o.result.(*[]os.File) = []os.File{} + return err + } + files = append(files, *f) + } + } else { + return fmt.Errorf("cannot use default type [%T] as value of pointer with type [*[]string]", o.opts.Default) + } + *o.result.(*[]os.File) = files + return nil +} + +// setDefault - if no value getted for specific argument, set default value, if provided +func (o *arg) setDefault() error { + // Only set default if it was not parsed, and default value was defined + if !o.parsed && o.opts != nil && o.opts.Default != nil { + switch o.result.(type) { + case *bool, *int, *float64, *string, *[]bool, *[]int, *[]float64, *[]string: + if reflect.TypeOf(o.result) != reflect.PtrTo(reflect.TypeOf(o.opts.Default)) { + return fmt.Errorf("cannot use default type [%T] as value of pointer with type [%T]", o.opts.Default, o.result) + } + reflect.ValueOf(o.result).Elem().Set(reflect.ValueOf(o.opts.Default)) + + case *os.File: + if err := o.setDefaultFile(); err != nil { + return err + } + case *[]os.File: + if err := o.setDefaultFiles(); err != nil { + return err + } + } + } + + return nil +} diff --git a/vendor/github.com/akamensky/argparse/command.go b/vendor/github.com/akamensky/argparse/command.go new file mode 100644 index 0000000..7bccf3b --- /dev/null +++ b/vendor/github.com/akamensky/argparse/command.go @@ -0,0 +1,179 @@ +package argparse + +import ( + "fmt" + "strings" +) + +func (o *Command) help(sname, lname string) { + result := &help{} + + if lname == "" { + sname, lname = "h", "help" + } + + a := &arg{ + result: result, + sname: sname, + lname: lname, + size: 1, + opts: &Options{Help: "Print help information"}, + unique: true, + } + + o.addArg(a) +} + +func (o *Command) addArg(a *arg) error { + // long name should be provided + if a.lname == "" { + return fmt.Errorf("long name should be provided") + } + // short name could be provided and must not exceed 1 character + if len(a.sname) > 1 { + return fmt.Errorf("short name must not exceed 1 character") + } + // Search parents for overlapping commands and fail if any + current := o + for current != nil { + if current.args != nil { + for _, v := range current.args { + if a.lname != "help" || a.sname != "h" { + if a.sname != "" && a.sname == v.sname { + return fmt.Errorf("short name %s occurs more than once", a.sname) + } + if a.lname == v.lname { + return fmt.Errorf("long name %s occurs more than once", a.lname) + } + } + } + } + current = current.parent + } + a.parent = o + o.args = append(o.args, a) + return nil +} + +//parseSubCommands - Parses subcommands if any +func (o *Command) parseSubCommands(args *[]string) error { + if o.commands != nil && len(o.commands) > 0 { + // If we have subcommands and 0 args left + // that is an error of SubCommandError type + if len(*args) < 1 { + return newSubCommandError(o) + } + for _, v := range o.commands { + err := v.parse(args) + if err != nil { + return err + } + } + } + return nil +} + +//parseArguments - Parses arguments +func (o *Command) parseArguments(args *[]string) error { + // Iterate over the args + for i := 0; i < len(o.args); i++ { + oarg := o.args[i] + for j := 0; j < len(*args); j++ { + arg := (*args)[j] + if arg == "" { + continue + } + if strings.Contains(arg, "=") { + splitInd := strings.LastIndex(arg, "=") + equalArg := []string{arg[:splitInd], arg[splitInd+1:]} + if cnt, err := oarg.check(equalArg[0]); err != nil { + return err + } else if cnt > 0 { + if equalArg[1] == "" { + return fmt.Errorf("not enough arguments for %s", oarg.name()) + } + oarg.eqChar = true + oarg.size = 1 + currArg := []string{equalArg[1]} + err := oarg.parse(currArg, cnt) + if err != nil { + return err + } + oarg.reduce(j, args) + continue + } + } + if cnt, err := oarg.check(arg); err != nil { + return err + } else if cnt > 0 { + if len(*args) < j+oarg.size { + return fmt.Errorf("not enough arguments for %s", oarg.name()) + } + err := oarg.parse((*args)[j+1:j+oarg.size], cnt) + if err != nil { + return err + } + oarg.reduce(j, args) + continue + } + } + + // Check if arg is required and not provided + if oarg.opts != nil && oarg.opts.Required && !oarg.parsed { + return fmt.Errorf("[%s] is required", oarg.name()) + } + + // Check for argument default value and if provided try to type cast and assign + if oarg.opts != nil && oarg.opts.Default != nil && !oarg.parsed { + err := oarg.setDefault() + if err != nil { + return err + } + } + } + return nil +} + +// Will parse provided list of arguments +// common usage would be to pass directly os.Args +func (o *Command) parse(args *[]string) error { + // If we already been parsed do nothing + if o.parsed { + return nil + } + + // If no arguments left to parse do nothing + if len(*args) < 1 { + return nil + } + + // Parse only matching commands + // But we always have to parse top level + if o.name == "" { + o.name = (*args)[0] + } else { + if o.name != (*args)[0] && o.parent != nil { + return nil + } + } + + // Set happened status to true when command happend + o.happened = true + + // Reduce arguments by removing Command name + *args = (*args)[1:] + + // Parse subcommands if any + if err := o.parseSubCommands(args); err != nil { + return err + } + + // Parse arguments if any + if err := o.parseArguments(args); err != nil { + return err + } + + // Set parsed status to true and return quietly + o.parsed = true + return nil +} diff --git a/vendor/github.com/akamensky/argparse/errors.go b/vendor/github.com/akamensky/argparse/errors.go new file mode 100644 index 0000000..3f5d461 --- /dev/null +++ b/vendor/github.com/akamensky/argparse/errors.go @@ -0,0 +1,14 @@ +package argparse + +type subCommandError struct { + error + cmd *Command +} + +func (e subCommandError) Error() string { + return "[sub]Command required" +} + +func newSubCommandError(cmd *Command) error { + return subCommandError{cmd: cmd} +} diff --git a/vendor/github.com/akamensky/argparse/extras.go b/vendor/github.com/akamensky/argparse/extras.go new file mode 100644 index 0000000..fbc7d5a --- /dev/null +++ b/vendor/github.com/akamensky/argparse/extras.go @@ -0,0 +1,25 @@ +package argparse + +import "strings" + +func getLastLine(input string) string { + slice := strings.Split(input, "\n") + return slice[len(slice)-1] +} + +func addToLastLine(base string, add string, width int, padding int, canSplit bool) string { + // If last line has less than 10% space left, do not try to fill in by splitting else just try to split + hasTen := (width - len(getLastLine(base))) > width/10 + if len(getLastLine(base)+" "+add) >= width { + if hasTen && canSplit { + adds := strings.Split(add, " ") + for _, v := range adds { + base = addToLastLine(base, v, width, padding, false) + } + return base + } + base = base + "\n" + strings.Repeat(" ", padding) + } + base = base + " " + add + return base +} diff --git a/vendor/github.com/akamensky/argparse/go.mod b/vendor/github.com/akamensky/argparse/go.mod new file mode 100644 index 0000000..189632f --- /dev/null +++ b/vendor/github.com/akamensky/argparse/go.mod @@ -0,0 +1,3 @@ +module github.com/akamensky/argparse + +go 1.13 diff --git a/vendor/github.com/andres-erbsen/clock/.travis.yml b/vendor/github.com/andres-erbsen/clock/.travis.yml new file mode 100644 index 0000000..ca785e5 --- /dev/null +++ b/vendor/github.com/andres-erbsen/clock/.travis.yml @@ -0,0 +1,7 @@ +language: go +go: + - 1.3 + - 1.4 + - release + - tip +sudo: false diff --git a/vendor/github.com/andres-erbsen/clock/LICENSE b/vendor/github.com/andres-erbsen/clock/LICENSE new file mode 100644 index 0000000..ddf4e00 --- /dev/null +++ b/vendor/github.com/andres-erbsen/clock/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014 Ben Johnson, Copyright (c) 2015 Yahoo Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/andres-erbsen/clock/README.md b/vendor/github.com/andres-erbsen/clock/README.md new file mode 100644 index 0000000..f744e76 --- /dev/null +++ b/vendor/github.com/andres-erbsen/clock/README.md @@ -0,0 +1,104 @@ +clock [![Build Status](https://travis-ci.org/andres-erbsen/clock.svg)](https://travis-ci.org/andres-erbsen/clock) [![Coverage Status](https://coveralls.io/repos/andres-erbsen/clock/badge.png?branch=master)](https://coveralls.io/r/andres-erbsen/clock?branch=master) [![GoDoc](https://godoc.org/github.com/andres-erbsen/clock?status.png)](https://godoc.org/github.com/andres-erbsen/clock) ![Project status](http://img.shields.io/status/experimental.png?color=red) +===== + +Clock is a small library for mocking time in Go. It provides an interface +around the standard library's [`time`][time] package so that the application +can use the realtime clock while tests can use the mock clock. + +[time]: http://golang.org/pkg/time/ + + +## Usage + +### Realtime Clock + +Your application can maintain a `Clock` variable that will allow realtime and +mock clocks to be interchangable. For example, if you had an `Application` type: + +```go +import "github.com/andres-erbsen/clock" + +type Application struct { + Clock clock.Clock +} +``` + +You could initialize it to use the realtime clock like this: + +```go +var app Application +app.Clock = clock.New() +... +``` + +Then all timers and time-related functionality should be performed from the +`Clock` variable. + + +### Mocking time + +In your tests, you will want to use a `Mock` clock: + +```go +import ( + "testing" + + "github.com/andres-erbsen/clock" +) + +func TestApplication_DoSomething(t *testing.T) { + mock := clock.NewMock() + app := Application{Clock: mock} + ... +} +``` + +Now that you've initialized your application to use the mock clock, you can +adjust the time programmatically. The mock clock always starts from the Unix +epoch (midnight, Jan 1, 1970 UTC). + + +### Controlling time + +The mock clock provides the same functions that the standard library's `time` +package provides. For example, to find the current time, you use the `Now()` +function: + +```go +mock := clock.NewMock() + +// Find the current time. +mock.Now().UTC() // 1970-01-01 00:00:00 +0000 UTC + +// Move the clock forward. +mock.Add(2 * time.Hour) + +// Check the time again. It's 2 hours later! +mock.Now().UTC() // 1970-01-01 02:00:00 +0000 UTC +``` + +Timers and Tickers are also controlled by this same mock clock. They will only +execute when the clock is moved forward: + +``` +mock := clock.NewMock() +count := 0 + +// Kick off a timer to increment every 1 mock second. +go func() { + ticker := clock.Ticker(1 * time.Second) + for { + <-ticker.C + count++ + } +}() +runtime.Gosched() + +// Move the clock forward 10 second. +mock.Add(10 * time.Second) + +// This prints 10. +fmt.Println(count) +``` + + diff --git a/vendor/github.com/andres-erbsen/clock/clock.go b/vendor/github.com/andres-erbsen/clock/clock.go new file mode 100644 index 0000000..b58b703 --- /dev/null +++ b/vendor/github.com/andres-erbsen/clock/clock.go @@ -0,0 +1,317 @@ +package clock + +import ( + "sort" + "sync" + "time" +) + +// Clock represents an interface to the functions in the standard library time +// package. Two implementations are available in the clock package. The first +// is a real-time clock which simply wraps the time package's functions. The +// second is a mock clock which will only make forward progress when +// programmatically adjusted. +type Clock interface { + After(d time.Duration) <-chan time.Time + AfterFunc(d time.Duration, f func()) *Timer + Now() time.Time + Sleep(d time.Duration) + Tick(d time.Duration) <-chan time.Time + Ticker(d time.Duration) *Ticker + Timer(d time.Duration) *Timer +} + +// New returns an instance of a real-time clock. +func New() Clock { + return &clock{} +} + +// clock implements a real-time clock by simply wrapping the time package functions. +type clock struct{} + +func (c *clock) After(d time.Duration) <-chan time.Time { return time.After(d) } + +func (c *clock) AfterFunc(d time.Duration, f func()) *Timer { + return &Timer{timer: time.AfterFunc(d, f)} +} + +func (c *clock) Now() time.Time { return time.Now() } + +func (c *clock) Sleep(d time.Duration) { time.Sleep(d) } + +func (c *clock) Tick(d time.Duration) <-chan time.Time { return time.Tick(d) } + +func (c *clock) Ticker(d time.Duration) *Ticker { + t := time.NewTicker(d) + return &Ticker{C: t.C, ticker: t} +} + +func (c *clock) Timer(d time.Duration) *Timer { + t := time.NewTimer(d) + return &Timer{C: t.C, timer: t} +} + +// Mock represents a mock clock that only moves forward programmically. +// It can be preferable to a real-time clock when testing time-based functionality. +type Mock struct { + mu sync.Mutex + now time.Time // current time + timers clockTimers // tickers & timers +} + +// NewMock returns an instance of a mock clock. +// The current time of the mock clock on initialization is the Unix epoch. +func NewMock() *Mock { + return &Mock{now: time.Unix(0, 0)} +} + +// Add moves the current time of the mock clock forward by the duration. +// This should only be called from a single goroutine at a time. +func (m *Mock) Add(d time.Duration) { + // Calculate the final current time. + t := m.now.Add(d) + + // Continue to execute timers until there are no more before the new time. + for { + if !m.runNextTimer(t) { + break + } + } + + // Ensure that we end with the new time. + m.mu.Lock() + m.now = t + m.mu.Unlock() + + // Give a small buffer to make sure the other goroutines get handled. + gosched() +} + +// Sets the current time of the mock clock to a specific one. +// This should only be called from a single goroutine at a time. +func (m *Mock) Set(t time.Time) { + // Continue to execute timers until there are no more before the new time. + for { + if !m.runNextTimer(t) { + break + } + } + + // Ensure that we end with the new time. + m.mu.Lock() + m.now = t + m.mu.Unlock() + + // Give a small buffer to make sure the other goroutines get handled. + gosched() +} + +// runNextTimer executes the next timer in chronological order and moves the +// current time to the timer's next tick time. The next time is not executed if +// it's next time if after the max time. Returns true if a timer is executed. +func (m *Mock) runNextTimer(max time.Time) bool { + m.mu.Lock() + + // Sort timers by time. + sort.Sort(m.timers) + + // If we have no more timers then exit. + if len(m.timers) == 0 { + m.mu.Unlock() + return false + } + + // Retrieve next timer. Exit if next tick is after new time. + t := m.timers[0] + if t.Next().After(max) { + m.mu.Unlock() + return false + } + + // Move "now" forward and unlock clock. + m.now = t.Next() + m.mu.Unlock() + + // Execute timer. + t.Tick(m.now) + return true +} + +// After waits for the duration to elapse and then sends the current time on the returned channel. +func (m *Mock) After(d time.Duration) <-chan time.Time { + return m.Timer(d).C +} + +// AfterFunc waits for the duration to elapse and then executes a function. +// A Timer is returned that can be stopped. +func (m *Mock) AfterFunc(d time.Duration, f func()) *Timer { + t := m.Timer(d) + t.C = nil + t.fn = f + return t +} + +// Now returns the current wall time on the mock clock. +func (m *Mock) Now() time.Time { + m.mu.Lock() + defer m.mu.Unlock() + return m.now +} + +// Sleep pauses the goroutine for the given duration on the mock clock. +// The clock must be moved forward in a separate goroutine. +func (m *Mock) Sleep(d time.Duration) { + <-m.After(d) +} + +// Tick is a convenience function for Ticker(). +// It will return a ticker channel that cannot be stopped. +func (m *Mock) Tick(d time.Duration) <-chan time.Time { + return m.Ticker(d).C +} + +// Ticker creates a new instance of Ticker. +func (m *Mock) Ticker(d time.Duration) *Ticker { + m.mu.Lock() + defer m.mu.Unlock() + ch := make(chan time.Time, 1) + t := &Ticker{ + C: ch, + c: ch, + mock: m, + d: d, + next: m.now.Add(d), + } + m.timers = append(m.timers, (*internalTicker)(t)) + return t +} + +// Timer creates a new instance of Timer. +func (m *Mock) Timer(d time.Duration) *Timer { + ch := make(chan time.Time, 1) + t := &Timer{ + C: ch, + c: ch, + mock: m, + next: m.Now().Add(d), + } + m.addTimer((*internalTimer)(t)) + return t +} + +func (m *Mock) addTimer(t *internalTimer) { + m.mu.Lock() + defer m.mu.Unlock() + m.timers = append(m.timers, t) +} + +func (m *Mock) removeClockTimer(t clockTimer) bool { + m.mu.Lock() + defer m.mu.Unlock() + ret := false + for i, timer := range m.timers { + if timer == t { + ret = true + copy(m.timers[i:], m.timers[i+1:]) + m.timers[len(m.timers)-1] = nil + m.timers = m.timers[:len(m.timers)-1] + break + } + } + sort.Sort(m.timers) + return ret +} + +// clockTimer represents an object with an associated start time. +type clockTimer interface { + Next() time.Time + Tick(time.Time) +} + +// clockTimers represents a list of sortable timers. +type clockTimers []clockTimer + +func (a clockTimers) Len() int { return len(a) } +func (a clockTimers) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a clockTimers) Less(i, j int) bool { return a[i].Next().Before(a[j].Next()) } + +// Timer represents a single event. +// The current time will be sent on C, unless the timer was created by AfterFunc. +type Timer struct { + C <-chan time.Time + c chan time.Time + timer *time.Timer // realtime impl, if set + next time.Time // next tick time + mock *Mock // mock clock, if set + fn func() // AfterFunc function, if set +} + +// Stop turns off the timer. +func (t *Timer) Stop() bool { + if t.timer != nil { + return t.timer.Stop() + } + return t.mock.removeClockTimer((*internalTimer)(t)) +} + +// Reset changes the timer to expire after duration d. It returns true if the +// timer had been active, false if the timer had expired or been stopped. +func (t *Timer) Reset(d time.Duration) bool { + if t.timer != nil { + return t.timer.Reset(d) + } + ret := t.mock.removeClockTimer((*internalTimer)(t)) + t.next = t.mock.Now().Add(d) + t.mock.addTimer((*internalTimer)(t)) + return ret +} + +type internalTimer Timer + +func (t *internalTimer) Next() time.Time { return t.next } +func (t *internalTimer) Tick(now time.Time) { + if t.fn != nil { + t.fn() + } else { + select { + case t.c <- now: + default: + } + } + t.mock.removeClockTimer((*internalTimer)(t)) + gosched() +} + +// Ticker holds a channel that receives "ticks" at regular intervals. +type Ticker struct { + C <-chan time.Time + c chan time.Time + ticker *time.Ticker // realtime impl, if set + next time.Time // next tick time + mock *Mock // mock clock, if set + d time.Duration // time between ticks +} + +// Stop turns off the ticker. +func (t *Ticker) Stop() { + if t.ticker != nil { + t.ticker.Stop() + } else { + t.mock.removeClockTimer((*internalTicker)(t)) + } +} + +type internalTicker Ticker + +func (t *internalTicker) Next() time.Time { return t.next } +func (t *internalTicker) Tick(now time.Time) { + select { + case t.c <- now: + default: + } + t.next = now.Add(t.d) + gosched() +} + +// Sleep momentarily so that other goroutines can process. +func gosched() { time.Sleep(1 * time.Millisecond) } diff --git a/vendor/github.com/avast/retry-go/.gitignore b/vendor/github.com/avast/retry-go/.gitignore new file mode 100644 index 0000000..c40eb23 --- /dev/null +++ b/vendor/github.com/avast/retry-go/.gitignore @@ -0,0 +1,21 @@ +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ + +# dep +vendor/ +Gopkg.lock + +# cover +coverage.txt diff --git a/vendor/github.com/avast/retry-go/.godocdown.tmpl b/vendor/github.com/avast/retry-go/.godocdown.tmpl new file mode 100644 index 0000000..6873edf --- /dev/null +++ b/vendor/github.com/avast/retry-go/.godocdown.tmpl @@ -0,0 +1,37 @@ +# {{ .Name }} + +[![Release](https://img.shields.io/github/release/avast/retry-go.svg?style=flat-square)](https://github.com/avast/retry-go/releases/latest) +[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg?style=flat-square)](LICENSE.md) +[![Travis](https://img.shields.io/travis/avast/retry-go.svg?style=flat-square)](https://travis-ci.org/avast/retry-go) +[![AppVeyor](https://ci.appveyor.com/api/projects/status/fieg9gon3qlq0a9a?svg=true)](https://ci.appveyor.com/project/JaSei/retry-go) +[![Go Report Card](https://goreportcard.com/badge/github.com/avast/retry-go?style=flat-square)](https://goreportcard.com/report/github.com/avast/retry-go) +[![GoDoc](https://godoc.org/github.com/avast/retry-go?status.svg&style=flat-square)](http://godoc.org/github.com/avast/retry-go) +[![codecov.io](https://codecov.io/github/avast/retry-go/coverage.svg?branch=master)](https://codecov.io/github/avast/retry-go?branch=master) +[![Sourcegraph](https://sourcegraph.com/github.com/avast/retry-go/-/badge.svg)](https://sourcegraph.com/github.com/avast/retry-go?badge) + +{{ .EmitSynopsis }} + +{{ .EmitUsage }} + +## Contributing + +Contributions are very much welcome. + +### Makefile + +Makefile provides several handy rules, like README.md `generator` , `setup` for prepare build/dev environment, `test`, `cover`, etc... + +Try `make help` for more information. + +### Before pull request + +please try: +* run tests (`make test`) +* run linter (`make lint`) +* if your IDE don't automaticaly do `go fmt`, run `go fmt` (`make fmt`) + +### README + +README.md are generate from template [.godocdown.tmpl](.godocdown.tmpl) and code documentation via [godocdown](https://github.com/robertkrimen/godocdown). + +Never edit README.md direct, because your change will be lost. diff --git a/vendor/github.com/avast/retry-go/.travis.yml b/vendor/github.com/avast/retry-go/.travis.yml new file mode 100644 index 0000000..ae3e0b6 --- /dev/null +++ b/vendor/github.com/avast/retry-go/.travis.yml @@ -0,0 +1,20 @@ +language: go + +go: + - 1.8 + - 1.9 + - "1.10" + - 1.11 + - 1.12 + - 1.13 + - 1.14 + - 1.15 + +install: + - make setup + +script: + - make ci + +after_success: + - bash <(curl -s https://codecov.io/bash) diff --git a/vendor/github.com/avast/retry-go/Gopkg.toml b/vendor/github.com/avast/retry-go/Gopkg.toml new file mode 100644 index 0000000..cf8c9eb --- /dev/null +++ b/vendor/github.com/avast/retry-go/Gopkg.toml @@ -0,0 +1,3 @@ +[[constraint]] + name = "github.com/stretchr/testify" + version = "1.1.4" diff --git a/vendor/github.com/avast/retry-go/LICENSE b/vendor/github.com/avast/retry-go/LICENSE new file mode 100644 index 0000000..f63fca8 --- /dev/null +++ b/vendor/github.com/avast/retry-go/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Avast + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/avast/retry-go/Makefile b/vendor/github.com/avast/retry-go/Makefile new file mode 100644 index 0000000..769816d --- /dev/null +++ b/vendor/github.com/avast/retry-go/Makefile @@ -0,0 +1,65 @@ +SOURCE_FILES?=$$(go list ./... | grep -v /vendor/) +TEST_PATTERN?=. +TEST_OPTIONS?= +DEP?=$$(which dep) +VERSION?=$$(cat VERSION) +LINTER?=$$(which golangci-lint) +LINTER_VERSION=1.15.0 + +ifeq ($(OS),Windows_NT) + DEP_VERS=dep-windows-amd64 + LINTER_FILE=golangci-lint-$(LINTER_VERSION)-windows-amd64.zip + LINTER_UNPACK= >| app.zip; unzip -j app.zip -d $$GOPATH/bin; rm app.zip +else ifeq ($(OS), Darwin) + LINTER_FILE=golangci-lint-$(LINTER_VERSION)-darwin-amd64.tar.gz + LINTER_UNPACK= | tar xzf - -C $$GOPATH/bin --wildcards --strip 1 "**/golangci-lint" +else + DEP_VERS=dep-linux-amd64 + LINTER_FILE=golangci-lint-$(LINTER_VERSION)-linux-amd64.tar.gz + LINTER_UNPACK= | tar xzf - -C $$GOPATH/bin --wildcards --strip 1 "**/golangci-lint" +endif + +setup: + go get -u github.com/pierrre/gotestcover + go get -u golang.org/x/tools/cmd/cover + go get -u github.com/robertkrimen/godocdown/godocdown + @if [ "$(LINTER)" = "" ]; then\ + curl -L https://github.com/golangci/golangci-lint/releases/download/v$(LINTER_VERSION)/$(LINTER_FILE) $(LINTER_UNPACK) ;\ + chmod +x $$GOPATH/bin/golangci-lint;\ + fi + @if [ "$(DEP)" = "" ]; then\ + curl -L https://github.com/golang/dep/releases/download/v0.3.1/$(DEP_VERS) >| $$GOPATH/bin/dep;\ + chmod +x $$GOPATH/bin/dep;\ + fi + dep ensure + +generate: ## Generate README.md + godocdown >| README.md + +test: generate test_and_cover_report lint + +test_and_cover_report: + gotestcover $(TEST_OPTIONS) -covermode=atomic -coverprofile=coverage.txt $(SOURCE_FILES) -run $(TEST_PATTERN) -timeout=2m + +cover: test ## Run all the tests and opens the coverage report + go tool cover -html=coverage.txt + +fmt: ## gofmt and goimports all go files + find . -name '*.go' -not -wholename './vendor/*' | while read -r file; do gofmt -w -s "$$file"; goimports -w "$$file"; done + +lint: ## Run all the linters + golangci-lint run + +ci: test_and_cover_report ## Run all the tests but no linters - use https://golangci.com integration instead + +build: + go build + +release: ## Release new version + git tag | grep -q $(VERSION) && echo This version was released! Increase VERSION! || git tag $(VERSION) && git push origin $(VERSION) && git tag v$(VERSION) && git push origin v$(VERSION) + +# Absolutely awesome: http://marmelab.com/blog/2016/02/29/auto-documented-makefile.html +help: + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.DEFAULT_GOAL := build diff --git a/vendor/github.com/avast/retry-go/README.md b/vendor/github.com/avast/retry-go/README.md new file mode 100644 index 0000000..80fb73b --- /dev/null +++ b/vendor/github.com/avast/retry-go/README.md @@ -0,0 +1,361 @@ +# retry + +[![Release](https://img.shields.io/github/release/avast/retry-go.svg?style=flat-square)](https://github.com/avast/retry-go/releases/latest) +[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg?style=flat-square)](LICENSE.md) +[![Travis](https://img.shields.io/travis/avast/retry-go.svg?style=flat-square)](https://travis-ci.org/avast/retry-go) +[![AppVeyor](https://ci.appveyor.com/api/projects/status/fieg9gon3qlq0a9a?svg=true)](https://ci.appveyor.com/project/JaSei/retry-go) +[![Go Report Card](https://goreportcard.com/badge/github.com/avast/retry-go?style=flat-square)](https://goreportcard.com/report/github.com/avast/retry-go) +[![GoDoc](https://godoc.org/github.com/avast/retry-go?status.svg&style=flat-square)](http://godoc.org/github.com/avast/retry-go) +[![codecov.io](https://codecov.io/github/avast/retry-go/coverage.svg?branch=master)](https://codecov.io/github/avast/retry-go?branch=master) +[![Sourcegraph](https://sourcegraph.com/github.com/avast/retry-go/-/badge.svg)](https://sourcegraph.com/github.com/avast/retry-go?badge) + +Simple library for retry mechanism + +slightly inspired by +[Try::Tiny::Retry](https://metacpan.org/pod/Try::Tiny::Retry) + + +### SYNOPSIS + +http get with retry: + + url := "http://example.com" + var body []byte + + err := retry.Do( + func() error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + body, err = ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + + return nil + }, + ) + + fmt.Println(body) + +[next examples](https://github.com/avast/retry-go/tree/master/examples) + + +### SEE ALSO + +* [giantswarm/retry-go](https://github.com/giantswarm/retry-go) - slightly +complicated interface. + +* [sethgrid/pester](https://github.com/sethgrid/pester) - only http retry for +http calls with retries and backoff + +* [cenkalti/backoff](https://github.com/cenkalti/backoff) - Go port of the +exponential backoff algorithm from Google's HTTP Client Library for Java. Really +complicated interface. + +* [rafaeljesus/retry-go](https://github.com/rafaeljesus/retry-go) - looks good, +slightly similar as this package, don't have 'simple' `Retry` method + +* [matryer/try](https://github.com/matryer/try) - very popular package, +nonintuitive interface (for me) + + +### BREAKING CHANGES + +3.0.0 + +* `DelayTypeFunc` accepts a new parameter `err` - this breaking change affects +only your custom Delay Functions. This change allow [make delay functions based +on error](examples/delay_based_on_error_test.go). + +1.0.2 -> 2.0.0 + +* argument of `retry.Delay` is final delay (no multiplication by `retry.Units` +anymore) + +* function `retry.Units` are removed + +* [more about this breaking change](https://github.com/avast/retry-go/issues/7) + +0.3.0 -> 1.0.0 + +* `retry.Retry` function are changed to `retry.Do` function + +* `retry.RetryCustom` (OnRetry) and `retry.RetryCustomWithOpts` functions are +now implement via functions produces Options (aka `retry.OnRetry`) + +## Usage + +```go +var ( + DefaultAttempts = uint(10) + DefaultDelay = 100 * time.Millisecond + DefaultMaxJitter = 100 * time.Millisecond + DefaultOnRetry = func(n uint, err error) {} + DefaultRetryIf = IsRecoverable + DefaultDelayType = CombineDelay(BackOffDelay, RandomDelay) + DefaultLastErrorOnly = false + DefaultContext = context.Background() +) +``` + +#### func BackOffDelay + +```go +func BackOffDelay(n uint, _ error, config *Config) time.Duration +``` +BackOffDelay is a DelayType which increases delay between consecutive retries + +#### func Do + +```go +func Do(retryableFunc RetryableFunc, opts ...Option) error +``` + +#### func FixedDelay + +```go +func FixedDelay(_ uint, _ error, config *Config) time.Duration +``` +FixedDelay is a DelayType which keeps delay the same through all iterations + +#### func IsRecoverable + +```go +func IsRecoverable(err error) bool +``` +IsRecoverable checks if error is an instance of `unrecoverableError` + +#### func RandomDelay + +```go +func RandomDelay(_ uint, _ error, config *Config) time.Duration +``` +RandomDelay is a DelayType which picks a random delay up to config.maxJitter + +#### func Unrecoverable + +```go +func Unrecoverable(err error) error +``` +Unrecoverable wraps an error in `unrecoverableError` struct + +#### type Config + +```go +type Config struct { +} +``` + + +#### type DelayTypeFunc + +```go +type DelayTypeFunc func(n uint, err error, config *Config) time.Duration +``` + +DelayTypeFunc is called to return the next delay to wait after the retriable +function fails on `err` after `n` attempts. + +#### func CombineDelay + +```go +func CombineDelay(delays ...DelayTypeFunc) DelayTypeFunc +``` +CombineDelay is a DelayType the combines all of the specified delays into a new +DelayTypeFunc + +#### type Error + +```go +type Error []error +``` + +Error type represents list of errors in retry + +#### func (Error) Error + +```go +func (e Error) Error() string +``` +Error method return string representation of Error It is an implementation of +error interface + +#### func (Error) WrappedErrors + +```go +func (e Error) WrappedErrors() []error +``` +WrappedErrors returns the list of errors that this Error is wrapping. It is an +implementation of the `errwrap.Wrapper` interface in package +[errwrap](https://github.com/hashicorp/errwrap) so that `retry.Error` can be +used with that library. + +#### type OnRetryFunc + +```go +type OnRetryFunc func(n uint, err error) +``` + +Function signature of OnRetry function n = count of attempts + +#### type Option + +```go +type Option func(*Config) +``` + +Option represents an option for retry. + +#### func Attempts + +```go +func Attempts(attempts uint) Option +``` +Attempts set count of retry default is 10 + +#### func Context + +```go +func Context(ctx context.Context) Option +``` +Context allow to set context of retry default are Background context + +example of immediately cancellation (maybe it isn't the best example, but it +describes behavior enough; I hope) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + retry.Do( + func() error { + ... + }, + retry.Context(ctx), + ) + +#### func Delay + +```go +func Delay(delay time.Duration) Option +``` +Delay set delay between retry default is 100ms + +#### func DelayType + +```go +func DelayType(delayType DelayTypeFunc) Option +``` +DelayType set type of the delay between retries default is BackOff + +#### func LastErrorOnly + +```go +func LastErrorOnly(lastErrorOnly bool) Option +``` +return the direct last error that came from the retried function default is +false (return wrapped errors with everything) + +#### func MaxDelay + +```go +func MaxDelay(maxDelay time.Duration) Option +``` +MaxDelay set maximum delay between retry does not apply by default + +#### func MaxJitter + +```go +func MaxJitter(maxJitter time.Duration) Option +``` +MaxJitter sets the maximum random Jitter between retries for RandomDelay + +#### func OnRetry + +```go +func OnRetry(onRetry OnRetryFunc) Option +``` +OnRetry function callback are called each retry + +log each retry example: + + retry.Do( + func() error { + return errors.New("some error") + }, + retry.OnRetry(func(n uint, err error) { + log.Printf("#%d: %s\n", n, err) + }), + ) + +#### func RetryIf + +```go +func RetryIf(retryIf RetryIfFunc) Option +``` +RetryIf controls whether a retry should be attempted after an error (assuming +there are any retry attempts remaining) + +skip retry if special error example: + + retry.Do( + func() error { + return errors.New("special error") + }, + retry.RetryIf(func(err error) bool { + if err.Error() == "special error" { + return false + } + return true + }) + ) + +By default RetryIf stops execution if the error is wrapped using +`retry.Unrecoverable`, so above example may also be shortened to: + + retry.Do( + func() error { + return retry.Unrecoverable(errors.New("special error")) + } + ) + +#### type RetryIfFunc + +```go +type RetryIfFunc func(error) bool +``` + +Function signature of retry if function + +#### type RetryableFunc + +```go +type RetryableFunc func() error +``` + +Function signature of retryable function + +## Contributing + +Contributions are very much welcome. + +### Makefile + +Makefile provides several handy rules, like README.md `generator` , `setup` for prepare build/dev environment, `test`, `cover`, etc... + +Try `make help` for more information. + +### Before pull request + +please try: +* run tests (`make test`) +* run linter (`make lint`) +* if your IDE don't automaticaly do `go fmt`, run `go fmt` (`make fmt`) + +### README + +README.md are generate from template [.godocdown.tmpl](.godocdown.tmpl) and code documentation via [godocdown](https://github.com/robertkrimen/godocdown). + +Never edit README.md direct, because your change will be lost. diff --git a/vendor/github.com/avast/retry-go/VERSION b/vendor/github.com/avast/retry-go/VERSION new file mode 100644 index 0000000..4a36342 --- /dev/null +++ b/vendor/github.com/avast/retry-go/VERSION @@ -0,0 +1 @@ +3.0.0 diff --git a/vendor/github.com/avast/retry-go/appveyor.yml b/vendor/github.com/avast/retry-go/appveyor.yml new file mode 100644 index 0000000..dc5234a --- /dev/null +++ b/vendor/github.com/avast/retry-go/appveyor.yml @@ -0,0 +1,19 @@ +version: "{build}" + +clone_folder: c:\Users\appveyor\go\src\github.com\avast\retry-go + +#os: Windows Server 2012 R2 +platform: x64 + +install: + - copy c:\MinGW\bin\mingw32-make.exe c:\MinGW\bin\make.exe + - set GOPATH=C:\Users\appveyor\go + - set PATH=%PATH%;c:\MinGW\bin + - set PATH=%PATH%;%GOPATH%\bin;c:\go\bin + - set GOBIN=%GOPATH%\bin + - go version + - go env + - make setup + +build_script: + - make ci diff --git a/vendor/github.com/avast/retry-go/options.go b/vendor/github.com/avast/retry-go/options.go new file mode 100644 index 0000000..a6c5720 --- /dev/null +++ b/vendor/github.com/avast/retry-go/options.go @@ -0,0 +1,198 @@ +package retry + +import ( + "context" + "math" + "math/rand" + "time" +) + +// Function signature of retry if function +type RetryIfFunc func(error) bool + +// Function signature of OnRetry function +// n = count of attempts +type OnRetryFunc func(n uint, err error) + +// DelayTypeFunc is called to return the next delay to wait after the retriable function fails on `err` after `n` attempts. +type DelayTypeFunc func(n uint, err error, config *Config) time.Duration + +type Config struct { + attempts uint + delay time.Duration + maxDelay time.Duration + maxJitter time.Duration + onRetry OnRetryFunc + retryIf RetryIfFunc + delayType DelayTypeFunc + lastErrorOnly bool + context context.Context + + maxBackOffN uint +} + +// Option represents an option for retry. +type Option func(*Config) + +// return the direct last error that came from the retried function +// default is false (return wrapped errors with everything) +func LastErrorOnly(lastErrorOnly bool) Option { + return func(c *Config) { + c.lastErrorOnly = lastErrorOnly + } +} + +// Attempts set count of retry +// default is 10 +func Attempts(attempts uint) Option { + return func(c *Config) { + c.attempts = attempts + } +} + +// Delay set delay between retry +// default is 100ms +func Delay(delay time.Duration) Option { + return func(c *Config) { + c.delay = delay + } +} + +// MaxDelay set maximum delay between retry +// does not apply by default +func MaxDelay(maxDelay time.Duration) Option { + return func(c *Config) { + c.maxDelay = maxDelay + } +} + +// MaxJitter sets the maximum random Jitter between retries for RandomDelay +func MaxJitter(maxJitter time.Duration) Option { + return func(c *Config) { + c.maxJitter = maxJitter + } +} + +// DelayType set type of the delay between retries +// default is BackOff +func DelayType(delayType DelayTypeFunc) Option { + return func(c *Config) { + c.delayType = delayType + } +} + +// BackOffDelay is a DelayType which increases delay between consecutive retries +func BackOffDelay(n uint, _ error, config *Config) time.Duration { + // 1 << 63 would overflow signed int64 (time.Duration), thus 62. + const max uint = 62 + + if config.maxBackOffN == 0 { + if config.delay <= 0 { + config.delay = 1 + } + + config.maxBackOffN = max - uint(math.Floor(math.Log2(float64(config.delay)))) + } + + if n > config.maxBackOffN { + n = config.maxBackOffN + } + + return config.delay << n +} + +// FixedDelay is a DelayType which keeps delay the same through all iterations +func FixedDelay(_ uint, _ error, config *Config) time.Duration { + return config.delay +} + +// RandomDelay is a DelayType which picks a random delay up to config.maxJitter +func RandomDelay(_ uint, _ error, config *Config) time.Duration { + return time.Duration(rand.Int63n(int64(config.maxJitter))) +} + +// CombineDelay is a DelayType the combines all of the specified delays into a new DelayTypeFunc +func CombineDelay(delays ...DelayTypeFunc) DelayTypeFunc { + const maxInt64 = uint64(math.MaxInt64) + + return func(n uint, err error, config *Config) time.Duration { + var total uint64 + for _, delay := range delays { + total += uint64(delay(n, err, config)) + if total > maxInt64 { + total = maxInt64 + } + } + + return time.Duration(total) + } +} + +// OnRetry function callback are called each retry +// +// log each retry example: +// +// retry.Do( +// func() error { +// return errors.New("some error") +// }, +// retry.OnRetry(func(n uint, err error) { +// log.Printf("#%d: %s\n", n, err) +// }), +// ) +func OnRetry(onRetry OnRetryFunc) Option { + return func(c *Config) { + c.onRetry = onRetry + } +} + +// RetryIf controls whether a retry should be attempted after an error +// (assuming there are any retry attempts remaining) +// +// skip retry if special error example: +// +// retry.Do( +// func() error { +// return errors.New("special error") +// }, +// retry.RetryIf(func(err error) bool { +// if err.Error() == "special error" { +// return false +// } +// return true +// }) +// ) +// +// By default RetryIf stops execution if the error is wrapped using `retry.Unrecoverable`, +// so above example may also be shortened to: +// +// retry.Do( +// func() error { +// return retry.Unrecoverable(errors.New("special error")) +// } +// ) +func RetryIf(retryIf RetryIfFunc) Option { + return func(c *Config) { + c.retryIf = retryIf + } +} + +// Context allow to set context of retry +// default are Background context +// +// example of immediately cancellation (maybe it isn't the best example, but it describes behavior enough; I hope) +// +// ctx, cancel := context.WithCancel(context.Background()) +// cancel() +// +// retry.Do( +// func() error { +// ... +// }, +// retry.Context(ctx), +// ) +func Context(ctx context.Context) Option { + return func(c *Config) { + c.context = ctx + } +} diff --git a/vendor/github.com/avast/retry-go/retry.go b/vendor/github.com/avast/retry-go/retry.go new file mode 100644 index 0000000..af2d926 --- /dev/null +++ b/vendor/github.com/avast/retry-go/retry.go @@ -0,0 +1,225 @@ +/* +Simple library for retry mechanism + +slightly inspired by [Try::Tiny::Retry](https://metacpan.org/pod/Try::Tiny::Retry) + +SYNOPSIS + +http get with retry: + + url := "http://example.com" + var body []byte + + err := retry.Do( + func() error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + body, err = ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + + return nil + }, + ) + + fmt.Println(body) + +[next examples](https://github.com/avast/retry-go/tree/master/examples) + + +SEE ALSO + +* [giantswarm/retry-go](https://github.com/giantswarm/retry-go) - slightly complicated interface. + +* [sethgrid/pester](https://github.com/sethgrid/pester) - only http retry for http calls with retries and backoff + +* [cenkalti/backoff](https://github.com/cenkalti/backoff) - Go port of the exponential backoff algorithm from Google's HTTP Client Library for Java. Really complicated interface. + +* [rafaeljesus/retry-go](https://github.com/rafaeljesus/retry-go) - looks good, slightly similar as this package, don't have 'simple' `Retry` method + +* [matryer/try](https://github.com/matryer/try) - very popular package, nonintuitive interface (for me) + + +BREAKING CHANGES + +3.0.0 + +* `DelayTypeFunc` accepts a new parameter `err` - this breaking change affects only your custom Delay Functions. This change allow [make delay functions based on error](examples/delay_based_on_error_test.go). + + +1.0.2 -> 2.0.0 + +* argument of `retry.Delay` is final delay (no multiplication by `retry.Units` anymore) + +* function `retry.Units` are removed + +* [more about this breaking change](https://github.com/avast/retry-go/issues/7) + + +0.3.0 -> 1.0.0 + +* `retry.Retry` function are changed to `retry.Do` function + +* `retry.RetryCustom` (OnRetry) and `retry.RetryCustomWithOpts` functions are now implement via functions produces Options (aka `retry.OnRetry`) + + +*/ +package retry + +import ( + "context" + "fmt" + "strings" + "time" +) + +// Function signature of retryable function +type RetryableFunc func() error + +var ( + DefaultAttempts = uint(10) + DefaultDelay = 100 * time.Millisecond + DefaultMaxJitter = 100 * time.Millisecond + DefaultOnRetry = func(n uint, err error) {} + DefaultRetryIf = IsRecoverable + DefaultDelayType = CombineDelay(BackOffDelay, RandomDelay) + DefaultLastErrorOnly = false + DefaultContext = context.Background() +) + +func Do(retryableFunc RetryableFunc, opts ...Option) error { + var n uint + + //default + config := &Config{ + attempts: DefaultAttempts, + delay: DefaultDelay, + maxJitter: DefaultMaxJitter, + onRetry: DefaultOnRetry, + retryIf: DefaultRetryIf, + delayType: DefaultDelayType, + lastErrorOnly: DefaultLastErrorOnly, + context: DefaultContext, + } + + //apply opts + for _, opt := range opts { + opt(config) + } + + if err := config.context.Err(); err != nil { + return err + } + + var errorLog Error + if !config.lastErrorOnly { + errorLog = make(Error, config.attempts) + } else { + errorLog = make(Error, 1) + } + + lastErrIndex := n + for n < config.attempts { + err := retryableFunc() + + if err != nil { + errorLog[lastErrIndex] = unpackUnrecoverable(err) + + if !config.retryIf(err) { + break + } + + config.onRetry(n, err) + + // if this is last attempt - don't wait + if n == config.attempts-1 { + break + } + + delayTime := config.delayType(n, err, config) + if config.maxDelay > 0 && delayTime > config.maxDelay { + delayTime = config.maxDelay + } + + select { + case <-time.After(delayTime): + case <-config.context.Done(): + return config.context.Err() + } + + } else { + return nil + } + + n++ + if !config.lastErrorOnly { + lastErrIndex = n + } + } + + if config.lastErrorOnly { + return errorLog[lastErrIndex] + } + return errorLog +} + +// Error type represents list of errors in retry +type Error []error + +// Error method return string representation of Error +// It is an implementation of error interface +func (e Error) Error() string { + logWithNumber := make([]string, lenWithoutNil(e)) + for i, l := range e { + if l != nil { + logWithNumber[i] = fmt.Sprintf("#%d: %s", i+1, l.Error()) + } + } + + return fmt.Sprintf("All attempts fail:\n%s", strings.Join(logWithNumber, "\n")) +} + +func lenWithoutNil(e Error) (count int) { + for _, v := range e { + if v != nil { + count++ + } + } + + return +} + +// WrappedErrors returns the list of errors that this Error is wrapping. +// It is an implementation of the `errwrap.Wrapper` interface +// in package [errwrap](https://github.com/hashicorp/errwrap) so that +// `retry.Error` can be used with that library. +func (e Error) WrappedErrors() []error { + return e +} + +type unrecoverableError struct { + error +} + +// Unrecoverable wraps an error in `unrecoverableError` struct +func Unrecoverable(err error) error { + return unrecoverableError{err} +} + +// IsRecoverable checks if error is an instance of `unrecoverableError` +func IsRecoverable(err error) bool { + _, isUnrecoverable := err.(unrecoverableError) + return !isUnrecoverable +} + +func unpackUnrecoverable(err error) error { + if unrecoverable, isUnrecoverable := err.(unrecoverableError); isUnrecoverable { + return unrecoverable.error + } + + return err +} diff --git a/vendor/github.com/gorilla/websocket/.gitignore b/vendor/github.com/gorilla/websocket/.gitignore new file mode 100644 index 0000000..cd3fcd1 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/.gitignore @@ -0,0 +1,25 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe + +.idea/ +*.iml diff --git a/vendor/github.com/gorilla/websocket/AUTHORS b/vendor/github.com/gorilla/websocket/AUTHORS new file mode 100644 index 0000000..1931f40 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/AUTHORS @@ -0,0 +1,9 @@ +# This is the official list of Gorilla WebSocket authors for copyright +# purposes. +# +# Please keep the list sorted. + +Gary Burd +Google LLC (https://opensource.google.com/) +Joachim Bauch + diff --git a/vendor/github.com/gorilla/websocket/LICENSE b/vendor/github.com/gorilla/websocket/LICENSE new file mode 100644 index 0000000..9171c97 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/gorilla/websocket/README.md b/vendor/github.com/gorilla/websocket/README.md new file mode 100644 index 0000000..19aa2e7 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/README.md @@ -0,0 +1,64 @@ +# Gorilla WebSocket + +[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) +[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) + +Gorilla WebSocket is a [Go](http://golang.org/) implementation of the +[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. + +### Documentation + +* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) +* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) +* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) +* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) +* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) + +### Status + +The Gorilla WebSocket package provides a complete and tested implementation of +the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The +package API is stable. + +### Installation + + go get github.com/gorilla/websocket + +### Protocol Compliance + +The Gorilla WebSocket package passes the server tests in the [Autobahn Test +Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn +subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). + +### Gorilla WebSocket compared with other packages + + + + + + + + + + + + + + + + + + +
github.com/gorillagolang.org/x/net
RFC 6455 Features
Passes Autobahn Test SuiteYesNo
Receive fragmented messageYesNo, see note 1
Send close messageYesNo
Send pings and receive pongsYesNo
Get the type of a received data messageYesYes, see note 2
Other Features
Compression ExtensionsExperimentalNo
Read message using io.ReaderYesNo, see note 3
Write message using io.WriteCloserYesNo, see note 3
+ +Notes: + +1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). +2. The application can get the type of a received data message by implementing + a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal) + function. +3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries. + Read returns when the input buffer is full or a frame boundary is + encountered. Each call to Write sends a single frame message. The Gorilla + io.Reader and io.WriteCloser operate on a single WebSocket message. + diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go new file mode 100644 index 0000000..962c06a --- /dev/null +++ b/vendor/github.com/gorilla/websocket/client.go @@ -0,0 +1,395 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptrace" + "net/url" + "strings" + "time" +) + +// ErrBadHandshake is returned when the server response to opening handshake is +// invalid. +var ErrBadHandshake = errors.New("websocket: bad handshake") + +var errInvalidCompression = errors.New("websocket: invalid compression negotiation") + +// NewClient creates a new client connection using the given net connection. +// The URL u specifies the host and request URI. Use requestHeader to specify +// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies +// (Cookie). Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etc. +// +// Deprecated: Use Dialer instead. +func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { + d := Dialer{ + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + NetDial: func(net, addr string) (net.Conn, error) { + return netConn, nil + }, + } + return d.Dial(u.String(), requestHeader) +} + +// A Dialer contains options for connecting to WebSocket server. +type Dialer struct { + // NetDial specifies the dial function for creating TCP connections. If + // NetDial is nil, net.Dial is used. + NetDial func(network, addr string) (net.Conn, error) + + // NetDialContext specifies the dial function for creating TCP connections. If + // NetDialContext is nil, net.DialContext is used. + NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + + // TLSClientConfig specifies the TLS configuration to use with tls.Client. + // If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then a useful default size is used. The I/O buffer sizes + // do not limit the size of the messages that can be sent or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // Subprotocols specifies the client's requested subprotocols. + Subprotocols []string + + // EnableCompression specifies if the client should attempt to negotiate + // per message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool + + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + Jar http.CookieJar +} + +// Dial creates a new client connection by calling DialContext with a background context. +func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + return d.DialContext(context.Background(), urlStr, requestHeader) +} + +var errMalformedURL = errors.New("malformed ws or wss URL") + +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { + hostPort = u.Host + hostNoPort = u.Host + if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { + hostNoPort = hostNoPort[:i] + } else { + switch u.Scheme { + case "wss": + hostPort += ":443" + case "https": + hostPort += ":443" + default: + hostPort += ":80" + } + } + return hostPort, hostNoPort +} + +// DefaultDialer is a dialer with all fields set to the default values. +var DefaultDialer = &Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, +} + +// nilDialer is dialer to use when receiver is nil. +var nilDialer = *DefaultDialer + +// DialContext creates a new client connection. Use requestHeader to specify the +// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). +// Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// The context will be used in the request and in the Dialer. +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etcetera. The response body may not contain the entire response and does not +// need to be closed by the application. +func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + if d == nil { + d = &nilDialer + } + + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, nil, err + } + + u, err := url.Parse(urlStr) + if err != nil { + return nil, nil, err + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL + } + + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + req = req.WithContext(ctx) + + // Set the cookies present in the cookie jar of the dialer + if d.Jar != nil { + for _, cookie := range d.Jar.Cookies(u) { + req.AddCookie(cookie) + } + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if len(d.Subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} + } + for k, vs := range requestHeader { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || + (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + case k == "Sec-Websocket-Protocol": + req.Header["Sec-WebSocket-Protocol"] = vs + default: + req.Header[k] = vs + } + } + + if d.EnableCompression { + req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} + } + + if d.HandshakeTimeout != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) + defer cancel() + } + + // Get network dial function. + var netDial func(network, add string) (net.Conn, error) + + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } else { + netDialer := &net.Dialer{} + netDial = func(network, addr string) (net.Conn, error) { + return netDialer.DialContext(ctx, network, addr) + } + } + + // If needed, wrap the dial function to set the connection deadline. + if deadline, ok := ctx.Deadline(); ok { + forwardDial := netDial + netDial = func(network, addr string) (net.Conn, error) { + c, err := forwardDial(network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + c.Close() + return nil, err + } + return c, nil + } + } + + // If needed, wrap the dial function to connect through a proxy. + if d.Proxy != nil { + proxyURL, err := d.Proxy(req) + if err != nil { + return nil, nil, err + } + if proxyURL != nil { + dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) + if err != nil { + return nil, nil, err + } + netDial = dialer.Dial + } + } + + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + + netConn, err := netDial("tcp", hostPort) + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: netConn, + }) + } + if err != nil { + return nil, nil, err + } + + defer func() { + if netConn != nil { + netConn.Close() + } + }() + + if u.Scheme == "https" { + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn + + var err error + if trace != nil { + err = doHandshakeWithTrace(trace, tlsConn, cfg) + } else { + err = doHandshake(tlsConn, cfg) + } + + if err != nil { + return nil, nil, err + } + } + + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) + + if err := req.Write(netConn); err != nil { + return nil, nil, err + } + + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + + resp, err := http.ReadResponse(conn.br, req) + if err != nil { + return nil, nil, err + } + + if d.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + d.Jar.SetCookies(u, rc) + } + } + + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, ErrBadHandshake + } + + for _, ext := range parseExtensions(resp.Header) { + if ext[""] != "permessage-deflate" { + continue + } + _, snct := ext["server_no_context_takeover"] + _, cnct := ext["client_no_context_takeover"] + if !snct || !cnct { + return nil, resp, errInvalidCompression + } + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover + break + } + + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") + + netConn.SetDeadline(time.Time{}) + netConn = nil // to avoid close in defer. + return conn, resp, nil +} + +func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.Handshake(); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/gorilla/websocket/client_clone.go b/vendor/github.com/gorilla/websocket/client_clone.go new file mode 100644 index 0000000..4f0d943 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/client_clone.go @@ -0,0 +1,16 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package websocket + +import "crypto/tls" + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} diff --git a/vendor/github.com/gorilla/websocket/client_clone_legacy.go b/vendor/github.com/gorilla/websocket/client_clone_legacy.go new file mode 100644 index 0000000..babb007 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/client_clone_legacy.go @@ -0,0 +1,38 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.8 + +package websocket + +import "crypto/tls" + +// cloneTLSConfig clones all public fields except the fields +// SessionTicketsDisabled and SessionTicketKey. This avoids copying the +// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a +// config in active use. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + } +} diff --git a/vendor/github.com/gorilla/websocket/compression.go b/vendor/github.com/gorilla/websocket/compression.go new file mode 100644 index 0000000..813ffb1 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/compression.go @@ -0,0 +1,148 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "compress/flate" + "errors" + "io" + "strings" + "sync" +) + +const ( + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 +) + +var ( + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { + return flate.NewReader(nil) + }} +) + +func decompressNoContextTakeover(r io.Reader) io.ReadCloser { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + fr, _ := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + return &flateReadWrapper{fr} +} + +func isValidCompressionLevel(level int) bool { + return minCompressionLevel <= level && level <= maxCompressionLevel +} + +func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] + tw := &truncWriter{w: w} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWriteWrapper struct { + fw *flate.Writer + tw *truncWriter + p *sync.Pool +} + +func (w *flateWriteWrapper) Write(p []byte) (int, error) { + if w.fw == nil { + return 0, errWriteClosed + } + return w.fw.Write(p) +} + +func (w *flateWriteWrapper) Close() error { + if w.fw == nil { + return errWriteClosed + } + err1 := w.fw.Flush() + w.p.Put(w.fw) + w.fw = nil + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} diff --git a/vendor/github.com/gorilla/websocket/conn.go b/vendor/github.com/gorilla/websocket/conn.go new file mode 100644 index 0000000..ca46d2f --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn.go @@ -0,0 +1,1201 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "strconv" + "sync" + "time" + "unicode/utf8" +) + +const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + + writeWait = time.Second + + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + + continuationFrame = 0 + noFrame = -1 +) + +// Close codes defined in RFC 6455, section 11.7. +const ( + CloseNormalClosure = 1000 + CloseGoingAway = 1001 + CloseProtocolError = 1002 + CloseUnsupportedData = 1003 + CloseNoStatusReceived = 1005 + CloseAbnormalClosure = 1006 + CloseInvalidFramePayloadData = 1007 + ClosePolicyViolation = 1008 + CloseMessageTooBig = 1009 + CloseMandatoryExtension = 1010 + CloseInternalServerErr = 1011 + CloseServiceRestart = 1012 + CloseTryAgainLater = 1013 + CloseTLSHandshake = 1015 +) + +// The message types are defined in RFC 6455, section 11.8. +const ( + // TextMessage denotes a text data message. The text message payload is + // interpreted as UTF-8 encoded text data. + TextMessage = 1 + + // BinaryMessage denotes a binary data message. + BinaryMessage = 2 + + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a pong control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("websocket: close sent") + +// ErrReadLimit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("websocket: read limit exceeded") + +// netError satisfies the net Error interface. +type netError struct { + msg string + temporary bool + timeout bool +} + +func (e *netError) Error() string { return e.msg } +func (e *netError) Temporary() bool { return e.temporary } +func (e *netError) Timeout() bool { return e.timeout } + +// CloseError represents a close message. +type CloseError struct { + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string +} + +func (e *CloseError) Error() string { + s := []byte("websocket: close ") + s = strconv.AppendInt(s, int64(e.Code), 10) + switch e.Code { + case CloseNormalClosure: + s = append(s, " (normal)"...) + case CloseGoingAway: + s = append(s, " (going away)"...) + case CloseProtocolError: + s = append(s, " (protocol error)"...) + case CloseUnsupportedData: + s = append(s, " (unsupported data)"...) + case CloseNoStatusReceived: + s = append(s, " (no status)"...) + case CloseAbnormalClosure: + s = append(s, " (abnormal closure)"...) + case CloseInvalidFramePayloadData: + s = append(s, " (invalid payload data)"...) + case ClosePolicyViolation: + s = append(s, " (policy violation)"...) + case CloseMessageTooBig: + s = append(s, " (message too big)"...) + case CloseMandatoryExtension: + s = append(s, " (mandatory extension missing)"...) + case CloseInternalServerErr: + s = append(s, " (internal server error)"...) + case CloseTLSHandshake: + s = append(s, " (TLS handshake error)"...) + } + if e.Text != "" { + s = append(s, ": "...) + s = append(s, e.Text...) + } + return string(s) +} + +// IsCloseError returns boolean indicating whether the error is a *CloseError +// with one of the specified codes. +func IsCloseError(err error, codes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range codes { + if e.Code == code { + return true + } + } + } + return false +} + +// IsUnexpectedCloseError returns boolean indicating whether the error is a +// *CloseError with a code not in the list of expected codes. +func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range expectedCodes { + if e.Code == code { + return false + } + } + return true + } + return false +} + +var ( + errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} + errBadWriteOpCode = errors.New("websocket: bad write message type") + errWriteClosed = errors.New("websocket: write closed") + errInvalidControlFrame = errors.New("websocket: invalid control frame") +) + +func newMaskKey() [4]byte { + n := rand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} +} + +func hideTempErr(err error) error { + if e, ok := err.(net.Error); ok && e.Temporary() { + err = &netError{msg: e.Error(), timeout: e.Timeout()} + } + return err +} + +func isControl(frameType int) bool { + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage +} + +func isData(frameType int) bool { + return frameType == TextMessage || frameType == BinaryMessage +} + +var validReceivedCloseCodes = map[int]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + + CloseNormalClosure: true, + CloseGoingAway: true, + CloseProtocolError: true, + CloseUnsupportedData: true, + CloseNoStatusReceived: false, + CloseAbnormalClosure: false, + CloseInvalidFramePayloadData: true, + ClosePolicyViolation: true, + CloseMessageTooBig: true, + CloseMandatoryExtension: true, + CloseInternalServerErr: true, + CloseServiceRestart: true, + CloseTryAgainLater: true, + CloseTLSHandshake: false, +} + +func isValidReceivedCloseCode(code int) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + +// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this +// interface. The type of the value stored in a pool is not specified. +type BufferPool interface { + // Get gets a value from the pool or returns nil if the pool is empty. + Get() interface{} + // Put adds a value to the pool. + Put(interface{}) +} + +// writePoolData is the type added to the write buffer pool. This wrapper is +// used to prevent applications from peeking at and depending on the values +// added to the pool. +type writePoolData struct{ buf []byte } + +// The Conn type represents a WebSocket connection. +type Conn struct { + conn net.Conn + isServer bool + subprotocol string + + // Write fields + mu chan struct{} // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writePool BufferPool + writeBufSize int + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error + + enableWriteCompression bool + compressionLevel int + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser + + // Read fields + reader io.ReadCloser // the current reader returned to the application + readErr error + br *bufio.Reader + // bytes remaining in current frame. + // set setReadRemaining to safely update this value and prevent overflow + readRemaining int64 + readFinal bool // true the current message has more frames. + readLength int64 // Message size. + readLimit int64 // Maximum message size. + readMaskPos int + readMaskKey [4]byte + handlePong func(string) error + handlePing func(string) error + handleClose func(int, string) error + readErrCount int + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.ReadCloser +} + +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { + + if br == nil { + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } else if readBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame + readBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(conn, readBufferSize) + } + + if writeBufferSize <= 0 { + writeBufferSize = defaultWriteBufferSize + } + writeBufferSize += maxFrameHeaderSize + + if writeBuf == nil && writeBufferPool == nil { + writeBuf = make([]byte, writeBufferSize) + } + + mu := make(chan struct{}, 1) + mu <- struct{}{} + c := &Conn{ + isServer: isServer, + br: br, + conn: conn, + mu: mu, + readFinal: true, + writeBuf: writeBuf, + writePool: writeBufferPool, + writeBufSize: writeBufferSize, + enableWriteCompression: true, + compressionLevel: defaultCompressionLevel, + } + c.SetCloseHandler(nil) + c.SetPingHandler(nil) + c.SetPongHandler(nil) + return c +} + +// setReadRemaining tracks the number of bytes remaining on the connection. If n +// overflows, an ErrReadLimit is returned. +func (c *Conn) setReadRemaining(n int64) error { + if n < 0 { + return ErrReadLimit + } + + c.readRemaining = n + return nil +} + +// Subprotocol returns the negotiated protocol for the connection. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +// Close closes the underlying network connection without sending or waiting +// for a close message. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// Write methods + +func (c *Conn) writeFatal(err error) error { + err = hideTempErr(err) + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + c.br.Discard(len(p)) + return p, err +} + +func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { + <-c.mu + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + if len(buf1) == 0 { + _, err = c.conn.Write(buf0) + } else { + err = c.writeBufs(buf0, buf1) + } + if err != nil { + return c.writeFatal(err) + } + if frameType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return nil +} + +// WriteControl writes a control message with the given deadline. The allowed +// message types are CloseMessage, PingMessage and PongMessage. +func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if !isControl(messageType) { + return errBadWriteOpCode + } + if len(data) > maxControlFramePayloadSize { + return errInvalidControlFrame + } + + b0 := byte(messageType) | finalBit + b1 := byte(len(data)) + if !c.isServer { + b1 |= maskBit + } + + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) + buf = append(buf, b0, b1) + + if c.isServer { + buf = append(buf, data...) + } else { + key := newMaskKey() + buf = append(buf, key[:]...) + buf = append(buf, data...) + maskBytes(key, 0, buf[6:]) + } + + d := 1000 * time.Hour + if !deadline.IsZero() { + d = deadline.Sub(time.Now()) + if d < 0 { + return errWriteTimeout + } + } + + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + _, err = c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) + } + if messageType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return err +} + +// beginMessage prepares a connection and message writer for a new message. +func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + c.writer.Close() + c.writer = nil + } + + if !isControl(messageType) && !isData(messageType) { + return errBadWriteOpCode + } + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + mw.c = c + mw.frameType = messageType + mw.pos = maxFrameHeaderSize + + if c.writeBuf == nil { + wpd, ok := c.writePool.Get().(writePoolData) + if ok { + c.writeBuf = wpd.buf + } else { + c.writeBuf = make([]byte, c.writeBufSize) + } + } + return nil +} + +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +// +// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and +// PongMessage) are supported. +func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return nil, err + } + c.writer = &mw + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w := c.newCompressionWriter(c.writer, c.compressionLevel) + mw.compress = true + c.writer = w + } + return c.writer, nil +} + +type messageWriter struct { + c *Conn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} + +func (w *messageWriter) endMessage(err error) error { + if w.err != nil { + return err + } + c := w.c + w.err = err + c.writer = nil + if c.writePool != nil { + c.writePool.Put(writePoolData{buf: c.writeBuf}) + c.writeBuf = nil + } + return err +} + +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) + + // Check for invalid control frames. + if isControl(w.frameType) && + (!final || length > maxControlFramePayloadSize) { + return w.endMessage(errInvalidControlFrame) + } + + b0 := byte(w.frameType) + if final { + b0 |= finalBit + } + if w.compress { + b0 |= rsv1Bit + } + w.compress = false + + b1 := byte(0) + if !c.isServer { + b1 |= maskBit + } + + // Assume that the frame starts at beginning of c.writeBuf. + framePos := 0 + if c.isServer { + // Adjust up if mask not included in the header. + framePos = 4 + } + + switch { + case length >= 65536: + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 127 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) + case length > 125: + framePos += 6 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 126 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) + default: + framePos += 8 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | byte(length) + } + + if !c.isServer { + key := newMaskKey() + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) + if len(extra) > 0 { + return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) + } + } + + // Write the buffers to the connection with best-effort detection of + // concurrent writes. See the concurrency section in the package + // documentation for more info. + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) + + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + + if err != nil { + return w.endMessage(err) + } + + if final { + w.endMessage(errWriteClosed) + return nil + } + + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame + return nil +} + +func (w *messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.pos + if n <= 0 { + if err := w.flushFrame(false, nil); err != nil { + return 0, err + } + n = len(w.c.writeBuf) - w.pos + } + if n > max { + n = max + } + return n, nil +} + +func (w *messageWriter) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { + // Don't buffer large messages. + err := w.flushFrame(false, p) + if err != nil { + return 0, err + } + return len(p), nil + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) WriteString(p string) (int, error) { + if w.err != nil { + return 0, w.err + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if w.err != nil { + return 0, w.err + } + for { + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) + if err != nil { + break + } + } + var n int + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n + nn += int64(n) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + return nn, err +} + +func (w *messageWriter) Close() error { + if w.err != nil { + return w.err + } + return w.flushFrame(true, nil) +} + +// WritePreparedMessage writes prepared message into connection. +func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { + frameType, frameData, err := pm.frame(prepareKey{ + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + return err +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *Conn) WriteMessage(messageType int, data []byte) error { + + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { + // Fast path with no allocations and single frame. + + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return err + } + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + return mw.flushFrame(true, data) + } + + w, err := c.NextWriter(messageType) + if err != nil { + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +// SetWriteDeadline sets the write deadline on the underlying network +// connection. After a write has timed out, the websocket state is corrupt and +// all future writes will return an error. A zero value for t means writes will +// not time out. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +// Read methods + +func (c *Conn) advanceFrame() (int, error) { + // 1. Skip remainder of previous frame. + + if c.readRemaining > 0 { + if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { + return noFrame, err + } + } + + // 2. Read and parse first two bytes of frame header. + + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + final := p[0]&finalBit != 0 + frameType := int(p[0] & 0xf) + mask := p[1]&maskBit != 0 + c.setReadRemaining(int64(p[1] & 0x7f)) + + c.readDecompress = false + if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { + c.readDecompress = true + p[0] &^= rsv1Bit + } + + if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { + return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) + } + + switch frameType { + case CloseMessage, PingMessage, PongMessage: + if c.readRemaining > maxControlFramePayloadSize { + return noFrame, c.handleProtocolError("control frame length > 125") + } + if !final { + return noFrame, c.handleProtocolError("control frame not final") + } + case TextMessage, BinaryMessage: + if !c.readFinal { + return noFrame, c.handleProtocolError("message start before final message frame") + } + c.readFinal = final + case continuationFrame: + if c.readFinal { + return noFrame, c.handleProtocolError("continuation after final message frame") + } + c.readFinal = final + default: + return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + } + + // 3. Read and parse frame length as per + // https://tools.ietf.org/html/rfc6455#section-5.2 + // + // The length of the "Payload data", in bytes: if 0-125, that is the payload + // length. + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned + // integer are the payload length. + // - If 127, the following 8 bytes interpreted as + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the + // payload length. Multibyte length quantities are expressed in network byte + // order. + + switch c.readRemaining { + case 126: + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { + return noFrame, err + } + case 127: + p, err := c.read(8) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { + return noFrame, err + } + } + + // 4. Handle frame masking. + + if mask != c.isServer { + return noFrame, c.handleProtocolError("incorrect mask flag") + } + + if mask { + c.readMaskPos = 0 + p, err := c.read(len(c.readMaskKey)) + if err != nil { + return noFrame, err + } + copy(c.readMaskKey[:], p) + } + + // 5. For text and binary messages, enforce read limit and return. + + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + + c.readLength += c.readRemaining + // Don't allow readLength to overflow in the presence of a large readRemaining + // counter. + if c.readLength < 0 { + return noFrame, ErrReadLimit + } + + if c.readLimit > 0 && c.readLength > c.readLimit { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + return noFrame, ErrReadLimit + } + + return frameType, nil + } + + // 6. Read control frame payload. + + var payload []byte + if c.readRemaining > 0 { + payload, err = c.read(int(c.readRemaining)) + c.setReadRemaining(0) + if err != nil { + return noFrame, err + } + if c.isServer { + maskBytes(c.readMaskKey, 0, payload) + } + } + + // 7. Process control frame payload. + + switch frameType { + case PongMessage: + if err := c.handlePong(string(payload)); err != nil { + return noFrame, err + } + case PingMessage: + if err := c.handlePing(string(payload)); err != nil { + return noFrame, err + } + case CloseMessage: + closeCode := CloseNoStatusReceived + closeText := "" + if len(payload) >= 2 { + closeCode = int(binary.BigEndian.Uint16(payload)) + if !isValidReceivedCloseCode(closeCode) { + return noFrame, c.handleProtocolError("invalid close code") + } + closeText = string(payload[2:]) + if !utf8.ValidString(closeText) { + return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") + } + } + if err := c.handleClose(closeCode, closeText); err != nil { + return noFrame, err + } + return noFrame, &CloseError{Code: closeCode, Text: closeText} + } + + return frameType, nil +} + +func (c *Conn) handleProtocolError(message string) error { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + return errors.New("websocket: " + message) +} + +// NextReader returns the next data message received from the peer. The +// returned messageType is either TextMessage or BinaryMessage. +// +// There can be at most one open reader on a connection. NextReader discards +// the previous message if the application has not already consumed it. +// +// Applications must break out of the application's read loop when this method +// returns a non-nil error value. Errors returned from this method are +// permanent. Once this method returns a non-nil error, all subsequent calls to +// this method return the same error. +func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } + + c.messageReader = nil + c.readLength = 0 + + for c.readErr == nil { + frameType, err := c.advanceFrame() + if err != nil { + c.readErr = hideTempErr(err) + break + } + + if frameType == TextMessage || frameType == BinaryMessage { + c.messageReader = &messageReader{c} + c.reader = c.messageReader + if c.readDecompress { + c.reader = c.newDecompressionReader(c.reader) + } + return frameType, c.reader, nil + } + } + + // Applications that do handle the error returned from this method spin in + // tight loop on connection failure. To help application developers detect + // this error, panic on repeated reads to the failed connection. + c.readErrCount++ + if c.readErrCount >= 1000 { + panic("repeated read on failed websocket connection") + } + + return noFrame, nil, c.readErr +} + +type messageReader struct{ c *Conn } + +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { + return 0, io.EOF + } + + for c.readErr == nil { + + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] + } + n, err := c.br.Read(b) + c.readErr = hideTempErr(err) + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) + } + rem := c.readRemaining + rem -= int64(n) + c.setReadRemaining(rem) + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF + } + return n, c.readErr + } + + if c.readFinal { + c.messageReader = nil + return 0, io.EOF + } + + frameType, err := c.advanceFrame() + switch { + case err != nil: + c.readErr = hideTempErr(err) + case frameType == TextMessage || frameType == BinaryMessage: + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + } + } + + err := c.readErr + if err == io.EOF && c.messageReader == r { + err = errUnexpectedEOF + } + return 0, err +} + +func (r *messageReader) Close() error { + return nil +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = c.NextReader() + if err != nil { + return messageType, nil, err + } + p, err = ioutil.ReadAll(r) + return messageType, p, err +} + +// SetReadDeadline sets the read deadline on the underlying network connection. +// After a read has timed out, the websocket connection state is corrupt and +// all future reads will return an error. A zero value for t means reads will +// not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a +// message exceeds the limit, the connection sends a close message to the peer +// and returns ErrReadLimit to the application. +func (c *Conn) SetReadLimit(limit int64) { + c.readLimit = limit +} + +// CloseHandler returns the current close handler +func (c *Conn) CloseHandler() func(code int, text string) error { + return c.handleClose +} + +// SetCloseHandler sets the handler for close messages received from the peer. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close +// message back to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// close messages as described in the section on Control Messages above. +// +// The connection read methods return a CloseError when a close message is +// received. Most applications should handle close messages as part of their +// normal error handling. Applications should only set a close handler when the +// application must perform some action before sending a close message back to +// the peer. +func (c *Conn) SetCloseHandler(h func(code int, text string) error) { + if h == nil { + h = func(code int, text string) error { + message := FormatCloseMessage(code, "") + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + return nil + } + } + c.handleClose = h +} + +// PingHandler returns the current ping handler +func (c *Conn) PingHandler() func(appData string) error { + return c.handlePing +} + +// SetPingHandler sets the handler for ping messages received from the peer. +// The appData argument to h is the PING message application data. The default +// ping handler sends a pong to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// ping messages as described in the section on Control Messages above. +func (c *Conn) SetPingHandler(h func(appData string) error) { + if h == nil { + h = func(message string) error { + err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + if err == ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + } + } + c.handlePing = h +} + +// PongHandler returns the current pong handler +func (c *Conn) PongHandler() func(appData string) error { + return c.handlePong +} + +// SetPongHandler sets the handler for pong messages received from the peer. +// The appData argument to h is the PONG message application data. The default +// pong handler does nothing. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// pong messages as described in the section on Control Messages above. +func (c *Conn) SetPongHandler(h func(appData string) error) { + if h == nil { + h = func(string) error { return nil } + } + c.handlePong = h +} + +// UnderlyingConn returns the internal net.Conn. This can be used to further +// modifications to connection specific flags. +func (c *Conn) UnderlyingConn() net.Conn { + return c.conn +} + +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + +// SetCompressionLevel sets the flate compression level for subsequent text and +// binary messages. This function is a noop if compression was not negotiated +// with the peer. See the compress/flate package for a description of +// compression levels. +func (c *Conn) SetCompressionLevel(level int) error { + if !isValidCompressionLevel(level) { + return errors.New("websocket: invalid compression level") + } + c.compressionLevel = level + return nil +} + +// FormatCloseMessage formats closeCode and text as a WebSocket close message. +// An empty message is returned for code CloseNoStatusReceived. +func FormatCloseMessage(closeCode int, text string) []byte { + if closeCode == CloseNoStatusReceived { + // Return empty message because it's illegal to send + // CloseNoStatusReceived. Return non-nil value in case application + // checks for nil. + return []byte{} + } + buf := make([]byte, 2+len(text)) + binary.BigEndian.PutUint16(buf, uint16(closeCode)) + copy(buf[2:], text) + return buf +} diff --git a/vendor/github.com/gorilla/websocket/conn_write.go b/vendor/github.com/gorilla/websocket/conn_write.go new file mode 100644 index 0000000..a509a21 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn_write.go @@ -0,0 +1,15 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package websocket + +import "net" + +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err +} diff --git a/vendor/github.com/gorilla/websocket/conn_write_legacy.go b/vendor/github.com/gorilla/websocket/conn_write_legacy.go new file mode 100644 index 0000000..37edaff --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn_write_legacy.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.8 + +package websocket + +func (c *Conn) writeBufs(bufs ...[]byte) error { + for _, buf := range bufs { + if len(buf) > 0 { + if _, err := c.conn.Write(buf); err != nil { + return err + } + } + } + return nil +} diff --git a/vendor/github.com/gorilla/websocket/doc.go b/vendor/github.com/gorilla/websocket/doc.go new file mode 100644 index 0000000..8db0cef --- /dev/null +++ b/vendor/github.com/gorilla/websocket/doc.go @@ -0,0 +1,227 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package websocket implements the WebSocket protocol defined in RFC 6455. +// +// Overview +// +// The Conn type represents a WebSocket connection. A server application calls +// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: +// +// var upgrader = websocket.Upgrader{ +// ReadBufferSize: 1024, +// WriteBufferSize: 1024, +// } +// +// func handler(w http.ResponseWriter, r *http.Request) { +// conn, err := upgrader.Upgrade(w, r, nil) +// if err != nil { +// log.Println(err) +// return +// } +// ... Use conn to send and receive messages. +// } +// +// Call the connection's WriteMessage and ReadMessage methods to send and +// receive messages as a slice of bytes. This snippet of code shows how to echo +// messages using these methods: +// +// for { +// messageType, p, err := conn.ReadMessage() +// if err != nil { +// log.Println(err) +// return +// } +// if err := conn.WriteMessage(messageType, p); err != nil { +// log.Println(err) +// return +// } +// } +// +// In above snippet of code, p is a []byte and messageType is an int with value +// websocket.BinaryMessage or websocket.TextMessage. +// +// An application can also send and receive messages using the io.WriteCloser +// and io.Reader interfaces. To send a message, call the connection NextWriter +// method to get an io.WriteCloser, write the message to the writer and close +// the writer when done. To receive a message, call the connection NextReader +// method to get an io.Reader and read until io.EOF is returned. This snippet +// shows how to echo messages using the NextWriter and NextReader methods: +// +// for { +// messageType, r, err := conn.NextReader() +// if err != nil { +// return +// } +// w, err := conn.NextWriter(messageType) +// if err != nil { +// return err +// } +// if _, err := io.Copy(w, r); err != nil { +// return err +// } +// if err := w.Close(); err != nil { +// return err +// } +// } +// +// Data Messages +// +// The WebSocket protocol distinguishes between text and binary data messages. +// Text messages are interpreted as UTF-8 encoded text. The interpretation of +// binary messages is left to the application. +// +// This package uses the TextMessage and BinaryMessage integer constants to +// identify the two data message types. The ReadMessage and NextReader methods +// return the type of the received message. The messageType argument to the +// WriteMessage and NextWriter methods specifies the type of a sent message. +// +// It is the application's responsibility to ensure that text messages are +// valid UTF-8 encoded text. +// +// Control Messages +// +// The WebSocket protocol defines three types of control messages: close, ping +// and pong. Call the connection WriteControl, WriteMessage or NextWriter +// methods to send a control message to the peer. +// +// Connections handle received close messages by calling the handler function +// set with the SetCloseHandler method and by returning a *CloseError from the +// NextReader, ReadMessage or the message Read method. The default close +// handler sends a close message to the peer. +// +// Connections handle received ping messages by calling the handler function +// set with the SetPingHandler method. The default ping handler sends a pong +// message to the peer. +// +// Connections handle received pong messages by calling the handler function +// set with the SetPongHandler method. The default pong handler does nothing. +// If an application sends ping messages, then the application should set a +// pong handler to receive the corresponding pong. +// +// The control message handler functions are called from the NextReader, +// ReadMessage and message reader Read methods. The default close and ping +// handlers can block these methods for a short time when the handler writes to +// the connection. +// +// The application must read the connection to process close, ping and pong +// messages sent from the peer. If the application is not otherwise interested +// in messages from the peer, then the application should start a goroutine to +// read and discard messages from the peer. A simple example is: +// +// func readLoop(c *websocket.Conn) { +// for { +// if _, _, err := c.NextReader(); err != nil { +// c.Close() +// break +// } +// } +// } +// +// Concurrency +// +// Connections support one concurrent reader and one concurrent writer. +// +// Applications are responsible for ensuring that no more than one goroutine +// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, +// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and +// that no more than one goroutine calls the read methods (NextReader, +// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) +// concurrently. +// +// The Close and WriteControl methods can be called concurrently with all other +// methods. +// +// Origin Considerations +// +// Web browsers allow Javascript applications to open a WebSocket connection to +// any host. It's up to the server to enforce an origin policy using the Origin +// request header sent by the browser. +// +// The Upgrader calls the function specified in the CheckOrigin field to check +// the origin. If the CheckOrigin function returns false, then the Upgrade +// method fails the WebSocket handshake with HTTP status 403. +// +// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail +// the handshake if the Origin request header is present and the Origin host is +// not equal to the Host request header. +// +// The deprecated package-level Upgrade function does not perform origin +// checking. The application is responsible for checking the Origin header +// before calling the Upgrade function. +// +// Buffers +// +// Connections buffer network input and output to reduce the number +// of system calls when reading or writing messages. +// +// Write buffers are also used for constructing WebSocket frames. See RFC 6455, +// Section 5 for a discussion of message framing. A WebSocket frame header is +// written to the network each time a write buffer is flushed to the network. +// Decreasing the size of the write buffer can increase the amount of framing +// overhead on the connection. +// +// The buffer sizes in bytes are specified by the ReadBufferSize and +// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default +// size of 4096 when a buffer size field is set to zero. The Upgrader reuses +// buffers created by the HTTP server when a buffer size field is set to zero. +// The HTTP server buffers have a size of 4096 at the time of this writing. +// +// The buffer sizes do not limit the size of a message that can be read or +// written by a connection. +// +// Buffers are held for the lifetime of the connection by default. If the +// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the +// write buffer only when writing a message. +// +// Applications should tune the buffer sizes to balance memory use and +// performance. Increasing the buffer size uses more memory, but can reduce the +// number of system calls to read or write the network. In the case of writing, +// increasing the buffer size can reduce the number of frame headers written to +// the network. +// +// Some guidelines for setting buffer parameters are: +// +// Limit the buffer sizes to the maximum expected message size. Buffers larger +// than the largest message do not provide any benefit. +// +// Depending on the distribution of message sizes, setting the buffer size to +// a value less than the maximum expected message size can greatly reduce memory +// use with a small impact on performance. Here's an example: If 99% of the +// messages are smaller than 256 bytes and the maximum message size is 512 +// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls +// than a buffer size of 512 bytes. The memory savings is 50%. +// +// A write buffer pool is useful when the application has a modest number +// writes over a large number of connections. when buffers are pooled, a larger +// buffer size has a reduced impact on total memory use and has the benefit of +// reducing system calls and frame overhead. +// +// Compression EXPERIMENTAL +// +// Per message compression extensions (RFC 7692) are experimentally supported +// by this package in a limited capacity. Setting the EnableCompression option +// to true in Dialer or Upgrader will attempt to negotiate per message deflate +// support. +// +// var upgrader = websocket.Upgrader{ +// EnableCompression: true, +// } +// +// If compression was successfully negotiated with the connection's peer, any +// message received in compressed form will be automatically decompressed. +// All Read methods will return uncompressed bytes. +// +// Per message compression of messages written to a connection can be enabled +// or disabled by calling the corresponding Conn method: +// +// conn.EnableWriteCompression(false) +// +// Currently this package does not support compression with "context takeover". +// This means that messages must be compressed and decompressed in isolation, +// without retaining sliding window or dictionary state across messages. For +// more details refer to RFC 7692. +// +// Use of compression is experimental and may result in decreased performance. +package websocket diff --git a/vendor/github.com/gorilla/websocket/go.mod b/vendor/github.com/gorilla/websocket/go.mod new file mode 100644 index 0000000..1a7afd5 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/go.mod @@ -0,0 +1,3 @@ +module github.com/gorilla/websocket + +go 1.12 diff --git a/vendor/github.com/gorilla/websocket/go.sum b/vendor/github.com/gorilla/websocket/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/vendor/github.com/gorilla/websocket/join.go b/vendor/github.com/gorilla/websocket/join.go new file mode 100644 index 0000000..c64f8c8 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/join.go @@ -0,0 +1,42 @@ +// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "io" + "strings" +) + +// JoinMessages concatenates received messages to create a single io.Reader. +// The string term is appended to each message. The returned reader does not +// support concurrent calls to the Read method. +func JoinMessages(c *Conn, term string) io.Reader { + return &joinReader{c: c, term: term} +} + +type joinReader struct { + c *Conn + term string + r io.Reader +} + +func (r *joinReader) Read(p []byte) (int, error) { + if r.r == nil { + var err error + _, r.r, err = r.c.NextReader() + if err != nil { + return 0, err + } + if r.term != "" { + r.r = io.MultiReader(r.r, strings.NewReader(r.term)) + } + } + n, err := r.r.Read(p) + if err == io.EOF { + err = nil + r.r = nil + } + return n, err +} diff --git a/vendor/github.com/gorilla/websocket/json.go b/vendor/github.com/gorilla/websocket/json.go new file mode 100644 index 0000000..dc2c1f6 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/json.go @@ -0,0 +1,60 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "encoding/json" + "io" +) + +// WriteJSON writes the JSON encoding of v as a message. +// +// Deprecated: Use c.WriteJSON instead. +func WriteJSON(c *Conn, v interface{}) error { + return c.WriteJSON(v) +} + +// WriteJSON writes the JSON encoding of v as a message. +// +// See the documentation for encoding/json Marshal for details about the +// conversion of Go values to JSON. +func (c *Conn) WriteJSON(v interface{}) error { + w, err := c.NextWriter(TextMessage) + if err != nil { + return err + } + err1 := json.NewEncoder(w).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// Deprecated: Use c.ReadJSON instead. +func ReadJSON(c *Conn, v interface{}) error { + return c.ReadJSON(v) +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Unmarshal function for details +// about the conversion of JSON to a Go value. +func (c *Conn) ReadJSON(v interface{}) error { + _, r, err := c.NextReader() + if err != nil { + return err + } + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + return err +} diff --git a/vendor/github.com/gorilla/websocket/mask.go b/vendor/github.com/gorilla/websocket/mask.go new file mode 100644 index 0000000..577fce9 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/mask.go @@ -0,0 +1,54 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +// +build !appengine + +package websocket + +import "unsafe" + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func maskBytes(key [4]byte, pos int, b []byte) int { + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} diff --git a/vendor/github.com/gorilla/websocket/mask_safe.go b/vendor/github.com/gorilla/websocket/mask_safe.go new file mode 100644 index 0000000..2aac060 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/mask_safe.go @@ -0,0 +1,15 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +// +build appengine + +package websocket + +func maskBytes(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} diff --git a/vendor/github.com/gorilla/websocket/prepared.go b/vendor/github.com/gorilla/websocket/prepared.go new file mode 100644 index 0000000..c854225 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/prepared.go @@ -0,0 +1,102 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "net" + "sync" + "time" +) + +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression is used +// because the CPU and memory expensive compression operation can be executed +// once for a given set of compression options. +type PreparedMessage struct { + messageType int + data []byte + mu sync.Mutex + frames map[prepareKey]*preparedFrame +} + +// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. +type prepareKey struct { + isServer bool + compress bool + compressionLevel int +} + +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns an initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire +// representation will be calculated lazily only once for a set of current +// connection options. +func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { + pm := &PreparedMessage{ + messageType: messageType, + frames: make(map[prepareKey]*preparedFrame), + data: data, + } + + // Prepare a plain server frame. + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) + if err != nil { + return nil, err + } + + // To protect against caller modifying the data argument, remember the data + // copied to the plain server frame. + pm.data = frameData[len(frameData)-len(data):] + return pm, nil +} + +func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame + } + pm.mu.Unlock() + + var err error + frame.once.Do(func() { + // Prepare a frame using a 'fake' connection. + // TODO: Refactor code in conn.go to allow more direct construction of + // the frame. + mu := make(chan struct{}, 1) + mu <- struct{}{} + var nc prepareConn + c := &Conn{ + conn: &nc, + mu: mu, + isServer: key.isServer, + compressionLevel: key.compressionLevel, + enableWriteCompression: true, + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + } + if key.compress { + c.newCompressionWriter = compressNoContextTakeover + } + err = c.WriteMessage(pm.messageType, pm.data) + frame.data = nc.buf.Bytes() + }) + return pm.messageType, frame.data, err +} + +type prepareConn struct { + buf bytes.Buffer + net.Conn +} + +func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } +func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/vendor/github.com/gorilla/websocket/proxy.go b/vendor/github.com/gorilla/websocket/proxy.go new file mode 100644 index 0000000..e87a8c9 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/proxy.go @@ -0,0 +1,77 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "encoding/base64" + "errors" + "net" + "net/http" + "net/url" + "strings" +) + +type netDialerFunc func(network, addr string) (net.Conn, error) + +func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { + return fn(network, addr) +} + +func init() { + proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil + }) +} + +type httpProxyDialer struct { + proxyURL *url.URL + forwardDial func(network, addr string) (net.Conn, error) +} + +func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { + hostPort, _ := hostPortNoPort(hpd.proxyURL) + conn, err := hpd.forwardDial(network, hostPort) + if err != nil { + return nil, err + } + + connectHeader := make(http.Header) + if user := hpd.proxyURL.User; user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } + + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: connectHeader, + } + + if err := connectReq.Write(conn); err != nil { + conn.Close() + return nil, err + } + + // Read response. It's OK to use and discard buffered reader here becaue + // the remote server does not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + conn.Close() + return nil, err + } + + if resp.StatusCode != 200 { + conn.Close() + f := strings.SplitN(resp.Status, " ", 2) + return nil, errors.New(f[1]) + } + return conn, nil +} diff --git a/vendor/github.com/gorilla/websocket/server.go b/vendor/github.com/gorilla/websocket/server.go new file mode 100644 index 0000000..887d558 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/server.go @@ -0,0 +1,363 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "errors" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// HandshakeError describes an error with the handshake from the peer. +type HandshakeError struct { + message string +} + +func (e HandshakeError) Error() string { return e.message } + +// Upgrader specifies parameters for upgrading an HTTP connection to a +// WebSocket connection. +type Upgrader struct { + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then buffers allocated by the HTTP server are used. The + // I/O buffer sizes do not limit the size of the messages that can be sent + // or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // Subprotocols specifies the server's supported protocols in order of + // preference. If this field is not nil, then the Upgrade method negotiates a + // subprotocol by selecting the first match in this list with a protocol + // requested by the client. If there's no match, then no protocol is + // negotiated (the Sec-Websocket-Protocol header is not included in the + // handshake response). + Subprotocols []string + + // Error specifies the function for generating HTTP error responses. If Error + // is nil, then http.Error is used to generate the HTTP response. + Error func(w http.ResponseWriter, r *http.Request, status int, reason error) + + // CheckOrigin returns true if the request Origin header is acceptable. If + // CheckOrigin is nil, then a safe default is used: return false if the + // Origin request header is present and the origin host is not equal to + // request Host header. + // + // A CheckOrigin function should carefully validate the request origin to + // prevent cross-site request forgery. + CheckOrigin func(r *http.Request) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool +} + +func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { + err := HandshakeError{reason} + if u.Error != nil { + u.Error(w, r, status, err) + } else { + w.Header().Set("Sec-Websocket-Version", "13") + http.Error(w, http.StatusText(status), status) + } + return nil, err +} + +// checkSameOrigin returns true if the origin is not set or is equal to the request host. +func checkSameOrigin(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + return equalASCIIFold(u.Host, r.Host) +} + +func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { + if u.Subprotocols != nil { + clientProtocols := Subprotocols(r) + for _, serverProtocol := range u.Subprotocols { + for _, clientProtocol := range clientProtocols { + if clientProtocol == serverProtocol { + return clientProtocol + } + } + } + } else if responseHeader != nil { + return responseHeader.Get("Sec-Websocket-Protocol") + } + return "" +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// application negotiated subprotocol (Sec-WebSocket-Protocol). +// +// If the upgrade fails, then Upgrade replies to the client with an HTTP error +// response. +func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { + const badHandshake = "websocket: the client is not using the websocket protocol: " + + if !tokenListContainsValue(r.Header, "Connection", "upgrade") { + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") + } + + if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") + } + + if r.Method != "GET" { + return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") + } + + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { + return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") + } + + if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") + } + + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = checkSameOrigin + } + if !checkOrigin(r) { + return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") + } + + challengeKey := r.Header.Get("Sec-Websocket-Key") + if challengeKey == "" { + return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank") + } + + subprotocol := u.selectSubprotocol(r, responseHeader) + + // Negotiate PMCE + var compress bool + if u.EnableCompression { + for _, ext := range parseExtensions(r.Header) { + if ext[""] != "permessage-deflate" { + continue + } + compress = true + break + } + } + + h, ok := w.(http.Hijacker) + if !ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") + } + var brw *bufio.ReadWriter + netConn, brw, err := h.Hijack() + if err != nil { + return u.returnError(w, r, http.StatusInternalServerError, err.Error()) + } + + if brw.Reader.Buffered() > 0 { + netConn.Close() + return nil, errors.New("websocket: client sent data before handshake is complete") + } + + var br *bufio.Reader + if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { + // Reuse hijacked buffered reader as connection reader. + br = brw.Reader + } + + buf := bufioWriterBuffer(netConn, brw.Writer) + + var writeBuf []byte + if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { + // Reuse hijacked write buffer as connection buffer. + writeBuf = buf + } + + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) + c.subprotocol = subprotocol + + if compress { + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } + + // Use larger of hijacked buffer and connection write buffer for header. + p := buf + if len(c.writeBuf) > len(p) { + p = c.writeBuf + } + p = p[:0] + + p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) + p = append(p, computeAcceptKey(challengeKey)...) + p = append(p, "\r\n"...) + if c.subprotocol != "" { + p = append(p, "Sec-WebSocket-Protocol: "...) + p = append(p, c.subprotocol...) + p = append(p, "\r\n"...) + } + if compress { + p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + } + for k, vs := range responseHeader { + if k == "Sec-Websocket-Protocol" { + continue + } + for _, v := range vs { + p = append(p, k...) + p = append(p, ": "...) + for i := 0; i < len(v); i++ { + b := v[i] + if b <= 31 { + // prevent response splitting. + b = ' ' + } + p = append(p, b) + } + p = append(p, "\r\n"...) + } + } + p = append(p, "\r\n"...) + + // Clear deadlines set by HTTP server. + netConn.SetDeadline(time.Time{}) + + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + } + if _, err = netConn.Write(p); err != nil { + netConn.Close() + return nil, err + } + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Time{}) + } + + return c, nil +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// Deprecated: Use websocket.Upgrader instead. +// +// Upgrade does not perform origin checking. The application is responsible for +// checking the Origin header before calling Upgrade. An example implementation +// of the same origin policy check is: +// +// if req.Header.Get("Origin") != "http://"+req.Host { +// http.Error(w, "Origin not allowed", http.StatusForbidden) +// return +// } +// +// If the endpoint supports subprotocols, then the application is responsible +// for negotiating the protocol used on the connection. Use the Subprotocols() +// function to get the subprotocols requested by the client. Use the +// Sec-Websocket-Protocol response header to specify the subprotocol selected +// by the application. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// negotiated subprotocol (Sec-Websocket-Protocol). +// +// The connection buffers IO to the underlying network connection. The +// readBufSize and writeBufSize parameters specify the size of the buffers to +// use. Messages can be larger than the buffers. +// +// If the request is not a valid WebSocket handshake, then Upgrade returns an +// error of type HandshakeError. Applications should handle this error by +// replying to the client with an HTTP error response. +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { + u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} + u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { + // don't return errors to maintain backwards compatibility + } + u.CheckOrigin = func(r *http.Request) bool { + // allow all connections by default + return true + } + return u.Upgrade(w, r, responseHeader) +} + +// Subprotocols returns the subprotocols requested by the client in the +// Sec-Websocket-Protocol header. +func Subprotocols(r *http.Request) []string { + h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) + if h == "" { + return nil + } + protocols := strings.Split(h, ",") + for i := range protocols { + protocols[i] = strings.TrimSpace(protocols[i]) + } + return protocols +} + +// IsWebSocketUpgrade returns true if the client requested upgrade to the +// WebSocket protocol. +func IsWebSocketUpgrade(r *http.Request) bool { + return tokenListContainsValue(r.Header, "Connection", "upgrade") && + tokenListContainsValue(r.Header, "Upgrade", "websocket") +} + +// bufioReaderSize size returns the size of a bufio.Reader. +func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int { + // This code assumes that peek on a reset reader returns + // bufio.Reader.buf[:0]. + // TODO: Use bufio.Reader.Size() after Go 1.10 + br.Reset(originalReader) + if p, err := br.Peek(0); err == nil { + return cap(p) + } + return 0 +} + +// writeHook is an io.Writer that records the last slice passed to it vio +// io.Writer.Write. +type writeHook struct { + p []byte +} + +func (wh *writeHook) Write(p []byte) (int, error) { + wh.p = p + return len(p), nil +} + +// bufioWriterBuffer grabs the buffer from a bufio.Writer. +func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { + // This code assumes that bufio.Writer.buf[:1] is passed to the + // bufio.Writer's underlying writer. + var wh writeHook + bw.Reset(&wh) + bw.WriteByte(0) + bw.Flush() + + bw.Reset(originalWriter) + + return wh.p[:cap(wh.p)] +} diff --git a/vendor/github.com/gorilla/websocket/trace.go b/vendor/github.com/gorilla/websocket/trace.go new file mode 100644 index 0000000..834f122 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/trace.go @@ -0,0 +1,19 @@ +// +build go1.8 + +package websocket + +import ( + "crypto/tls" + "net/http/httptrace" +) + +func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { + if trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(tlsConn, cfg) + if trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + return err +} diff --git a/vendor/github.com/gorilla/websocket/trace_17.go b/vendor/github.com/gorilla/websocket/trace_17.go new file mode 100644 index 0000000..77d05a0 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/trace_17.go @@ -0,0 +1,12 @@ +// +build !go1.8 + +package websocket + +import ( + "crypto/tls" + "net/http/httptrace" +) + +func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { + return doHandshake(tlsConn, cfg) +} diff --git a/vendor/github.com/gorilla/websocket/util.go b/vendor/github.com/gorilla/websocket/util.go new file mode 100644 index 0000000..7bf2f66 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/util.go @@ -0,0 +1,283 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "io" + "net/http" + "strings" + "unicode/utf8" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +// Token octets per RFC 2616. +var isTokenOctet = [256]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +// skipSpace returns a slice of the string s with all leading RFC 2616 linear +// whitespace removed. +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if b := s[i]; b != ' ' && b != '\t' { + break + } + } + return s[i:] +} + +// nextToken returns the leading RFC 2616 token of s and the string following +// the token. +func nextToken(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + if !isTokenOctet[s[i]] { + break + } + } + return s[:i], s[i:] +} + +// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616 +// and the string following the token or quoted string. +func nextTokenOrQuoted(s string) (value string, rest string) { + if !strings.HasPrefix(s, "\"") { + return nextToken(s) + } + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j++ + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j++ + } + } + return "", "" + } + } + return "", "" +} + +// equalASCIIFold returns true if s is equal to t with ASCII case folding as +// defined in RFC 4790. +func equalASCIIFold(s, t string) bool { + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + if 'A' <= sr && sr <= 'Z' { + sr = sr + 'a' - 'A' + } + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + return false + } + } + return s == t +} + +// tokenListContainsValue returns true if the 1#token header with the given +// name contains a token equal to value with ASCII case folding. +func tokenListContainsValue(header http.Header, name string, value string) bool { +headers: + for _, s := range header[name] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + s = skipSpace(s) + if s != "" && s[0] != ',' { + continue headers + } + if equalASCIIFold(t, value) { + return true + } + if s == "" { + continue headers + } + s = s[1:] + } + } + return false +} + +// parseExtensions parses WebSocket extensions from a header. +func parseExtensions(header http.Header) []map[string]string { + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + + var result []map[string]string +headers: + for _, s := range header["Sec-Websocket-Extensions"] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + continue headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + continue headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + continue headers + } + result = append(result, ext) + if s == "" { + continue headers + } + s = s[1:] + } + } + return result +} diff --git a/vendor/github.com/gorilla/websocket/x_net_proxy.go b/vendor/github.com/gorilla/websocket/x_net_proxy.go new file mode 100644 index 0000000..2e668f6 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/x_net_proxy.go @@ -0,0 +1,473 @@ +// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. +//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy + +// Package proxy provides support for a variety of protocols to proxy network +// data. +// + +package websocket + +import ( + "errors" + "io" + "net" + "net/url" + "os" + "strconv" + "strings" + "sync" +) + +type proxy_direct struct{} + +// Direct is a direct proxy: one that makes network connections directly. +var proxy_Direct = proxy_direct{} + +func (proxy_direct) Dial(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) +} + +// A PerHost directs connections to a default Dialer unless the host name +// requested matches one of a number of exceptions. +type proxy_PerHost struct { + def, bypass proxy_Dialer + + bypassNetworks []*net.IPNet + bypassIPs []net.IP + bypassZones []string + bypassHosts []string +} + +// NewPerHost returns a PerHost Dialer that directs connections to either +// defaultDialer or bypass, depending on whether the connection matches one of +// the configured rules. +func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost { + return &proxy_PerHost{ + def: defaultDialer, + bypass: bypass, + } +} + +// Dial connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + return p.dialerForRequest(host).Dial(network, addr) +} + +func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer { + if ip := net.ParseIP(host); ip != nil { + for _, net := range p.bypassNetworks { + if net.Contains(ip) { + return p.bypass + } + } + for _, bypassIP := range p.bypassIPs { + if bypassIP.Equal(ip) { + return p.bypass + } + } + return p.def + } + + for _, zone := range p.bypassZones { + if strings.HasSuffix(host, zone) { + return p.bypass + } + if host == zone[1:] { + // For a zone ".example.com", we match "example.com" + // too. + return p.bypass + } + } + for _, bypassHost := range p.bypassHosts { + if bypassHost == host { + return p.bypass + } + } + return p.def +} + +// AddFromString parses a string that contains comma-separated values +// specifying hosts that should use the bypass proxy. Each value is either an +// IP address, a CIDR range, a zone (*.example.com) or a host name +// (localhost). A best effort is made to parse the string and errors are +// ignored. +func (p *proxy_PerHost) AddFromString(s string) { + hosts := strings.Split(s, ",") + for _, host := range hosts { + host = strings.TrimSpace(host) + if len(host) == 0 { + continue + } + if strings.Contains(host, "/") { + // We assume that it's a CIDR address like 127.0.0.0/8 + if _, net, err := net.ParseCIDR(host); err == nil { + p.AddNetwork(net) + } + continue + } + if ip := net.ParseIP(host); ip != nil { + p.AddIP(ip) + continue + } + if strings.HasPrefix(host, "*.") { + p.AddZone(host[1:]) + continue + } + p.AddHost(host) + } +} + +// AddIP specifies an IP address that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match an IP. +func (p *proxy_PerHost) AddIP(ip net.IP) { + p.bypassIPs = append(p.bypassIPs, ip) +} + +// AddNetwork specifies an IP range that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match. +func (p *proxy_PerHost) AddNetwork(net *net.IPNet) { + p.bypassNetworks = append(p.bypassNetworks, net) +} + +// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of +// "example.com" matches "example.com" and all of its subdomains. +func (p *proxy_PerHost) AddZone(zone string) { + if strings.HasSuffix(zone, ".") { + zone = zone[:len(zone)-1] + } + if !strings.HasPrefix(zone, ".") { + zone = "." + zone + } + p.bypassZones = append(p.bypassZones, zone) +} + +// AddHost specifies a host name that will use the bypass proxy. +func (p *proxy_PerHost) AddHost(host string) { + if strings.HasSuffix(host, ".") { + host = host[:len(host)-1] + } + p.bypassHosts = append(p.bypassHosts, host) +} + +// A Dialer is a means to establish a connection. +type proxy_Dialer interface { + // Dial connects to the given address via the proxy. + Dial(network, addr string) (c net.Conn, err error) +} + +// Auth contains authentication parameters that specific Dialers may require. +type proxy_Auth struct { + User, Password string +} + +// FromEnvironment returns the dialer specified by the proxy related variables in +// the environment. +func proxy_FromEnvironment() proxy_Dialer { + allProxy := proxy_allProxyEnv.Get() + if len(allProxy) == 0 { + return proxy_Direct + } + + proxyURL, err := url.Parse(allProxy) + if err != nil { + return proxy_Direct + } + proxy, err := proxy_FromURL(proxyURL, proxy_Direct) + if err != nil { + return proxy_Direct + } + + noProxy := proxy_noProxyEnv.Get() + if len(noProxy) == 0 { + return proxy + } + + perHost := proxy_NewPerHost(proxy, proxy_Direct) + perHost.AddFromString(noProxy) + return perHost +} + +// proxySchemes is a map from URL schemes to a function that creates a Dialer +// from a URL with such a scheme. +var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error) + +// RegisterDialerType takes a URL scheme and a function to generate Dialers from +// a URL with that scheme and a forwarding Dialer. Registered schemes are used +// by FromURL. +func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) { + if proxy_proxySchemes == nil { + proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) + } + proxy_proxySchemes[scheme] = f +} + +// FromURL returns a Dialer given a URL specification and an underlying +// Dialer for it to make network requests. +func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) { + var auth *proxy_Auth + if u.User != nil { + auth = new(proxy_Auth) + auth.User = u.User.Username() + if p, ok := u.User.Password(); ok { + auth.Password = p + } + } + + switch u.Scheme { + case "socks5": + return proxy_SOCKS5("tcp", u.Host, auth, forward) + } + + // If the scheme doesn't match any of the built-in schemes, see if it + // was registered by another package. + if proxy_proxySchemes != nil { + if f, ok := proxy_proxySchemes[u.Scheme]; ok { + return f(u, forward) + } + } + + return nil, errors.New("proxy: unknown scheme: " + u.Scheme) +} + +var ( + proxy_allProxyEnv = &proxy_envOnce{ + names: []string{"ALL_PROXY", "all_proxy"}, + } + proxy_noProxyEnv = &proxy_envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +// (Borrowed from net/http/transport.go) +type proxy_envOnce struct { + names []string + once sync.Once + val string +} + +func (e *proxy_envOnce) Get() string { + e.once.Do(e.init) + return e.val +} + +func (e *proxy_envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } + } +} + +// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address +// with an optional username and password. See RFC 1928 and RFC 1929. +func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) { + s := &proxy_socks5{ + network: network, + addr: addr, + forward: forward, + } + if auth != nil { + s.user = auth.User + s.password = auth.Password + } + + return s, nil +} + +type proxy_socks5 struct { + user, password string + network, addr string + forward proxy_Dialer +} + +const proxy_socks5Version = 5 + +const ( + proxy_socks5AuthNone = 0 + proxy_socks5AuthPassword = 2 +) + +const proxy_socks5Connect = 1 + +const ( + proxy_socks5IP4 = 1 + proxy_socks5Domain = 3 + proxy_socks5IP6 = 4 +) + +var proxy_socks5Errors = []string{ + "", + "general failure", + "connection forbidden", + "network unreachable", + "host unreachable", + "connection refused", + "TTL expired", + "command not supported", + "address type not supported", +} + +// Dial connects to the address addr on the given network via the SOCKS5 proxy. +func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network) + } + + conn, err := s.forward.Dial(s.network, s.addr) + if err != nil { + return nil, err + } + if err := s.connect(conn, addr); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +// connect takes an existing connection to a socks5 proxy server, +// and commands the server to extend that connection to target, +// which must be a canonical address with a host and port. +func (s *proxy_socks5) connect(conn net.Conn, target string) error { + host, portStr, err := net.SplitHostPort(target) + if err != nil { + return err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return errors.New("proxy: failed to parse port number: " + portStr) + } + if port < 1 || port > 0xffff { + return errors.New("proxy: port number out of range: " + portStr) + } + + // the size here is just an estimate + buf := make([]byte, 0, 6+len(host)) + + buf = append(buf, proxy_socks5Version) + if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { + buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword) + } else { + buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone) + } + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + if buf[0] != 5 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) + } + if buf[1] == 0xff { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") + } + + // See RFC 1929 + if buf[1] == proxy_socks5AuthPassword { + buf = buf[:0] + buf = append(buf, 1 /* password protocol version */) + buf = append(buf, uint8(len(s.user))) + buf = append(buf, s.user...) + buf = append(buf, uint8(len(s.password))) + buf = append(buf, s.password...) + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if buf[1] != 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") + } + } + + buf = buf[:0] + buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */) + + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + buf = append(buf, proxy_socks5IP4) + ip = ip4 + } else { + buf = append(buf, proxy_socks5IP6) + } + buf = append(buf, ip...) + } else { + if len(host) > 255 { + return errors.New("proxy: destination host name too long: " + host) + } + buf = append(buf, proxy_socks5Domain) + buf = append(buf, byte(len(host))) + buf = append(buf, host...) + } + buf = append(buf, byte(port>>8), byte(port)) + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:4]); err != nil { + return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + failure := "unknown error" + if int(buf[1]) < len(proxy_socks5Errors) { + failure = proxy_socks5Errors[buf[1]] + } + + if len(failure) > 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) + } + + bytesToDiscard := 0 + switch buf[3] { + case proxy_socks5IP4: + bytesToDiscard = net.IPv4len + case proxy_socks5IP6: + bytesToDiscard = net.IPv6len + case proxy_socks5Domain: + _, err := io.ReadFull(conn, buf[:1]) + if err != nil { + return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + bytesToDiscard = int(buf[0]) + default: + return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) + } + + if cap(buf) < bytesToDiscard { + buf = make([]byte, bytesToDiscard) + } else { + buf = buf[:bytesToDiscard] + } + if _, err := io.ReadFull(conn, buf); err != nil { + return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + // Also need to discard the port number + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + return nil +} diff --git a/vendor/github.com/grishinsana/goftx/.gitignore b/vendor/github.com/grishinsana/goftx/.gitignore new file mode 100644 index 0000000..be00e21 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/.gitignore @@ -0,0 +1,19 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.idea + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Enviroment variables +.env \ No newline at end of file diff --git a/vendor/github.com/grishinsana/goftx/LICENSE b/vendor/github.com/grishinsana/goftx/LICENSE new file mode 100644 index 0000000..8120836 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 grishinsana + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/grishinsana/goftx/README.md b/vendor/github.com/grishinsana/goftx/README.md new file mode 100644 index 0000000..39060f7 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/README.md @@ -0,0 +1,116 @@ +# goftx +FTX exchange golang library + +### Install +```shell script +go get github.com/grishinsana/goftx +``` + +### Usage + +> See examples directory and test cases for more examples + +### TODO +- Private Streams +- Orders +- Futures +- Wallet +- Converts +- Fills +- Funding Payments +- Leveraged Tokens +- Options +- SRM Staking + +#### REST +```go +package main + +import ( + "fmt" + "net/http" + "time" + + "github.com/grishinsana/goftx" +) + +func main() { + client := goftx.New( + goftx.WithAuth("API-KEY", "API-SECRET"), + goftx.WithHTTPClient(&http.Client{ + Timeout: 5 * time.Second, + }), + ) + + info, err := client.Account.GetAccountInformation() + if err != nil { + panic(err) + } + fmt.Println(info) +} +``` + +#### WebSocket +```go +package main + +import ( + "context" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/grishinsana/goftx" +) + +func main() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancel := context.WithCancel(context.Background()) + + client := goftx.New() + client.Stream.SetDebugMode(true) + + data, err := client.Stream.SubscribeToTickers(ctx, "ETH/BTC") + if err != nil { + log.Fatalf("%+v", err) + } + + go func() { + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-data: + if !ok { + return + } + log.Printf("%+v\n", msg) + } + } + }() + + <-sigs + cancel() + time.Sleep(time.Second) +} +``` + +### Websocket Debug Mode +If need, it is possible to set debug mode to look error and system messages in stream methods +```go + client := goftx.New() + client.Stream.SetDebugMode(true) +``` + +### No Logged In Error +"Not logged in" errors usually come from a wrong signatures. FTX released an article on how to authenticate https://blog.ftx.com/blog/api-authentication/ + +If you have unauthorized error to private methods, then you need to use SetServerTimeDiff() +```go +ftx := New() +ftx.SetServerTimeDiff() +``` diff --git a/vendor/github.com/grishinsana/goftx/account.go b/vendor/github.com/grishinsana/goftx/account.go new file mode 100644 index 0000000..e723b13 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/account.go @@ -0,0 +1,121 @@ +package goftx + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/pkg/errors" + "github.com/shopspring/decimal" + + "github.com/grishinsana/goftx/models" +) + +const ( + apiGetAccountInformation = "/account" + apiGetPositions = "/positions" + apiGetBalances = "/wallet/balances" + apiPostLeverage = "/account/leverage" +) + +type Account struct { + client *Client +} + +func (a *Account) GetAccountInformation() (*models.AccountInformation, error) { + request, err := a.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiGetAccountInformation), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := a.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result *models.AccountInformation + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (s *Account) GetBalances() ([]*models.Balance, error) { + request, err := s.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiGetBalances), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := s.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.Balance + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (a *Account) GetPositions() ([]*models.Position, error) { + request, err := a.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiGetPositions), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := a.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.Position + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (a *Account) ChangeAccountLeverage(leverage decimal.Decimal) error { + body, err := json.Marshal(struct { + Leverage decimal.Decimal `json:"leverage"` + }{Leverage: leverage}) + if err != nil { + return errors.WithStack(err) + } + + request, err := a.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodPost, + URL: fmt.Sprintf("%s%s", apiUrl, apiPostLeverage), + Body: body, + }) + if err != nil { + return errors.WithStack(err) + } + + _, err = a.client.do(request) + if err != nil { + return errors.WithStack(err) + } + + return nil +} diff --git a/vendor/github.com/grishinsana/goftx/client.go b/vendor/github.com/grishinsana/goftx/client.go new file mode 100644 index 0000000..8203b14 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/client.go @@ -0,0 +1,229 @@ +package goftx + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "reflect" + "strconv" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/pkg/errors" +) + +const ( + apiUrl = "https://ftx.com/api" + apiOtcUrl = "https://otc.ftx.com/api" + + keyHeader = "FTX-KEY" + signHeader = "FTX-SIGN" + tsHeader = "FTX-TS" + subAccountHeader = "FTX-SUBACCOUNT" +) + +type Option func(c *Client) + +func WithHTTPClient(client *http.Client) Option { + return func(c *Client) { + c.client = client + } +} + +func WithAuth(key, secret string) Option { + return func(c *Client) { + c.apiKey = key + c.secret = secret + c.Stream.apiKey = key + c.Stream.secret = secret + } +} + +func WithSubaccount(subAccount string) Option { + return func(c *Client) { + c.subAccount = subAccount + } +} + +type Client struct { + client *http.Client + apiKey string + secret string + subAccount string + serverTimeDiff time.Duration + SubAccounts + Markets + Account + Stream + Orders + SpotMargin +} + +func New(opts ...Option) *Client { + client := &Client{ + client: http.DefaultClient, + } + + for _, opt := range opts { + opt(client) + } + + client.SubAccounts = SubAccounts{client: client} + client.Markets = Markets{client: client} + client.Account = Account{client: client} + client.Orders = Orders{client: client} + client.SpotMargin = SpotMargin{client: client} + client.Stream = Stream{ + apiKey: client.apiKey, + secret: client.secret, + subAccount: client.subAccount, + mu: &sync.Mutex{}, + url: wsUrl, + dialer: websocket.DefaultDialer, + wsReconnectionCount: reconnectCount, + wsReconnectionInterval: reconnectInterval, + wsTimeout: streamTimeout, + } + + return client +} + +func (c *Client) SetServerTimeDiff() error { + serverTime, err := c.GetServerTime() + if err != nil { + return errors.WithStack(err) + } + c.serverTimeDiff = serverTime.Sub(time.Now().UTC()) + return nil +} + +type Response struct { + Success bool `json:"success"` + Result json.RawMessage `json:"result"` + Error string `json:"error,omitempty"` +} + +type Request struct { + Auth bool + Method string + URL string + Headers map[string]string + Params map[string]string + Body []byte +} + +func (c *Client) prepareRequest(request Request) (*http.Request, error) { + req, err := http.NewRequest(request.Method, request.URL, bytes.NewBuffer(request.Body)) + if err != nil { + return nil, errors.WithStack(err) + } + + query := req.URL.Query() + for k, v := range request.Params { + query.Add(k, v) + } + req.URL.RawQuery = query.Encode() + + if request.Auth { + nonce := strconv.FormatInt(time.Now().UTC().Add(c.serverTimeDiff).Unix()*1000, 10) + payload := nonce + req.Method + req.URL.Path + if req.URL.RawQuery != "" { + payload += "?" + req.URL.RawQuery + } + if len(request.Body) > 0 { + payload += string(request.Body) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set(keyHeader, c.apiKey) + req.Header.Set(signHeader, c.signture(payload)) + req.Header.Set(tsHeader, nonce) + + if c.subAccount != "" { + req.Header.Set(subAccountHeader, c.subAccount) + } + } + + for k, v := range request.Headers { + req.Header.Set(k, v) + } + + return req, nil +} + +func (c *Client) do(req *http.Request) ([]byte, error) { + resp, err := c.client.Do(req) + if resp != nil { + defer resp.Body.Close() + } + if err != nil { + return nil, errors.WithStack(err) + } + + res, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.WithStack(err) + } + + var response Response + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.WithStack(err) + } + + if !response.Success { + return nil, errors.Errorf("Status Code: %d Error: %v", resp.StatusCode, response.Error) + } + + return response.Result, nil +} + +func (c *Client) prepareQueryParams(params interface{}) map[string]string { + result := make(map[string]string) + + val := reflect.ValueOf(params).Elem() + for i := 0; i < val.NumField(); i++ { + valueField := val.Field(i) + typeField := val.Type().Field(i) + tag := typeField.Tag + + result[tag.Get("json")] = valueField.String() + } + + return result +} + +func (c *Client) signture(payload string) string { + mac := hmac.New(sha256.New, []byte(c.secret)) + mac.Write([]byte(payload)) + return hex.EncodeToString(mac.Sum(nil)) +} + +func (c *Client) GetServerTime() (*time.Time, error) { + request, err := c.prepareRequest(Request{ + Method: http.MethodGet, + URL: fmt.Sprintf("%s/time", apiOtcUrl), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := c.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result time.Time + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return &result, nil +} diff --git a/vendor/github.com/grishinsana/goftx/go.mod b/vendor/github.com/grishinsana/goftx/go.mod new file mode 100644 index 0000000..7997afe --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/go.mod @@ -0,0 +1,18 @@ +module github.com/grishinsana/goftx + +go 1.14 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/go-numb/go-ftx v0.0.0-20200829181514-3144aa68f505 // indirect + github.com/google/go-querystring v1.0.0 // indirect + github.com/gorilla/websocket v1.4.2 + github.com/joho/godotenv v1.3.0 + github.com/json-iterator/go v1.1.10 // indirect + github.com/pkg/errors v0.9.1 + github.com/shopspring/decimal v1.2.0 + github.com/stretchr/testify v1.6.1 + github.com/valyala/fasthttp v1.16.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect +) diff --git a/vendor/github.com/grishinsana/goftx/go.sum b/vendor/github.com/grishinsana/goftx/go.sum new file mode 100644 index 0000000..2a3ae6e --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/go.sum @@ -0,0 +1,51 @@ +github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4= +github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/go-numb/go-ftx v0.0.0-20200829181514-3144aa68f505 h1:iOIIPP+XjnNbEH7N139XWTRdjTK3ZmM/R4RFwUZc9Cc= +github.com/go-numb/go-ftx v0.0.0-20200829181514-3144aa68f505/go.mod h1:rjG/Mg/la6U9w0NN/oaMZkgCpEQgseKPOl6EkvYkjCw= +github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68= +github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg= +github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.16.0 h1:9zAqOYLl8Tuy3E5R6ckzGDJ1g8+pw15oQp2iL9Jl6gQ= +github.com/valyala/fasthttp v1.16.0/go.mod h1:YOKImeEosDdBPnxc0gy7INqi3m1zK6A+xl6TwOBhHCA= +github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/grishinsana/goftx/markets.go b/vendor/github.com/grishinsana/goftx/markets.go new file mode 100644 index 0000000..7ed9199 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/markets.go @@ -0,0 +1,190 @@ +package goftx + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/pkg/errors" + + "github.com/grishinsana/goftx/models" +) + +const ( + apiGetMarkets = "/markets" + apiGetOrderBook = "/markets/%s/orderbook" + apiGetTrades = "/markets/%s/trades" + apiGetHistoricalPrices = "/markets/%s/candles" + apiGetLastCandle = "/markets/%s/candles/last" +) + +type Markets struct { + client *Client +} + +func (m *Markets) GetMarkets() ([]*models.Market, error) { + request, err := m.client.prepareRequest(Request{ + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiGetMarkets), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := m.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.Market + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (m *Markets) GetMarketByName(name string) (*models.Market, error) { + request, err := m.client.prepareRequest(Request{ + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s/%s", apiUrl, apiGetMarkets, name), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := m.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result models.Market + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return &result, nil +} + +func (m *Markets) GetOrderBook(marketName string, depth *int) (*models.OrderBook, error) { + params := map[string]string{} + if depth != nil { + params["depth"] = fmt.Sprintf("%d", *depth) + } + + path := fmt.Sprintf(apiGetOrderBook, marketName) + + request, err := m.client.prepareRequest(Request{ + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, path), + Params: params, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := m.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result models.OrderBook + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return &result, nil +} + +func (m *Markets) GetTrades(marketName string, params *models.GetTradesParams) ([]*models.Trade, error) { + queryParams, err := PrepareQueryParams(params) + if err != nil { + return nil, errors.WithStack(err) + } + + path := fmt.Sprintf(apiGetTrades, marketName) + request, err := m.client.prepareRequest(Request{ + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, path), + Params: queryParams, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := m.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.Trade + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (m *Markets) GetHistoricalPrices(marketName string, params *models.GetHistoricalPricesParams) ([]*models.HistoricalPrice, error) { + queryParams, err := PrepareQueryParams(params) + if err != nil { + return nil, errors.WithStack(err) + } + + path := fmt.Sprintf(apiGetHistoricalPrices, marketName) + request, err := m.client.prepareRequest(Request{ + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, path), + Params: queryParams, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := m.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.HistoricalPrice + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (m *Markets) GetLastCandle(marketName string, params *models.GetLastCandleParams) (*models.HistoricalPrice, error) { + queryParams, err := PrepareQueryParams(params) + if err != nil { + return nil, errors.WithStack(err) + } + + path := fmt.Sprintf(apiGetLastCandle, marketName) + request, err := m.client.prepareRequest(Request{ + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, path), + Params: queryParams, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := m.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result *models.HistoricalPrice + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} diff --git a/vendor/github.com/grishinsana/goftx/models/account.go b/vendor/github.com/grishinsana/goftx/models/account.go new file mode 100644 index 0000000..01e200c --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/models/account.go @@ -0,0 +1,39 @@ +package models + +import "github.com/shopspring/decimal" + +type AccountInformation struct { + BackstopProvider bool `json:"backstopProvider"` + Collateral decimal.Decimal `json:"collateral"` + FreeCollateral decimal.Decimal `json:"freeCollateral"` + InitialMarginRequirement decimal.Decimal `json:"initialMarginRequirement"` + Liquidating bool `json:"liquidating"` + MaintenanceMarginRequirement decimal.Decimal `json:"maintenanceMarginRequirement"` + MakerFee decimal.Decimal `json:"makerFee"` + MarginFraction decimal.Decimal `json:"marginFraction"` + OpenMarginFraction decimal.Decimal `json:"openMarginFraction"` + TakerFee decimal.Decimal `json:"takerFee"` + TotalAccountValue decimal.Decimal `json:"totalAccountValue"` + TotalPositionSize decimal.Decimal `json:"totalPositionSize"` + Username string `json:"username"` + Leverage decimal.Decimal `json:"leverage"` + Positions []Position `json:"positions"` +} + +type Position struct { + Cost decimal.Decimal `json:"cost"` + EntryPrice decimal.Decimal `json:"entryPrice"` + EstimatedLiquidationPrice decimal.Decimal `json:"estimatedLiquidationPrice"` + Future string `json:"future"` + InitialMarginRequirement decimal.Decimal `json:"initialMarginRequirement"` + LongOrderSize decimal.Decimal `json:"longOrderSize"` + MaintenanceMarginRequirement decimal.Decimal `json:"maintenanceMarginRequirement"` + NetSize decimal.Decimal `json:"netSize"` + OpenSize decimal.Decimal `json:"openSize"` + RealizedPnl decimal.Decimal `json:"realizedPnl"` + ShortOrderSize decimal.Decimal `json:"shortOrderSize"` + Side string `json:"side"` + Size decimal.Decimal `json:"size"` + UnrealizedPnl decimal.Decimal `json:"unrealizedPnl"` + CollateralUsed decimal.Decimal `json:"collateralUsed"` +} diff --git a/vendor/github.com/grishinsana/goftx/models/markets.go b/vendor/github.com/grishinsana/goftx/models/markets.go new file mode 100644 index 0000000..46249ef --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/models/markets.go @@ -0,0 +1,117 @@ +package models + +import ( + "time" + + "github.com/shopspring/decimal" +) + +type Market struct { + Name string `json:"name"` + Type string `json:"type"` + Underlying string `json:"underlying"` + BaseCurrency string `json:"baseCurrency"` + QuoteCurrency string `json:"quoteCurrency"` + Enabled bool `json:"enabled"` + Ask decimal.Decimal `json:"ask"` + Bid decimal.Decimal `json:"bid"` + Last decimal.Decimal `json:"last"` + PostOnly bool `json:"postOnly"` + PriceIncrement decimal.Decimal `json:"priceIncrement"` + SizeIncrement decimal.Decimal `json:"sizeIncrement"` + Restricted bool `json:"restricted"` +} + +// The bids and asks are formatted like so: +// [[best price, size at price], [next next best price, size at price], ...] +// +// Checksum +// Every message contains a signed 32-bit integer checksum of the orderbook. +// You can run the same checksum on your client orderbook state and compare it to checksum field. +// If they are the same, your client's state is correct. +// If not, you have likely lost or mishandled a packet and should re-subscribe to receive the initial snapshot. +// +// The checksum operates on a string that represents the first 100 orders on the orderbook on either side. The format of the string is: +// +// ::::::... +// For example, if the orderbook was comprised of the following two bids and asks: +// +// bids: [[5000.5, 10], [4995.0, 5]] +// asks: [[5001.0, 6], [5002.0, 7]] +// The string would be '5005.5:10:5001.0:6:4995.0:5:5002.0:7' +// +// If there are more orders on one side of the book than the other, then simply omit the information about orders that don't exist. +// +// For example, if the orderbook had the following bids and asks: +// +// bids: [[5000.5, 10], [4995.0, 5]] +// asks: [[5001.0, 6]] +// The string would be '5005.5:10:5001.0:6:4995.0:5' +// +// The final checksum is the crc32 value of this string. +type OrderBook struct { + Asks [][]decimal.Decimal `json:"asks"` + Bids [][]decimal.Decimal `json:"bids"` + Checksum int64 `json:"checksum,omitempty"` + Time FTXTime `json:"time"` +} + +type Trade struct { + ID int64 `json:"id"` + Liquidation bool `json:"liquidation"` + Price decimal.Decimal `json:"price"` + Side string `json:"side"` + Size decimal.Decimal `json:"size"` + Time time.Time `json:"time"` +} + +type HistoricalPrice struct { + StartTime time.Time `json:"startTime"` + Open decimal.Decimal `json:"open"` + Close decimal.Decimal `json:"close"` + High decimal.Decimal `json:"high"` + Low decimal.Decimal `json:"low"` + Volume decimal.Decimal `json:"volume"` +} + +type Ticker struct { + Bid decimal.Decimal `json:"bid"` + Ask decimal.Decimal `json:"ask"` + BidSize decimal.Decimal `json:"bidSize"` + AskSize decimal.Decimal `json:"askSize"` + Last decimal.Decimal `json:"last"` + Time FTXTime `json:"time"` +} + +type Fill struct { + Fee decimal.Decimal `json:"fee"` + FeeRate decimal.Decimal `json:"feeRate"` + Future string `json:"future"` + ID int64 `json:"id"` + Liquidity string `json:"liquidity"` + Market string `json:"market"` + OrderID int64 `json:"orderId"` + TradeID int64 `json:"tradeId"` + Price decimal.Decimal `json:"price"` + Side string `json:"side"` + Size decimal.Decimal `json:"size"` + Time FTXTime `json:"time"` + Type string `json:"type"` +} + +type GetTradesParams struct { + Limit *int `json:"limit"` + StartTime *int `json:"start_time"` + EndTime *int `json:"end_time"` +} + +type GetHistoricalPricesParams struct { + Resolution Resolution `json:"resolution"` + Limit *int `json:"limit"` + StartTime *int `json:"start_time"` + EndTime *int `json:"end_time"` +} + +type GetLastCandleParams struct { + Resolution Resolution `json:"resolution"` +} diff --git a/vendor/github.com/grishinsana/goftx/models/orders.go b/vendor/github.com/grishinsana/goftx/models/orders.go new file mode 100644 index 0000000..01cbec8 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/models/orders.go @@ -0,0 +1,109 @@ +package models + +import ( + "time" + + "github.com/shopspring/decimal" +) + +type Order struct { + ID int64 `json:"id"` + Market string `json:"market"` + Type OrderType `json:"type"` + Side Side `json:"side"` + Price decimal.Decimal `json:"price"` + Size decimal.Decimal `json:"size"` + FilledSize decimal.Decimal `json:"filledSize"` + RemainingSize decimal.Decimal `json:"remainingSize"` + AvgFillPrice decimal.Decimal `json:"avgFillPrice"` + Status Status `json:"status"` + CreatedAt time.Time `json:"createdAt"` + ReduceOnly bool `json:"reduceOnly"` + Ioc bool `json:"ioc"` + PostOnly bool `json:"postOnly"` + Future string `json:"future"` + ClientID string `json:"clientId"` +} + +type PlaceOrderParams struct { + Market string `json:"market"` + Type OrderType `json:"type"` + Side Side `json:"side"` + Price decimal.Decimal `json:"price"` + Size decimal.Decimal `json:"size"` + ReduceOnly bool `json:"reduceOnly"` + Ioc bool `json:"ioc"` + PostOnly bool `json:"postOnly"` +} + +type PlaceStopLossParams struct { + Market string `json:"market"` + Side Side `json:"side"` + Size decimal.Decimal `json:"size"` + ReduceOnly bool `json:"reduceOnly"` + Type TriggerOrderType `json:"type"` + TriggerPrice decimal.Decimal `json:"triggerPrice"` +} + +type PlaceStopLimitParams struct { + Market string `json:"market"` + Side Side `json:"side"` + Size decimal.Decimal `json:"size"` + ReduceOnly bool `json:"reduceOnly"` + Type TriggerOrderType `json:"type"` + TriggerPrice decimal.Decimal `json:"triggerPrice"` + OrderPrice decimal.Decimal `json:"orderPrice"` +} + +type PlaceTrailingStopParams struct { + Market string `json:"market"` + Side Side `json:"side"` + Size decimal.Decimal `json:"size"` + ReduceOnly bool `json:"reduceOnly"` + Type TriggerOrderType `json:"type"` + TrailValue decimal.Decimal `json:"trailValue"` +} + +type GetOrdersHistoryParams struct { + Market *string `json:"market"` + Limit *int `json:"limit"` + StartTime *int `json:"start_time"` + EndTime *int `json:"end_time"` +} + +type TriggerOrder struct { + ID int64 `json:"id"` + OrderID int64 `json:"orderId"` + Market string `json:"market"` + CreatedAt time.Time `json:"createdAt"` + Error string `json:"error"` + Future string `json:"future"` + OrderPrice decimal.Decimal `json:"orderPrice"` + ReduceOnly bool `json:"reduceOnly"` + Side Side `json:"side"` + Size decimal.Decimal `json:"size"` + Status Status `json:"status"` + TrailStart decimal.Decimal `json:"trailStart"` + TrailValue decimal.Decimal `json:"trailValue"` + TriggerPrice decimal.Decimal `json:"triggerPrice"` + TriggeredAt time.Time `json:"triggeredAt"` + Type TriggerOrderType `json:"type"` + OrderType OrderType `json:"orderType"` + FilledSize decimal.Decimal `json:"filledSize"` + AvgFillPrice decimal.Decimal `json:"avgFillPrice"` + OrderStatus string `json:"orderStatus"` + RetryUntilFilled bool `json:"retryUntilFilled"` +} + +type GetOpenTriggerOrdersParams struct { + Market *string `json:"market"` + Type *TriggerOrderType `json:"type"` +} + +type Trigger struct { + Error string `json:"error"` + FilledSize float64 `json:"filledSize"` + OrderSize float64 `json:"orderSize"` + OrderID int64 `json:"orderId"` + Time time.Time `json:"time"` +} diff --git a/vendor/github.com/grishinsana/goftx/models/spotmargin.go b/vendor/github.com/grishinsana/goftx/models/spotmargin.go new file mode 100644 index 0000000..811fb30 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/models/spotmargin.go @@ -0,0 +1,17 @@ +package models + +import "github.com/shopspring/decimal" + +type LendingInfo struct { + Coin string `json:"coin"` + Lendable decimal.Decimal `json:"lendable"` + Locked decimal.Decimal `json:"locked"` + MinRate decimal.Decimal `json:"minRate"` + Offered decimal.Decimal `json:"offered"` +} + +type LendingRate struct { + Coin string `json:"coin"` + Estimate decimal.Decimal `json:"estimate"` + Previous decimal.Decimal `json:"previous"` +} diff --git a/vendor/github.com/grishinsana/goftx/models/subaccounts.go b/vendor/github.com/grishinsana/goftx/models/subaccounts.go new file mode 100644 index 0000000..d9cb515 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/models/subaccounts.go @@ -0,0 +1,38 @@ +package models + +import ( + "time" + + "github.com/shopspring/decimal" +) + +type SubAccount struct { + Nickname string `json:"nickname"` + Deletable bool `json:"deletable"` + Editable bool `json:"editable"` + Competition bool `json:"competition,omitempty"` +} + +type Balance struct { + Coin string `json:"coin"` + Free decimal.Decimal `json:"free"` + Total decimal.Decimal `json:"total"` + SpotBorrow decimal.Decimal `json:"spotBorrow"` + AvailableWithoutBorrow decimal.Decimal `json:"availableWithoutBorrow"` +} + +type TransferPayload struct { + Coin string `json:"coin"` + Size decimal.Decimal `json:"size"` + Source *string `json:"source"` + Destination *string `json:"destination"` +} + +type TransferResponse struct { + ID int64 `json:"id"` + Coin string `json:"coin"` + Size decimal.Decimal `json:"size"` + Time time.Time `json:"time"` + Notes string `json:"notes"` + Status TransferStatus `json:"status"` +} diff --git a/vendor/github.com/grishinsana/goftx/models/types.go b/vendor/github.com/grishinsana/goftx/models/types.go new file mode 100644 index 0000000..6785115 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/models/types.go @@ -0,0 +1,113 @@ +package models + +import ( + "encoding/json" + "math" + "time" +) + +type Resolution int + +const ( + Sec15 = 15 + Minute = 60 + Minute5 = 300 + Minute15 = 900 + Hour = 3600 + Hour4 = 14400 + Day = 86400 +) + +type Channel string + +const ( + OrderBookChannel = Channel("orderbook") + TradesChannel = Channel("trades") + TickerChannel = Channel("ticker") + MarketsChannel = Channel("markets") + FillsChannel = Channel("fills") + OrdersChannel = Channel("orders") +) + +type Operation string + +const ( + Subscribe = Operation("subscribe") + UnSubscribe = Operation("unsubscribe") + Login = Operation("login") +) + +type ResponseType string + +const ( + Error = ResponseType("error") + Subscribed = ResponseType("subscribed") + UnSubscribed = ResponseType("unsubscribed") + Info = ResponseType("info") + Partial = ResponseType("partial") + Update = ResponseType("update") +) + +type TransferStatus string + +const Complete = TransferStatus("complete") + +type OrderType string + +const ( + LimitOrder = OrderType("limit") + MarketOrder = OrderType("market") +) + +type Side string + +const ( + Sell = Side("sell") + Buy = Side("buy") +) + +type Status string + +const ( + New = Status("new") + Open = Status("open") + Closed = Status("closed") +) + +type TriggerOrderType string + +const ( + Stop = TriggerOrderType("stop") + TrailingStop = TriggerOrderType("trailing_stop") + TakeProfit = TriggerOrderType("take_profit") +) + +type FTXTime struct { + Time time.Time +} + +func (f *FTXTime) UnmarshalJSON(data []byte) error { + var t float64 + err := json.Unmarshal(data, &t) + + // FTX uses ISO format sometimes so we have to detect and handle that differently. + if err != nil { + var iso time.Time + errIso := json.Unmarshal(data, &iso) + + if errIso != nil { + return err + } + + f.Time = iso + return nil + } + + sec, nsec := math.Modf(t) + f.Time = time.Unix(int64(sec), int64(nsec)) + return nil +} + +func (f FTXTime) MarshalJSON() ([]byte, error) { + return json.Marshal(float64(f.Time.UnixNano()) / float64(1000000000)) +} diff --git a/vendor/github.com/grishinsana/goftx/models/websocket.go b/vendor/github.com/grishinsana/goftx/models/websocket.go new file mode 100644 index 0000000..8549e4a --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/models/websocket.go @@ -0,0 +1,138 @@ +package models + +import ( + "encoding/json" + + "github.com/pkg/errors" +) + +type BaseResponse struct { + Type ResponseType + Symbol string +} + +type TickerResponse struct { + Ticker + BaseResponse +} + +type TradesResponse struct { + Trades []Trade + BaseResponse +} + +type TradeResponse struct { + Trade + BaseResponse +} + +type OrderBookResponse struct { + OrderBook + BaseResponse +} + +type FillResponse struct { + Fill + BaseResponse +} + +type OrderResponse struct { + Order + BaseResponse +} + +type WSRequest struct { + Channel Channel `json:"channel"` + Market string `json:"market"` + Op Operation `json:"op"` + Args map[string]interface{} `json:"args"` +} + +type WsResponse struct { + Channel Channel `json:"channel"` + Market string `json:"market"` + Type ResponseType `json:"type"` + Code int `json:"code"` + Message string `json:"msg"` + Data json.RawMessage `json:"data"` +} + +func (wr *WsResponse) MapToTradesResponse() (*TradesResponse, error) { + var trades []Trade + err := json.Unmarshal(wr.Data, &trades) + if err != nil { + return nil, errors.WithStack(err) + } + + return &TradesResponse{ + Trades: trades, + BaseResponse: BaseResponse{ + Type: wr.Type, + Symbol: wr.Market, + }, + }, nil +} + +func (wr *WsResponse) MapToTickerResponse() (*TickerResponse, error) { + ticker := Ticker{} + err := json.Unmarshal(wr.Data, &ticker) + if err != nil { + return nil, errors.WithStack(err) + } + + return &TickerResponse{ + Ticker: ticker, + BaseResponse: BaseResponse{ + Type: wr.Type, + Symbol: wr.Market, + }, + }, nil +} + +func (wr *WsResponse) MapToOrderBookResponse() (*OrderBookResponse, error) { + book := OrderBook{} + err := json.Unmarshal(wr.Data, &book) + if err != nil { + return nil, errors.WithStack(err) + } + + return &OrderBookResponse{ + OrderBook: book, + BaseResponse: BaseResponse{ + Type: wr.Type, + Symbol: wr.Market, + }, + }, nil +} + +func (wr *WsResponse) MapToFillResponse() (*FillResponse, error) { + fill := Fill{} + err := json.Unmarshal(wr.Data, &fill) + if err != nil { + return nil, errors.WithStack(err) + } + + return &FillResponse{ + Fill: fill, + BaseResponse: BaseResponse{ + Type: wr.Type, + Symbol: wr.Market, + }, + }, nil +} + +func (wr *WsResponse) MapToOrderResponse() (*OrderResponse, error) { + order := Order{} + err := json.Unmarshal(wr.Data, &order) + if err != nil { + return nil, errors.WithStack(err) + } + + return &OrderResponse{ + Order: order, + BaseResponse: BaseResponse{ + Type: wr.Type, + Symbol: wr.Market, + }, + }, nil +} diff --git a/vendor/github.com/grishinsana/goftx/orders.go b/vendor/github.com/grishinsana/goftx/orders.go new file mode 100644 index 0000000..c55fe54 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/orders.go @@ -0,0 +1,273 @@ +package goftx + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/grishinsana/goftx/models" + "github.com/pkg/errors" +) + +const ( + apiGetOpenOrders = "/orders" + apiGetOrderStatus = "/orders/%d" + apiGetOrdersHistory = "/orders/history" + apiGetTriggerOrders = "/conditional_orders" + apiGetOrderTriggers = "/conditional_orders/%d/triggers" + apiPlaceTriggerOrder = "/conditional_orders" + apiPlaceOrder = "/orders" + apiCancelOrders = "/orders" +) + +type Orders struct { + client *Client +} + +func (o *Orders) GetOpenOrders(market string) ([]*models.Order, error) { + request, err := o.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiGetOpenOrders), + Params: map[string]string{ + "market": market, + }, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := o.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.Order + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (o *Orders) GetOrderStatus(orderID int64) (*models.Order, error) { + request, err := o.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, fmt.Sprintf(apiGetOrderStatus, orderID)), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := o.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result *models.Order + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (o *Orders) GetOrdersHistory(params *models.GetOrdersHistoryParams) ([]*models.Order, error) { + queryParams, err := PrepareQueryParams(params) + if err != nil { + return nil, errors.WithStack(err) + } + + request, err := o.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiGetOrdersHistory), + Params: queryParams, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := o.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.Order + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (o *Orders) GetOpenTriggerOrders(params *models.GetOpenTriggerOrdersParams) ([]*models.TriggerOrder, error) { + queryParams, err := PrepareQueryParams(params) + if err != nil { + return nil, errors.WithStack(err) + } + + request, err := o.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiGetTriggerOrders), + Params: queryParams, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := o.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.TriggerOrder + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (o *Orders) GetOrderTriggers(orderID int64) ([]*models.Trigger, error) { + request, err := o.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, fmt.Sprintf(apiGetOrderTriggers, orderID)), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := o.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.Trigger + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (o *Orders) PlaceOrder(orderParams models.PlaceOrderParams) (*models.Order, error) { + body, err := json.Marshal(orderParams) + if err != nil { + return nil, errors.WithStack(err) + } + + request, err := o.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodPost, + URL: fmt.Sprintf("%s%s", apiUrl, apiPlaceOrder), + Body: body, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := o.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result *models.Order + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (o *Orders) PlaceTriggerOrder(orderParams interface{}) (*models.TriggerOrder, error) { + body, err := json.Marshal(orderParams) + if err != nil { + return nil, errors.WithStack(err) + } + + request, err := o.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodPost, + URL: fmt.Sprintf("%s%s", apiUrl, apiPlaceTriggerOrder), + Body: body, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := o.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result *models.TriggerOrder + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (o *Orders) CancelAllOrders(market string) error { + return o.cancelOrders(struct { + Market string `json:"market"` + }{ + Market: market, + }) +} + +func (o *Orders) CancelAllLimitOrders(market string) error { + return o.cancelOrders(struct { + Market string `json:"market"` + LimitOrdersOnly bool `json:"limitOrdersOnly"` + }{ + Market: market, + LimitOrdersOnly: true, + }) +} + +func (o *Orders) CancelAllConditionalOrders(market string) error { + return o.cancelOrders(struct { + Market string `json:"market"` + ConditionalOrdersOnly bool `json:"conditionalOrdersOnly"` + }{ + Market: market, + ConditionalOrdersOnly: true, + }) +} + +func (o *Orders) cancelOrders(req interface{}) error { + body, err := json.Marshal(req) + + if err != nil { + return errors.WithStack(err) + } + + request, err := o.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodDelete, + URL: fmt.Sprintf("%s%s", apiUrl, apiCancelOrders), + Body: body, + }) + if err != nil { + return errors.WithStack(err) + } + + _, err = o.client.do(request) + if err != nil { + return errors.WithStack(err) + } + + return nil +} diff --git a/vendor/github.com/grishinsana/goftx/spotmargin.go b/vendor/github.com/grishinsana/goftx/spotmargin.go new file mode 100644 index 0000000..a512ecc --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/spotmargin.go @@ -0,0 +1,102 @@ +package goftx + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/pkg/errors" + "github.com/shopspring/decimal" + + "github.com/grishinsana/goftx/models" +) + +const ( + apiGetLendingInfo = "/spot_margin/lending_info" + apiGetLendingRates = "/spot_margin/lending_rates" + apiSubmitLendingOffer = "/spot_margin/offers" +) + +type SpotMargin struct { + client *Client +} + +func (m *SpotMargin) GetLendingInfo() ([]*models.LendingInfo, error) { + request, err := m.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiGetLendingInfo), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := m.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.LendingInfo + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (m *SpotMargin) GetLendingRates() ([]*models.LendingRate, error) { + request, err := m.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiGetLendingRates), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := m.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.LendingRate + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (m *SpotMargin) SubmitLendingOffer(coin string, size decimal.Decimal, rate decimal.Decimal) error { + body, err := json.Marshal(struct { + Coin string `json:"coin"` + Size decimal.Decimal `json:"size"` + Rate decimal.Decimal `json:"rate"` + }{ + Coin: coin, + Size: size, + Rate: rate, + }) + if err != nil { + return errors.WithStack(err) + } + + request, err := m.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodPost, + URL: fmt.Sprintf("%s%s", apiUrl, apiSubmitLendingOffer), + Body: body, + }) + if err != nil { + return errors.WithStack(err) + } + + _, err = m.client.do(request) + if err != nil { + return errors.WithStack(err) + } + + return nil +} diff --git a/vendor/github.com/grishinsana/goftx/subaccounts.go b/vendor/github.com/grishinsana/goftx/subaccounts.go new file mode 100644 index 0000000..0c0c761 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/subaccounts.go @@ -0,0 +1,183 @@ +package goftx + +import ( + "encoding/json" + "fmt" + "github.com/grishinsana/goftx/models" + "github.com/pkg/errors" + "net/http" +) + +const ( + apiSubaccounts = "/subaccounts" + apiChangeSubaccountName = "/subaccounts/update_name" + apiGetSubaccountBalances = "/subaccounts/%s/balances" + apiTransfer = "/subaccounts/transfer" +) + +type SubAccounts struct { + client *Client +} + +func (s *SubAccounts) GetSubaccounts() ([]*models.SubAccount, error) { + request, err := s.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, apiSubaccounts), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := s.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.SubAccount + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (s *SubAccounts) CreateSubaccount(nickname string) (*models.SubAccount, error) { + body, err := json.Marshal(struct { + Nickname string `json:"nickname"` + }{Nickname: nickname}) + if err != nil { + return nil, errors.WithStack(err) + } + + request, err := s.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodPost, + URL: fmt.Sprintf("%s%s", apiUrl, apiSubaccounts), + Body: body, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := s.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result models.SubAccount + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return &result, nil +} + +func (s *SubAccounts) ChangeSubaccount(nickname, newNickname string) error { + body, err := json.Marshal(struct { + Nickname string `json:"nickname"` + NewNickname string `json:"newNickname"` + }{Nickname: nickname, NewNickname: newNickname}) + if err != nil { + return errors.WithStack(err) + } + + request, err := s.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodPost, + URL: fmt.Sprintf("%s%s", apiUrl, apiChangeSubaccountName), + Body: body, + }) + if err != nil { + return errors.WithStack(err) + } + + _, err = s.client.do(request) + if err != nil { + return errors.WithStack(err) + } + + return nil +} + +func (s *SubAccounts) DeleteSubaccount(nickname string) error { + body, err := json.Marshal(struct { + Nickname string `json:"nickname"` + }{Nickname: nickname}) + if err != nil { + return errors.WithStack(err) + } + + request, err := s.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodDelete, + URL: fmt.Sprintf("%s%s", apiUrl, apiSubaccounts), + Body: body, + }) + if err != nil { + return errors.WithStack(err) + } + + _, err = s.client.do(request) + if err != nil { + return errors.WithStack(err) + } + + return nil +} + +func (s *SubAccounts) GetSubaccountBalances(nickname string) ([]*models.Balance, error) { + request, err := s.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", apiUrl, fmt.Sprintf(apiGetSubaccountBalances, nickname)), + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := s.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result []*models.Balance + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (s *SubAccounts) Transfer(payload *models.TransferPayload) (*models.TransferResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return nil, errors.WithStack(err) + } + + request, err := s.client.prepareRequest(Request{ + Auth: true, + Method: http.MethodPost, + URL: fmt.Sprintf("%s%s", apiUrl, apiTransfer), + Body: body, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + response, err := s.client.do(request) + if err != nil { + return nil, errors.WithStack(err) + } + + var result models.TransferResponse + err = json.Unmarshal(response, &result) + if err != nil { + return nil, errors.WithStack(err) + } + + return &result, nil +} diff --git a/vendor/github.com/grishinsana/goftx/utils.go b/vendor/github.com/grishinsana/goftx/utils.go new file mode 100644 index 0000000..4e51fcc --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/utils.go @@ -0,0 +1,38 @@ +package goftx + +import ( + "fmt" + "reflect" + + "github.com/pkg/errors" +) + +func PrepareQueryParams(params interface{}) (map[string]string, error) { + result := make(map[string]string) + + val := reflect.ValueOf(params).Elem() + if val.Kind() != reflect.Struct { + return result, nil + } + + for i := 0; i < val.NumField(); i++ { + valueField := val.Field(i) + typeField := val.Type().Field(i) + tag := typeField.Tag.Get("json") + + switch valueField.Kind() { + case reflect.Ptr: + if valueField.IsNil() { + continue + } + result[tag] = fmt.Sprintf("%v", valueField.Elem().Interface()) + default: + if valueField.IsZero() { + return result, errors.Errorf("required field: %v", tag) + } + result[tag] = fmt.Sprintf("%v", valueField.Interface()) + } + } + + return result, nil +} diff --git a/vendor/github.com/grishinsana/goftx/websocket.go b/vendor/github.com/grishinsana/goftx/websocket.go new file mode 100644 index 0000000..df313a1 --- /dev/null +++ b/vendor/github.com/grishinsana/goftx/websocket.go @@ -0,0 +1,479 @@ +package goftx + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "log" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/pkg/errors" + + "github.com/grishinsana/goftx/models" +) + +const ( + wsUrl = "wss://ftx.com/ws/" + + writeWait = time.Second * 10 + reconnectCount = int(10) + reconnectInterval = time.Second + streamTimeout = time.Second * 60 +) + +type Stream struct { + apiKey string + secret string + subAccount string + mu *sync.Mutex + url string + dialer *websocket.Dialer + wsReconnectionCount int + wsReconnectionInterval time.Duration + wsTimeout time.Duration + isDebugMode bool +} + +func (s *Stream) SetStreamTimeout(timeout time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + + s.wsTimeout = timeout +} + +func (s *Stream) SetReconnectionCount(count int) { + s.mu.Lock() + defer s.mu.Unlock() + + s.wsReconnectionCount = count +} + +func (s *Stream) SetDebugMode(isDebugMode bool) { + s.mu.Lock() + defer s.mu.Unlock() + + s.isDebugMode = isDebugMode +} + +func (s *Stream) SetReconnectionInterval(interval time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + + s.wsReconnectionInterval = interval +} + +func (s *Stream) printf(format string, v ...interface{}) { + if !s.isDebugMode { + return + } + log.Printf(format+"\n", v) +} + +func (s *Stream) connect(requests ...models.WSRequest) (*websocket.Conn, error) { + conn, _, err := s.dialer.Dial(s.url, nil) + if err != nil { + return nil, errors.WithStack(err) + } + + err = s.auth(conn) + + if err != nil { + return nil, errors.WithStack(err) + } + + s.printf("connected to %v", s.url) + + err = s.subscribe(conn, requests) + if err != nil { + return nil, errors.WithStack(err) + } + + conn.SetPongHandler(func(msg string) error { + s.printf("%s", "PONG") + conn.SetReadDeadline(time.Now().Add(s.wsTimeout)) + return nil + }) + + return conn, nil +} + +func (s *Stream) serve(ctx context.Context, requests ...models.WSRequest) (chan interface{}, error) { + conn, err := s.connect(requests...) + if err != nil { + return nil, errors.WithStack(err) + } + + doneC := make(chan struct{}) + eventsC := make(chan interface{}, 1) + + go func() { + go func() { + defer close(doneC) + + for { + message := &models.WsResponse{} + err = conn.ReadJSON(&message) + if err != nil { + s.printf("read msg: %v", err) + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + return + } + conn, err = s.reconnect(ctx, requests) + if err != nil { + s.printf("reconnect: %+v", err) + return + } + continue + } + + switch message.Type { + case models.Subscribed, models.UnSubscribed: + continue + } + + var response interface{} + switch message.Channel { + case models.TickerChannel: + response, err = message.MapToTickerResponse() + case models.TradesChannel: + response, err = message.MapToTradesResponse() + case models.OrderBookChannel: + response, err = message.MapToOrderBookResponse() + case models.OrdersChannel: + response, err = message.MapToOrderResponse() + case models.FillsChannel: + response, err = message.MapToFillResponse() + case models.MarketsChannel: + response = message.Data + } + + eventsC <- response + } + }() + + for { + select { + case <-ctx.Done(): + err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + s.printf("write close msg: %v", err) + return + } + select { + case <-doneC: + return + case <-time.After(time.Second): + return + } + case <-doneC: + return + case <-time.After((s.wsTimeout * 9) / 10): + s.printf("%s", "PING") + conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + s.printf("write ping: %v", err) + } + } + } + }() + + return eventsC, nil +} + +// Credit to https://github.com/go-numb/go-ftx +func (s *Stream) auth(conn *websocket.Conn) error { + if s.apiKey == "" { + return nil + } + + s.printf("%s", "Authenticate websocket connection") + msec := time.Now().UTC().UnixNano() / int64(time.Millisecond) + + mac := hmac.New(sha256.New, []byte(s.secret)) + mac.Write([]byte(fmt.Sprintf("%dwebsocket_login", msec))) + args := map[string]interface{}{ + "key": s.apiKey, + "sign": hex.EncodeToString(mac.Sum(nil)), + "time": msec, + } + if s.subAccount != "" { + args["subaccount"] = s.subAccount + } + + return conn.WriteJSON(models.WSRequest{ + Op: models.Login, + Args: args, + }) +} + +func (s *Stream) reconnect(ctx context.Context, requests []models.WSRequest) (*websocket.Conn, error) { + for i := 1; i < s.wsReconnectionCount; i++ { + conn, err := s.connect(requests...) + if err == nil { + return conn, nil + } + + select { + case <-time.After(s.wsReconnectionInterval): + conn, err := s.connect(requests...) + if err != nil { + continue + } + + return conn, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + return nil, errors.New("reconnection failed") +} + +func (s *Stream) subscribe(conn *websocket.Conn, requests []models.WSRequest) error { + for _, req := range requests { + err := conn.WriteJSON(req) + if err != nil { + return errors.WithStack(err) + } + } + return nil +} + +func (s *Stream) SubscribeToFills(ctx context.Context) (chan *models.FillResponse, error) { + eventsC, err := s.serve(ctx, models.WSRequest{ + Channel: models.FillsChannel, + Op: models.Subscribe, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + fillsC := make(chan *models.FillResponse, 1) + go func() { + defer close(fillsC) + for { + select { + case <-ctx.Done(): + return + case event, ok := <-eventsC: + if !ok { + return + } + fill, ok := event.(*models.FillResponse) + if !ok { + return + } + fillsC <- fill + } + } + }() + + return fillsC, nil +} + +func (s *Stream) SubscribeToOrders(ctx context.Context) (chan *models.OrderResponse, error) { + eventsC, err := s.serve(ctx, models.WSRequest{ + Channel: models.OrdersChannel, + Op: models.Subscribe, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + ordersC := make(chan *models.OrderResponse, 1) + go func() { + defer close(ordersC) + for { + select { + case <-ctx.Done(): + return + case event, ok := <-eventsC: + if !ok { + return + } + order, ok := event.(*models.OrderResponse) + if !ok { + return + } + ordersC <- order + } + } + }() + + return ordersC, nil +} + +func (s *Stream) SubscribeToTickers(ctx context.Context, symbols ...string) (chan *models.TickerResponse, error) { + if len(symbols) == 0 { + return nil, errors.New("symbols is missing") + } + + requests := make([]models.WSRequest, 0, len(symbols)) + for _, symbol := range symbols { + requests = append(requests, models.WSRequest{ + Channel: models.TickerChannel, + Market: symbol, + Op: models.Subscribe, + }) + } + + eventsC, err := s.serve(ctx, requests...) + if err != nil { + return nil, errors.WithStack(err) + } + + tickersC := make(chan *models.TickerResponse, 1) + go func() { + defer close(tickersC) + for { + select { + case <-ctx.Done(): + return + case event, ok := <-eventsC: + if !ok { + return + } + ticker, ok := event.(*models.TickerResponse) + if !ok { + return + } + tickersC <- ticker + } + } + }() + + return tickersC, nil +} + +func (s *Stream) SubscribeToMarkets(ctx context.Context) (chan *models.Market, error) { + eventsC, err := s.serve(ctx, models.WSRequest{ + Channel: models.MarketsChannel, + Op: models.Subscribe, + }) + if err != nil { + return nil, errors.WithStack(err) + } + + marketsC := make(chan *models.Market, 1) + go func() { + defer close(marketsC) + for { + select { + case <-ctx.Done(): + return + case event, ok := <-eventsC: + if !ok { + return + } + data, ok := event.(json.RawMessage) + if !ok { + return + } + var markets struct { + Data map[string]*models.Market `json:"data"` + } + err = json.Unmarshal(data, &markets) + if err != nil { + s.printf("unmarshal markets: %+v", err) + return + } + for _, market := range markets.Data { + marketsC <- market + } + } + } + }() + + return marketsC, nil +} + +func (s *Stream) SubscribeToTrades(ctx context.Context, symbols ...string) (chan *models.TradeResponse, error) { + if len(symbols) == 0 { + return nil, errors.New("symbols is missing") + } + + requests := make([]models.WSRequest, 0, len(symbols)) + for _, symbol := range symbols { + requests = append(requests, models.WSRequest{ + Channel: models.TradesChannel, + Market: symbol, + Op: models.Subscribe, + }) + } + + eventsC, err := s.serve(ctx, requests...) + if err != nil { + return nil, errors.WithStack(err) + } + + tradesC := make(chan *models.TradeResponse, 1) + go func() { + defer close(tradesC) + for { + select { + case <-ctx.Done(): + return + case event, ok := <-eventsC: + if !ok { + return + } + trades, ok := event.(*models.TradesResponse) + if !ok { + return + } + for _, trade := range trades.Trades { + tradesC <- &models.TradeResponse{ + Trade: trade, + BaseResponse: trades.BaseResponse, + } + } + } + } + }() + + return tradesC, nil +} + +func (s *Stream) SubscribeToOrderBooks(ctx context.Context, symbols ...string) (chan *models.OrderBookResponse, error) { + if len(symbols) == 0 { + return nil, errors.New("symbols is missing") + } + + requests := make([]models.WSRequest, 0, len(symbols)) + for _, symbol := range symbols { + requests = append(requests, models.WSRequest{ + Channel: models.OrderBookChannel, + Market: symbol, + Op: models.Subscribe, + }) + } + + eventsC, err := s.serve(ctx, requests...) + if err != nil { + return nil, errors.WithStack(err) + } + + booksC := make(chan *models.OrderBookResponse, 1) + go func() { + defer close(booksC) + for { + select { + case <-ctx.Done(): + return + case event, ok := <-eventsC: + book, ok := event.(*models.OrderBookResponse) + if !ok { + return + } + booksC <- book + } + } + }() + + return booksC, nil +} diff --git a/vendor/github.com/pkg/errors/.gitignore b/vendor/github.com/pkg/errors/.gitignore new file mode 100644 index 0000000..daf913b --- /dev/null +++ b/vendor/github.com/pkg/errors/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/pkg/errors/.travis.yml b/vendor/github.com/pkg/errors/.travis.yml new file mode 100644 index 0000000..9159de0 --- /dev/null +++ b/vendor/github.com/pkg/errors/.travis.yml @@ -0,0 +1,10 @@ +language: go +go_import_path: github.com/pkg/errors +go: + - 1.11.x + - 1.12.x + - 1.13.x + - tip + +script: + - make check diff --git a/vendor/github.com/pkg/errors/LICENSE b/vendor/github.com/pkg/errors/LICENSE new file mode 100644 index 0000000..835ba3e --- /dev/null +++ b/vendor/github.com/pkg/errors/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2015, Dave Cheney +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/pkg/errors/Makefile b/vendor/github.com/pkg/errors/Makefile new file mode 100644 index 0000000..ce9d7cd --- /dev/null +++ b/vendor/github.com/pkg/errors/Makefile @@ -0,0 +1,44 @@ +PKGS := github.com/pkg/errors +SRCDIRS := $(shell go list -f '{{.Dir}}' $(PKGS)) +GO := go + +check: test vet gofmt misspell unconvert staticcheck ineffassign unparam + +test: + $(GO) test $(PKGS) + +vet: | test + $(GO) vet $(PKGS) + +staticcheck: + $(GO) get honnef.co/go/tools/cmd/staticcheck + staticcheck -checks all $(PKGS) + +misspell: + $(GO) get github.com/client9/misspell/cmd/misspell + misspell \ + -locale GB \ + -error \ + *.md *.go + +unconvert: + $(GO) get github.com/mdempsky/unconvert + unconvert -v $(PKGS) + +ineffassign: + $(GO) get github.com/gordonklaus/ineffassign + find $(SRCDIRS) -name '*.go' | xargs ineffassign + +pedantic: check errcheck + +unparam: + $(GO) get mvdan.cc/unparam + unparam ./... + +errcheck: + $(GO) get github.com/kisielk/errcheck + errcheck $(PKGS) + +gofmt: + @echo Checking code is gofmted + @test -z "$(shell gofmt -s -l -d -e $(SRCDIRS) | tee /dev/stderr)" diff --git a/vendor/github.com/pkg/errors/README.md b/vendor/github.com/pkg/errors/README.md new file mode 100644 index 0000000..54dfdcb --- /dev/null +++ b/vendor/github.com/pkg/errors/README.md @@ -0,0 +1,59 @@ +# errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors) [![Sourcegraph](https://sourcegraph.com/github.com/pkg/errors/-/badge.svg)](https://sourcegraph.com/github.com/pkg/errors?badge) + +Package errors provides simple error handling primitives. + +`go get github.com/pkg/errors` + +The traditional error handling idiom in Go is roughly akin to +```go +if err != nil { + return err +} +``` +which applied recursively up the call stack results in error reports without context or debugging information. The errors package allows programmers to add context to the failure path in their code in a way that does not destroy the original value of the error. + +## Adding context to an error + +The errors.Wrap function returns a new error that adds context to the original error. For example +```go +_, err := ioutil.ReadAll(r) +if err != nil { + return errors.Wrap(err, "read failed") +} +``` +## Retrieving the cause of an error + +Using `errors.Wrap` constructs a stack of errors, adding context to the preceding error. Depending on the nature of the error it may be necessary to reverse the operation of errors.Wrap to retrieve the original error for inspection. Any error value which implements this interface can be inspected by `errors.Cause`. +```go +type causer interface { + Cause() error +} +``` +`errors.Cause` will recursively retrieve the topmost error which does not implement `causer`, which is assumed to be the original cause. For example: +```go +switch err := errors.Cause(err).(type) { +case *MyError: + // handle specifically +default: + // unknown error +} +``` + +[Read the package documentation for more information](https://godoc.org/github.com/pkg/errors). + +## Roadmap + +With the upcoming [Go2 error proposals](https://go.googlesource.com/proposal/+/master/design/go2draft.md) this package is moving into maintenance mode. The roadmap for a 1.0 release is as follows: + +- 0.9. Remove pre Go 1.9 and Go 1.10 support, address outstanding pull requests (if possible) +- 1.0. Final release. + +## Contributing + +Because of the Go2 errors changes, this package is not accepting proposals for new functionality. With that said, we welcome pull requests, bug fixes and issue reports. + +Before sending a PR, please discuss your change by raising an issue. + +## License + +BSD-2-Clause diff --git a/vendor/github.com/pkg/errors/appveyor.yml b/vendor/github.com/pkg/errors/appveyor.yml new file mode 100644 index 0000000..a932ead --- /dev/null +++ b/vendor/github.com/pkg/errors/appveyor.yml @@ -0,0 +1,32 @@ +version: build-{build}.{branch} + +clone_folder: C:\gopath\src\github.com\pkg\errors +shallow_clone: true # for startup speed + +environment: + GOPATH: C:\gopath + +platform: + - x64 + +# http://www.appveyor.com/docs/installed-software +install: + # some helpful output for debugging builds + - go version + - go env + # pre-installed MinGW at C:\MinGW is 32bit only + # but MSYS2 at C:\msys64 has mingw64 + - set PATH=C:\msys64\mingw64\bin;%PATH% + - gcc --version + - g++ --version + +build_script: + - go install -v ./... + +test_script: + - set PATH=C:\gopath\bin;%PATH% + - go test -v ./... + +#artifacts: +# - path: '%GOPATH%\bin\*.exe' +deploy: off diff --git a/vendor/github.com/pkg/errors/errors.go b/vendor/github.com/pkg/errors/errors.go new file mode 100644 index 0000000..161aea2 --- /dev/null +++ b/vendor/github.com/pkg/errors/errors.go @@ -0,0 +1,288 @@ +// Package errors provides simple error handling primitives. +// +// The traditional error handling idiom in Go is roughly akin to +// +// if err != nil { +// return err +// } +// +// which when applied recursively up the call stack results in error reports +// without context or debugging information. The errors package allows +// programmers to add context to the failure path in their code in a way +// that does not destroy the original value of the error. +// +// Adding context to an error +// +// The errors.Wrap function returns a new error that adds context to the +// original error by recording a stack trace at the point Wrap is called, +// together with the supplied message. For example +// +// _, err := ioutil.ReadAll(r) +// if err != nil { +// return errors.Wrap(err, "read failed") +// } +// +// If additional control is required, the errors.WithStack and +// errors.WithMessage functions destructure errors.Wrap into its component +// operations: annotating an error with a stack trace and with a message, +// respectively. +// +// Retrieving the cause of an error +// +// Using errors.Wrap constructs a stack of errors, adding context to the +// preceding error. Depending on the nature of the error it may be necessary +// to reverse the operation of errors.Wrap to retrieve the original error +// for inspection. Any error value which implements this interface +// +// type causer interface { +// Cause() error +// } +// +// can be inspected by errors.Cause. errors.Cause will recursively retrieve +// the topmost error that does not implement causer, which is assumed to be +// the original cause. For example: +// +// switch err := errors.Cause(err).(type) { +// case *MyError: +// // handle specifically +// default: +// // unknown error +// } +// +// Although the causer interface is not exported by this package, it is +// considered a part of its stable public interface. +// +// Formatted printing of errors +// +// All error values returned from this package implement fmt.Formatter and can +// be formatted by the fmt package. The following verbs are supported: +// +// %s print the error. If the error has a Cause it will be +// printed recursively. +// %v see %s +// %+v extended format. Each Frame of the error's StackTrace will +// be printed in detail. +// +// Retrieving the stack trace of an error or wrapper +// +// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are +// invoked. This information can be retrieved with the following interface: +// +// type stackTracer interface { +// StackTrace() errors.StackTrace +// } +// +// The returned errors.StackTrace type is defined as +// +// type StackTrace []Frame +// +// The Frame type represents a call site in the stack trace. Frame supports +// the fmt.Formatter interface that can be used for printing information about +// the stack trace of this error. For example: +// +// if err, ok := err.(stackTracer); ok { +// for _, f := range err.StackTrace() { +// fmt.Printf("%+s:%d\n", f, f) +// } +// } +// +// Although the stackTracer interface is not exported by this package, it is +// considered a part of its stable public interface. +// +// See the documentation for Frame.Format for more details. +package errors + +import ( + "fmt" + "io" +) + +// New returns an error with the supplied message. +// New also records the stack trace at the point it was called. +func New(message string) error { + return &fundamental{ + msg: message, + stack: callers(), + } +} + +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +// Errorf also records the stack trace at the point it was called. +func Errorf(format string, args ...interface{}) error { + return &fundamental{ + msg: fmt.Sprintf(format, args...), + stack: callers(), + } +} + +// fundamental is an error that has a message and a stack, but no caller. +type fundamental struct { + msg string + *stack +} + +func (f *fundamental) Error() string { return f.msg } + +func (f *fundamental) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + io.WriteString(s, f.msg) + f.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, f.msg) + case 'q': + fmt.Fprintf(s, "%q", f.msg) + } +} + +// WithStack annotates err with a stack trace at the point WithStack was called. +// If err is nil, WithStack returns nil. +func WithStack(err error) error { + if err == nil { + return nil + } + return &withStack{ + err, + callers(), + } +} + +type withStack struct { + error + *stack +} + +func (w *withStack) Cause() error { return w.error } + +// Unwrap provides compatibility for Go 1.13 error chains. +func (w *withStack) Unwrap() error { return w.error } + +func (w *withStack) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v", w.Cause()) + w.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, w.Error()) + case 'q': + fmt.Fprintf(s, "%q", w.Error()) + } +} + +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: message, + } + return &withStack{ + err, + callers(), + } +} + +// Wrapf returns an error annotating err with a stack trace +// at the point Wrapf is called, and the format specifier. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + } + return &withStack{ + err, + callers(), + } +} + +// WithMessage annotates err with a new message. +// If err is nil, WithMessage returns nil. +func WithMessage(err error, message string) error { + if err == nil { + return nil + } + return &withMessage{ + cause: err, + msg: message, + } +} + +// WithMessagef annotates err with the format specifier. +// If err is nil, WithMessagef returns nil. +func WithMessagef(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + return &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + } +} + +type withMessage struct { + cause error + msg string +} + +func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } +func (w *withMessage) Cause() error { return w.cause } + +// Unwrap provides compatibility for Go 1.13 error chains. +func (w *withMessage) Unwrap() error { return w.cause } + +func (w *withMessage) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v\n", w.Cause()) + io.WriteString(s, w.msg) + return + } + fallthrough + case 's', 'q': + io.WriteString(s, w.Error()) + } +} + +// Cause returns the underlying cause of the error, if possible. +// An error value has a cause if it implements the following +// interface: +// +// type causer interface { +// Cause() error +// } +// +// If the error does not implement Cause, the original error will +// be returned. If the error is nil, nil will be returned without further +// investigation. +func Cause(err error) error { + type causer interface { + Cause() error + } + + for err != nil { + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + return err +} diff --git a/vendor/github.com/pkg/errors/go113.go b/vendor/github.com/pkg/errors/go113.go new file mode 100644 index 0000000..be0d10d --- /dev/null +++ b/vendor/github.com/pkg/errors/go113.go @@ -0,0 +1,38 @@ +// +build go1.13 + +package errors + +import ( + stderrors "errors" +) + +// Is reports whether any error in err's chain matches target. +// +// The chain consists of err itself followed by the sequence of errors obtained by +// repeatedly calling Unwrap. +// +// An error is considered to match a target if it is equal to that target or if +// it implements a method Is(error) bool such that Is(target) returns true. +func Is(err, target error) bool { return stderrors.Is(err, target) } + +// As finds the first error in err's chain that matches target, and if so, sets +// target to that error value and returns true. +// +// The chain consists of err itself followed by the sequence of errors obtained by +// repeatedly calling Unwrap. +// +// An error matches target if the error's concrete value is assignable to the value +// pointed to by target, or if the error has a method As(interface{}) bool such that +// As(target) returns true. In the latter case, the As method is responsible for +// setting target. +// +// As will panic if target is not a non-nil pointer to either a type that implements +// error, or to any interface type. As returns false if err is nil. +func As(err error, target interface{}) bool { return stderrors.As(err, target) } + +// Unwrap returns the result of calling the Unwrap method on err, if err's +// type contains an Unwrap method returning error. +// Otherwise, Unwrap returns nil. +func Unwrap(err error) error { + return stderrors.Unwrap(err) +} diff --git a/vendor/github.com/pkg/errors/stack.go b/vendor/github.com/pkg/errors/stack.go new file mode 100644 index 0000000..779a834 --- /dev/null +++ b/vendor/github.com/pkg/errors/stack.go @@ -0,0 +1,177 @@ +package errors + +import ( + "fmt" + "io" + "path" + "runtime" + "strconv" + "strings" +) + +// Frame represents a program counter inside a stack frame. +// For historical reasons if Frame is interpreted as a uintptr +// its value represents the program counter + 1. +type Frame uintptr + +// pc returns the program counter for this frame; +// multiple frames may have the same PC value. +func (f Frame) pc() uintptr { return uintptr(f) - 1 } + +// file returns the full path to the file that contains the +// function for this Frame's pc. +func (f Frame) file() string { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return "unknown" + } + file, _ := fn.FileLine(f.pc()) + return file +} + +// line returns the line number of source code of the +// function for this Frame's pc. +func (f Frame) line() int { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return 0 + } + _, line := fn.FileLine(f.pc()) + return line +} + +// name returns the name of this function, if known. +func (f Frame) name() string { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return "unknown" + } + return fn.Name() +} + +// Format formats the frame according to the fmt.Formatter interface. +// +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+s function name and path of source file relative to the compile time +// GOPATH separated by \n\t (\n\t) +// %+v equivalent to %+s:%d +func (f Frame) Format(s fmt.State, verb rune) { + switch verb { + case 's': + switch { + case s.Flag('+'): + io.WriteString(s, f.name()) + io.WriteString(s, "\n\t") + io.WriteString(s, f.file()) + default: + io.WriteString(s, path.Base(f.file())) + } + case 'd': + io.WriteString(s, strconv.Itoa(f.line())) + case 'n': + io.WriteString(s, funcname(f.name())) + case 'v': + f.Format(s, 's') + io.WriteString(s, ":") + f.Format(s, 'd') + } +} + +// MarshalText formats a stacktrace Frame as a text string. The output is the +// same as that of fmt.Sprintf("%+v", f), but without newlines or tabs. +func (f Frame) MarshalText() ([]byte, error) { + name := f.name() + if name == "unknown" { + return []byte(name), nil + } + return []byte(fmt.Sprintf("%s %s:%d", name, f.file(), f.line())), nil +} + +// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). +type StackTrace []Frame + +// Format formats the stack of Frames according to the fmt.Formatter interface. +// +// %s lists source files for each Frame in the stack +// %v lists the source file and line number for each Frame in the stack +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+v Prints filename, function, and line number for each Frame in the stack. +func (st StackTrace) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case s.Flag('+'): + for _, f := range st { + io.WriteString(s, "\n") + f.Format(s, verb) + } + case s.Flag('#'): + fmt.Fprintf(s, "%#v", []Frame(st)) + default: + st.formatSlice(s, verb) + } + case 's': + st.formatSlice(s, verb) + } +} + +// formatSlice will format this StackTrace into the given buffer as a slice of +// Frame, only valid when called with '%s' or '%v'. +func (st StackTrace) formatSlice(s fmt.State, verb rune) { + io.WriteString(s, "[") + for i, f := range st { + if i > 0 { + io.WriteString(s, " ") + } + f.Format(s, verb) + } + io.WriteString(s, "]") +} + +// stack represents a stack of program counters. +type stack []uintptr + +func (s *stack) Format(st fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case st.Flag('+'): + for _, pc := range *s { + f := Frame(pc) + fmt.Fprintf(st, "\n%+v", f) + } + } + } +} + +func (s *stack) StackTrace() StackTrace { + f := make([]Frame, len(*s)) + for i := 0; i < len(f); i++ { + f[i] = Frame((*s)[i]) + } + return f +} + +func callers() *stack { + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(3, pcs[:]) + var st stack = pcs[0:n] + return &st +} + +// funcname removes the path prefix component of a function's name reported by func.Name(). +func funcname(name string) string { + i := strings.LastIndex(name, "/") + name = name[i+1:] + i = strings.Index(name, ".") + return name[i+1:] +} diff --git a/vendor/github.com/robfig/cron/v3/.gitignore b/vendor/github.com/robfig/cron/v3/.gitignore new file mode 100644 index 0000000..0026861 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/vendor/github.com/robfig/cron/v3/.travis.yml b/vendor/github.com/robfig/cron/v3/.travis.yml new file mode 100644 index 0000000..4f2ee4d --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/.travis.yml @@ -0,0 +1 @@ +language: go diff --git a/vendor/github.com/robfig/cron/v3/LICENSE b/vendor/github.com/robfig/cron/v3/LICENSE new file mode 100644 index 0000000..3a0f627 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/LICENSE @@ -0,0 +1,21 @@ +Copyright (C) 2012 Rob Figueiredo +All Rights Reserved. + +MIT LICENSE + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/robfig/cron/v3/README.md b/vendor/github.com/robfig/cron/v3/README.md new file mode 100644 index 0000000..8db4f55 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/README.md @@ -0,0 +1,125 @@ +[![GoDoc](http://godoc.org/github.com/robfig/cron?status.png)](http://godoc.org/github.com/robfig/cron) +[![Build Status](https://travis-ci.org/robfig/cron.svg?branch=master)](https://travis-ci.org/robfig/cron) + +# cron + +Cron V3 has been released! + +To download the specific tagged release, run: + + go get github.com/robfig/cron/v3@v3.0.0 + +Import it in your program as: + + import "github.com/robfig/cron/v3" + +It requires Go 1.11 or later due to usage of Go Modules. + +Refer to the documentation here: +http://godoc.org/github.com/robfig/cron + +The rest of this document describes the the advances in v3 and a list of +breaking changes for users that wish to upgrade from an earlier version. + +## Upgrading to v3 (June 2019) + +cron v3 is a major upgrade to the library that addresses all outstanding bugs, +feature requests, and rough edges. It is based on a merge of master which +contains various fixes to issues found over the years and the v2 branch which +contains some backwards-incompatible features like the ability to remove cron +jobs. In addition, v3 adds support for Go Modules, cleans up rough edges like +the timezone support, and fixes a number of bugs. + +New features: + +- Support for Go modules. Callers must now import this library as + `github.com/robfig/cron/v3`, instead of `gopkg.in/...` + +- Fixed bugs: + - 0f01e6b parser: fix combining of Dow and Dom (#70) + - dbf3220 adjust times when rolling the clock forward to handle non-existent midnight (#157) + - eeecf15 spec_test.go: ensure an error is returned on 0 increment (#144) + - 70971dc cron.Entries(): update request for snapshot to include a reply channel (#97) + - 1cba5e6 cron: fix: removing a job causes the next scheduled job to run too late (#206) + +- Standard cron spec parsing by default (first field is "minute"), with an easy + way to opt into the seconds field (quartz-compatible). Although, note that the + year field (optional in Quartz) is not supported. + +- Extensible, key/value logging via an interface that complies with + the https://github.com/go-logr/logr project. + +- The new Chain & JobWrapper types allow you to install "interceptors" to add + cross-cutting behavior like the following: + - Recover any panics from jobs + - Delay a job's execution if the previous run hasn't completed yet + - Skip a job's execution if the previous run hasn't completed yet + - Log each job's invocations + - Notification when jobs are completed + +It is backwards incompatible with both v1 and v2. These updates are required: + +- The v1 branch accepted an optional seconds field at the beginning of the cron + spec. This is non-standard and has led to a lot of confusion. The new default + parser conforms to the standard as described by [the Cron wikipedia page]. + + UPDATING: To retain the old behavior, construct your Cron with a custom + parser: + + // Seconds field, required + cron.New(cron.WithSeconds()) + + // Seconds field, optional + cron.New( + cron.WithParser( + cron.SecondOptional | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)) + +- The Cron type now accepts functional options on construction rather than the + previous ad-hoc behavior modification mechanisms (setting a field, calling a setter). + + UPDATING: Code that sets Cron.ErrorLogger or calls Cron.SetLocation must be + updated to provide those values on construction. + +- CRON_TZ is now the recommended way to specify the timezone of a single + schedule, which is sanctioned by the specification. The legacy "TZ=" prefix + will continue to be supported since it is unambiguous and easy to do so. + + UPDATING: No update is required. + +- By default, cron will no longer recover panics in jobs that it runs. + Recovering can be surprising (see issue #192) and seems to be at odds with + typical behavior of libraries. Relatedly, the `cron.WithPanicLogger` option + has been removed to accommodate the more general JobWrapper type. + + UPDATING: To opt into panic recovery and configure the panic logger: + + cron.New(cron.WithChain( + cron.Recover(logger), // or use cron.DefaultLogger + )) + +- In adding support for https://github.com/go-logr/logr, `cron.WithVerboseLogger` was + removed, since it is duplicative with the leveled logging. + + UPDATING: Callers should use `WithLogger` and specify a logger that does not + discard `Info` logs. For convenience, one is provided that wraps `*log.Logger`: + + cron.New( + cron.WithLogger(cron.VerbosePrintfLogger(logger))) + + +### Background - Cron spec format + +There are two cron spec formats in common usage: + +- The "standard" cron format, described on [the Cron wikipedia page] and used by + the cron Linux system utility. + +- The cron format used by [the Quartz Scheduler], commonly used for scheduled + jobs in Java software + +[the Cron wikipedia page]: https://en.wikipedia.org/wiki/Cron +[the Quartz Scheduler]: http://www.quartz-scheduler.org/documentation/quartz-2.x/tutorials/crontrigger.html + +The original version of this package included an optional "seconds" field, which +made it incompatible with both of these formats. Now, the "standard" format is +the default format accepted, and the Quartz format is opt-in. diff --git a/vendor/github.com/robfig/cron/v3/chain.go b/vendor/github.com/robfig/cron/v3/chain.go new file mode 100644 index 0000000..118e5bb --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/chain.go @@ -0,0 +1,92 @@ +package cron + +import ( + "fmt" + "runtime" + "sync" + "time" +) + +// JobWrapper decorates the given Job with some behavior. +type JobWrapper func(Job) Job + +// Chain is a sequence of JobWrappers that decorates submitted jobs with +// cross-cutting behaviors like logging or synchronization. +type Chain struct { + wrappers []JobWrapper +} + +// NewChain returns a Chain consisting of the given JobWrappers. +func NewChain(c ...JobWrapper) Chain { + return Chain{c} +} + +// Then decorates the given job with all JobWrappers in the chain. +// +// This: +// NewChain(m1, m2, m3).Then(job) +// is equivalent to: +// m1(m2(m3(job))) +func (c Chain) Then(j Job) Job { + for i := range c.wrappers { + j = c.wrappers[len(c.wrappers)-i-1](j) + } + return j +} + +// Recover panics in wrapped jobs and log them with the provided logger. +func Recover(logger Logger) JobWrapper { + return func(j Job) Job { + return FuncJob(func() { + defer func() { + if r := recover(); r != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + err, ok := r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + logger.Error(err, "panic", "stack", "...\n"+string(buf)) + } + }() + j.Run() + }) + } +} + +// DelayIfStillRunning serializes jobs, delaying subsequent runs until the +// previous one is complete. Jobs running after a delay of more than a minute +// have the delay logged at Info. +func DelayIfStillRunning(logger Logger) JobWrapper { + return func(j Job) Job { + var mu sync.Mutex + return FuncJob(func() { + start := time.Now() + mu.Lock() + defer mu.Unlock() + if dur := time.Since(start); dur > time.Minute { + logger.Info("delay", "duration", dur) + } + j.Run() + }) + } +} + +// SkipIfStillRunning skips an invocation of the Job if a previous invocation is +// still running. It logs skips to the given logger at Info level. +func SkipIfStillRunning(logger Logger) JobWrapper { + var ch = make(chan struct{}, 1) + ch <- struct{}{} + return func(j Job) Job { + return FuncJob(func() { + select { + case v := <-ch: + j.Run() + ch <- v + default: + logger.Info("skip") + } + }) + } +} diff --git a/vendor/github.com/robfig/cron/v3/constantdelay.go b/vendor/github.com/robfig/cron/v3/constantdelay.go new file mode 100644 index 0000000..cd6e7b1 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/constantdelay.go @@ -0,0 +1,27 @@ +package cron + +import "time" + +// ConstantDelaySchedule represents a simple recurring duty cycle, e.g. "Every 5 minutes". +// It does not support jobs more frequent than once a second. +type ConstantDelaySchedule struct { + Delay time.Duration +} + +// Every returns a crontab Schedule that activates once every duration. +// Delays of less than a second are not supported (will round up to 1 second). +// Any fields less than a Second are truncated. +func Every(duration time.Duration) ConstantDelaySchedule { + if duration < time.Second { + duration = time.Second + } + return ConstantDelaySchedule{ + Delay: duration - time.Duration(duration.Nanoseconds())%time.Second, + } +} + +// Next returns the next time this should be run. +// This rounds so that the next activation time will be on the second. +func (schedule ConstantDelaySchedule) Next(t time.Time) time.Time { + return t.Add(schedule.Delay - time.Duration(t.Nanosecond())*time.Nanosecond) +} diff --git a/vendor/github.com/robfig/cron/v3/cron.go b/vendor/github.com/robfig/cron/v3/cron.go new file mode 100644 index 0000000..f6e451d --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/cron.go @@ -0,0 +1,350 @@ +package cron + +import ( + "context" + "sort" + "sync" + "time" +) + +// Cron keeps track of any number of entries, invoking the associated func as +// specified by the schedule. It may be started, stopped, and the entries may +// be inspected while running. +type Cron struct { + entries []*Entry + chain Chain + stop chan struct{} + add chan *Entry + remove chan EntryID + snapshot chan chan []Entry + running bool + logger Logger + runningMu sync.Mutex + location *time.Location + parser Parser + nextID EntryID + jobWaiter sync.WaitGroup +} + +// Job is an interface for submitted cron jobs. +type Job interface { + Run() +} + +// Schedule describes a job's duty cycle. +type Schedule interface { + // Next returns the next activation time, later than the given time. + // Next is invoked initially, and then each time the job is run. + Next(time.Time) time.Time +} + +// EntryID identifies an entry within a Cron instance +type EntryID int + +// Entry consists of a schedule and the func to execute on that schedule. +type Entry struct { + // ID is the cron-assigned ID of this entry, which may be used to look up a + // snapshot or remove it. + ID EntryID + + // Schedule on which this job should be run. + Schedule Schedule + + // Next time the job will run, or the zero time if Cron has not been + // started or this entry's schedule is unsatisfiable + Next time.Time + + // Prev is the last time this job was run, or the zero time if never. + Prev time.Time + + // WrappedJob is the thing to run when the Schedule is activated. + WrappedJob Job + + // Job is the thing that was submitted to cron. + // It is kept around so that user code that needs to get at the job later, + // e.g. via Entries() can do so. + Job Job +} + +// Valid returns true if this is not the zero entry. +func (e Entry) Valid() bool { return e.ID != 0 } + +// byTime is a wrapper for sorting the entry array by time +// (with zero time at the end). +type byTime []*Entry + +func (s byTime) Len() int { return len(s) } +func (s byTime) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s byTime) Less(i, j int) bool { + // Two zero times should return false. + // Otherwise, zero is "greater" than any other time. + // (To sort it at the end of the list.) + if s[i].Next.IsZero() { + return false + } + if s[j].Next.IsZero() { + return true + } + return s[i].Next.Before(s[j].Next) +} + +// New returns a new Cron job runner, modified by the given options. +// +// Available Settings +// +// Time Zone +// Description: The time zone in which schedules are interpreted +// Default: time.Local +// +// Parser +// Description: Parser converts cron spec strings into cron.Schedules. +// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron +// +// Chain +// Description: Wrap submitted jobs to customize behavior. +// Default: A chain that recovers panics and logs them to stderr. +// +// See "cron.With*" to modify the default behavior. +func New(opts ...Option) *Cron { + c := &Cron{ + entries: nil, + chain: NewChain(), + add: make(chan *Entry), + stop: make(chan struct{}), + snapshot: make(chan chan []Entry), + remove: make(chan EntryID), + running: false, + runningMu: sync.Mutex{}, + logger: DefaultLogger, + location: time.Local, + parser: standardParser, + } + for _, opt := range opts { + opt(c) + } + return c +} + +// FuncJob is a wrapper that turns a func() into a cron.Job +type FuncJob func() + +func (f FuncJob) Run() { f() } + +// AddFunc adds a func to the Cron to be run on the given schedule. +// The spec is parsed using the time zone of this Cron instance as the default. +// An opaque ID is returned that can be used to later remove it. +func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) { + return c.AddJob(spec, FuncJob(cmd)) +} + +// AddJob adds a Job to the Cron to be run on the given schedule. +// The spec is parsed using the time zone of this Cron instance as the default. +// An opaque ID is returned that can be used to later remove it. +func (c *Cron) AddJob(spec string, cmd Job) (EntryID, error) { + schedule, err := c.parser.Parse(spec) + if err != nil { + return 0, err + } + return c.Schedule(schedule, cmd), nil +} + +// Schedule adds a Job to the Cron to be run on the given schedule. +// The job is wrapped with the configured Chain. +func (c *Cron) Schedule(schedule Schedule, cmd Job) EntryID { + c.runningMu.Lock() + defer c.runningMu.Unlock() + c.nextID++ + entry := &Entry{ + ID: c.nextID, + Schedule: schedule, + WrappedJob: c.chain.Then(cmd), + Job: cmd, + } + if !c.running { + c.entries = append(c.entries, entry) + } else { + c.add <- entry + } + return entry.ID +} + +// Entries returns a snapshot of the cron entries. +func (c *Cron) Entries() []Entry { + c.runningMu.Lock() + defer c.runningMu.Unlock() + if c.running { + replyChan := make(chan []Entry, 1) + c.snapshot <- replyChan + return <-replyChan + } + return c.entrySnapshot() +} + +// Location gets the time zone location +func (c *Cron) Location() *time.Location { + return c.location +} + +// Entry returns a snapshot of the given entry, or nil if it couldn't be found. +func (c *Cron) Entry(id EntryID) Entry { + for _, entry := range c.Entries() { + if id == entry.ID { + return entry + } + } + return Entry{} +} + +// Remove an entry from being run in the future. +func (c *Cron) Remove(id EntryID) { + c.runningMu.Lock() + defer c.runningMu.Unlock() + if c.running { + c.remove <- id + } else { + c.removeEntry(id) + } +} + +// Start the cron scheduler in its own goroutine, or no-op if already started. +func (c *Cron) Start() { + c.runningMu.Lock() + defer c.runningMu.Unlock() + if c.running { + return + } + c.running = true + go c.run() +} + +// Run the cron scheduler, or no-op if already running. +func (c *Cron) Run() { + c.runningMu.Lock() + if c.running { + c.runningMu.Unlock() + return + } + c.running = true + c.runningMu.Unlock() + c.run() +} + +// run the scheduler.. this is private just due to the need to synchronize +// access to the 'running' state variable. +func (c *Cron) run() { + c.logger.Info("start") + + // Figure out the next activation times for each entry. + now := c.now() + for _, entry := range c.entries { + entry.Next = entry.Schedule.Next(now) + c.logger.Info("schedule", "now", now, "entry", entry.ID, "next", entry.Next) + } + + for { + // Determine the next entry to run. + sort.Sort(byTime(c.entries)) + + var timer *time.Timer + if len(c.entries) == 0 || c.entries[0].Next.IsZero() { + // If there are no entries yet, just sleep - it still handles new entries + // and stop requests. + timer = time.NewTimer(100000 * time.Hour) + } else { + timer = time.NewTimer(c.entries[0].Next.Sub(now)) + } + + for { + select { + case now = <-timer.C: + now = now.In(c.location) + c.logger.Info("wake", "now", now) + + // Run every entry whose next time was less than now + for _, e := range c.entries { + if e.Next.After(now) || e.Next.IsZero() { + break + } + c.startJob(e.WrappedJob) + e.Prev = e.Next + e.Next = e.Schedule.Next(now) + c.logger.Info("run", "now", now, "entry", e.ID, "next", e.Next) + } + + case newEntry := <-c.add: + timer.Stop() + now = c.now() + newEntry.Next = newEntry.Schedule.Next(now) + c.entries = append(c.entries, newEntry) + c.logger.Info("added", "now", now, "entry", newEntry.ID, "next", newEntry.Next) + + case replyChan := <-c.snapshot: + replyChan <- c.entrySnapshot() + continue + + case <-c.stop: + timer.Stop() + c.logger.Info("stop") + return + + case id := <-c.remove: + timer.Stop() + now = c.now() + c.removeEntry(id) + c.logger.Info("removed", "entry", id) + } + + break + } + } +} + +// startJob runs the given job in a new goroutine. +func (c *Cron) startJob(j Job) { + c.jobWaiter.Add(1) + go func() { + defer c.jobWaiter.Done() + j.Run() + }() +} + +// now returns current time in c location +func (c *Cron) now() time.Time { + return time.Now().In(c.location) +} + +// Stop stops the cron scheduler if it is running; otherwise it does nothing. +// A context is returned so the caller can wait for running jobs to complete. +func (c *Cron) Stop() context.Context { + c.runningMu.Lock() + defer c.runningMu.Unlock() + if c.running { + c.stop <- struct{}{} + c.running = false + } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + c.jobWaiter.Wait() + cancel() + }() + return ctx +} + +// entrySnapshot returns a copy of the current cron entry list. +func (c *Cron) entrySnapshot() []Entry { + var entries = make([]Entry, len(c.entries)) + for i, e := range c.entries { + entries[i] = *e + } + return entries +} + +func (c *Cron) removeEntry(id EntryID) { + var entries []*Entry + for _, e := range c.entries { + if e.ID != id { + entries = append(entries, e) + } + } + c.entries = entries +} diff --git a/vendor/github.com/robfig/cron/v3/doc.go b/vendor/github.com/robfig/cron/v3/doc.go new file mode 100644 index 0000000..ac6b4b0 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/doc.go @@ -0,0 +1,212 @@ +/* +Package cron implements a cron spec parser and job runner. + +Usage + +Callers may register Funcs to be invoked on a given schedule. Cron will run +them in their own goroutines. + + c := cron.New() + c.AddFunc("30 * * * *", func() { fmt.Println("Every hour on the half hour") }) + c.AddFunc("30 3-6,20-23 * * *", func() { fmt.Println(".. in the range 3-6am, 8-11pm") }) + c.AddFunc("CRON_TZ=Asia/Tokyo 30 04 * * * *", func() { fmt.Println("Runs at 04:30 Tokyo time every day") }) + c.AddFunc("@hourly", func() { fmt.Println("Every hour, starting an hour from now") }) + c.AddFunc("@every 1h30m", func() { fmt.Println("Every hour thirty, starting an hour thirty from now") }) + c.Start() + .. + // Funcs are invoked in their own goroutine, asynchronously. + ... + // Funcs may also be added to a running Cron + c.AddFunc("@daily", func() { fmt.Println("Every day") }) + .. + // Inspect the cron job entries' next and previous run times. + inspect(c.Entries()) + .. + c.Stop() // Stop the scheduler (does not stop any jobs already running). + +CRON Expression Format + +A cron expression represents a set of times, using 5 space-separated fields. + + Field name | Mandatory? | Allowed values | Allowed special characters + ---------- | ---------- | -------------- | -------------------------- + Minutes | Yes | 0-59 | * / , - + Hours | Yes | 0-23 | * / , - + Day of month | Yes | 1-31 | * / , - ? + Month | Yes | 1-12 or JAN-DEC | * / , - + Day of week | Yes | 0-6 or SUN-SAT | * / , - ? + +Month and Day-of-week field values are case insensitive. "SUN", "Sun", and +"sun" are equally accepted. + +The specific interpretation of the format is based on the Cron Wikipedia page: +https://en.wikipedia.org/wiki/Cron + +Alternative Formats + +Alternative Cron expression formats support other fields like seconds. You can +implement that by creating a custom Parser as follows. + + cron.New( + cron.WithParser( + cron.SecondOptional | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)) + +The most popular alternative Cron expression format is Quartz: +http://www.quartz-scheduler.org/documentation/quartz-2.x/tutorials/crontrigger.html + +Special Characters + +Asterisk ( * ) + +The asterisk indicates that the cron expression will match for all values of the +field; e.g., using an asterisk in the 5th field (month) would indicate every +month. + +Slash ( / ) + +Slashes are used to describe increments of ranges. For example 3-59/15 in the +1st field (minutes) would indicate the 3rd minute of the hour and every 15 +minutes thereafter. The form "*\/..." is equivalent to the form "first-last/...", +that is, an increment over the largest possible range of the field. The form +"N/..." is accepted as meaning "N-MAX/...", that is, starting at N, use the +increment until the end of that specific range. It does not wrap around. + +Comma ( , ) + +Commas are used to separate items of a list. For example, using "MON,WED,FRI" in +the 5th field (day of week) would mean Mondays, Wednesdays and Fridays. + +Hyphen ( - ) + +Hyphens are used to define ranges. For example, 9-17 would indicate every +hour between 9am and 5pm inclusive. + +Question mark ( ? ) + +Question mark may be used instead of '*' for leaving either day-of-month or +day-of-week blank. + +Predefined schedules + +You may use one of several pre-defined schedules in place of a cron expression. + + Entry | Description | Equivalent To + ----- | ----------- | ------------- + @yearly (or @annually) | Run once a year, midnight, Jan. 1st | 0 0 1 1 * + @monthly | Run once a month, midnight, first of month | 0 0 1 * * + @weekly | Run once a week, midnight between Sat/Sun | 0 0 * * 0 + @daily (or @midnight) | Run once a day, midnight | 0 0 * * * + @hourly | Run once an hour, beginning of hour | 0 * * * * + +Intervals + +You may also schedule a job to execute at fixed intervals, starting at the time it's added +or cron is run. This is supported by formatting the cron spec like this: + + @every + +where "duration" is a string accepted by time.ParseDuration +(http://golang.org/pkg/time/#ParseDuration). + +For example, "@every 1h30m10s" would indicate a schedule that activates after +1 hour, 30 minutes, 10 seconds, and then every interval after that. + +Note: The interval does not take the job runtime into account. For example, +if a job takes 3 minutes to run, and it is scheduled to run every 5 minutes, +it will have only 2 minutes of idle time between each run. + +Time zones + +By default, all interpretation and scheduling is done in the machine's local +time zone (time.Local). You can specify a different time zone on construction: + + cron.New( + cron.WithLocation(time.UTC)) + +Individual cron schedules may also override the time zone they are to be +interpreted in by providing an additional space-separated field at the beginning +of the cron spec, of the form "CRON_TZ=Asia/Tokyo". + +For example: + + # Runs at 6am in time.Local + cron.New().AddFunc("0 6 * * ?", ...) + + # Runs at 6am in America/New_York + nyc, _ := time.LoadLocation("America/New_York") + c := cron.New(cron.WithLocation(nyc)) + c.AddFunc("0 6 * * ?", ...) + + # Runs at 6am in Asia/Tokyo + cron.New().AddFunc("CRON_TZ=Asia/Tokyo 0 6 * * ?", ...) + + # Runs at 6am in Asia/Tokyo + c := cron.New(cron.WithLocation(nyc)) + c.SetLocation("America/New_York") + c.AddFunc("CRON_TZ=Asia/Tokyo 0 6 * * ?", ...) + +The prefix "TZ=(TIME ZONE)" is also supported for legacy compatibility. + +Be aware that jobs scheduled during daylight-savings leap-ahead transitions will +not be run! + +Job Wrappers / Chain + +A Cron runner may be configured with a chain of job wrappers to add +cross-cutting functionality to all submitted jobs. For example, they may be used +to achieve the following effects: + + - Recover any panics from jobs (activated by default) + - Delay a job's execution if the previous run hasn't completed yet + - Skip a job's execution if the previous run hasn't completed yet + - Log each job's invocations + +Install wrappers for all jobs added to a cron using the `cron.WithChain` option: + + cron.New(cron.WithChain( + cron.SkipIfStillRunning(logger), + )) + +Install wrappers for individual jobs by explicitly wrapping them: + + job = cron.NewChain( + cron.SkipIfStillRunning(logger), + ).Then(job) + +Thread safety + +Since the Cron service runs concurrently with the calling code, some amount of +care must be taken to ensure proper synchronization. + +All cron methods are designed to be correctly synchronized as long as the caller +ensures that invocations have a clear happens-before ordering between them. + +Logging + +Cron defines a Logger interface that is a subset of the one defined in +github.com/go-logr/logr. It has two logging levels (Info and Error), and +parameters are key/value pairs. This makes it possible for cron logging to plug +into structured logging systems. An adapter, [Verbose]PrintfLogger, is provided +to wrap the standard library *log.Logger. + +For additional insight into Cron operations, verbose logging may be activated +which will record job runs, scheduling decisions, and added or removed jobs. +Activate it with a one-off logger as follows: + + cron.New( + cron.WithLogger( + cron.VerbosePrintfLogger(log.New(os.Stdout, "cron: ", log.LstdFlags)))) + + +Implementation + +Cron entries are stored in an array, sorted by their next activation time. Cron +sleeps until the next job is due to be run. + +Upon waking: + - it runs each entry that is active on that second + - it calculates the next run times for the jobs that were run + - it re-sorts the array of entries by next activation time. + - it goes to sleep until the soonest job. +*/ +package cron diff --git a/vendor/github.com/robfig/cron/v3/go.mod b/vendor/github.com/robfig/cron/v3/go.mod new file mode 100644 index 0000000..8c95bf4 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/go.mod @@ -0,0 +1,3 @@ +module github.com/robfig/cron/v3 + +go 1.12 diff --git a/vendor/github.com/robfig/cron/v3/logger.go b/vendor/github.com/robfig/cron/v3/logger.go new file mode 100644 index 0000000..b4efcc0 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/logger.go @@ -0,0 +1,86 @@ +package cron + +import ( + "io/ioutil" + "log" + "os" + "strings" + "time" +) + +// DefaultLogger is used by Cron if none is specified. +var DefaultLogger Logger = PrintfLogger(log.New(os.Stdout, "cron: ", log.LstdFlags)) + +// DiscardLogger can be used by callers to discard all log messages. +var DiscardLogger Logger = PrintfLogger(log.New(ioutil.Discard, "", 0)) + +// Logger is the interface used in this package for logging, so that any backend +// can be plugged in. It is a subset of the github.com/go-logr/logr interface. +type Logger interface { + // Info logs routine messages about cron's operation. + Info(msg string, keysAndValues ...interface{}) + // Error logs an error condition. + Error(err error, msg string, keysAndValues ...interface{}) +} + +// PrintfLogger wraps a Printf-based logger (such as the standard library "log") +// into an implementation of the Logger interface which logs errors only. +func PrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger { + return printfLogger{l, false} +} + +// VerbosePrintfLogger wraps a Printf-based logger (such as the standard library +// "log") into an implementation of the Logger interface which logs everything. +func VerbosePrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger { + return printfLogger{l, true} +} + +type printfLogger struct { + logger interface{ Printf(string, ...interface{}) } + logInfo bool +} + +func (pl printfLogger) Info(msg string, keysAndValues ...interface{}) { + if pl.logInfo { + keysAndValues = formatTimes(keysAndValues) + pl.logger.Printf( + formatString(len(keysAndValues)), + append([]interface{}{msg}, keysAndValues...)...) + } +} + +func (pl printfLogger) Error(err error, msg string, keysAndValues ...interface{}) { + keysAndValues = formatTimes(keysAndValues) + pl.logger.Printf( + formatString(len(keysAndValues)+2), + append([]interface{}{msg, "error", err}, keysAndValues...)...) +} + +// formatString returns a logfmt-like format string for the number of +// key/values. +func formatString(numKeysAndValues int) string { + var sb strings.Builder + sb.WriteString("%s") + if numKeysAndValues > 0 { + sb.WriteString(", ") + } + for i := 0; i < numKeysAndValues/2; i++ { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString("%v=%v") + } + return sb.String() +} + +// formatTimes formats any time.Time values as RFC3339. +func formatTimes(keysAndValues []interface{}) []interface{} { + var formattedArgs []interface{} + for _, arg := range keysAndValues { + if t, ok := arg.(time.Time); ok { + arg = t.Format(time.RFC3339) + } + formattedArgs = append(formattedArgs, arg) + } + return formattedArgs +} diff --git a/vendor/github.com/robfig/cron/v3/option.go b/vendor/github.com/robfig/cron/v3/option.go new file mode 100644 index 0000000..0763820 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/option.go @@ -0,0 +1,45 @@ +package cron + +import ( + "time" +) + +// Option represents a modification to the default behavior of a Cron. +type Option func(*Cron) + +// WithLocation overrides the timezone of the cron instance. +func WithLocation(loc *time.Location) Option { + return func(c *Cron) { + c.location = loc + } +} + +// WithSeconds overrides the parser used for interpreting job schedules to +// include a seconds field as the first one. +func WithSeconds() Option { + return WithParser(NewParser( + Second | Minute | Hour | Dom | Month | Dow | Descriptor, + )) +} + +// WithParser overrides the parser used for interpreting job schedules. +func WithParser(p Parser) Option { + return func(c *Cron) { + c.parser = p + } +} + +// WithChain specifies Job wrappers to apply to all jobs added to this cron. +// Refer to the Chain* functions in this package for provided wrappers. +func WithChain(wrappers ...JobWrapper) Option { + return func(c *Cron) { + c.chain = NewChain(wrappers...) + } +} + +// WithLogger uses the provided logger. +func WithLogger(logger Logger) Option { + return func(c *Cron) { + c.logger = logger + } +} diff --git a/vendor/github.com/robfig/cron/v3/parser.go b/vendor/github.com/robfig/cron/v3/parser.go new file mode 100644 index 0000000..3cf8879 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/parser.go @@ -0,0 +1,434 @@ +package cron + +import ( + "fmt" + "math" + "strconv" + "strings" + "time" +) + +// Configuration options for creating a parser. Most options specify which +// fields should be included, while others enable features. If a field is not +// included the parser will assume a default value. These options do not change +// the order fields are parse in. +type ParseOption int + +const ( + Second ParseOption = 1 << iota // Seconds field, default 0 + SecondOptional // Optional seconds field, default 0 + Minute // Minutes field, default 0 + Hour // Hours field, default 0 + Dom // Day of month field, default * + Month // Month field, default * + Dow // Day of week field, default * + DowOptional // Optional day of week field, default * + Descriptor // Allow descriptors such as @monthly, @weekly, etc. +) + +var places = []ParseOption{ + Second, + Minute, + Hour, + Dom, + Month, + Dow, +} + +var defaults = []string{ + "0", + "0", + "0", + "*", + "*", + "*", +} + +// A custom Parser that can be configured. +type Parser struct { + options ParseOption +} + +// NewParser creates a Parser with custom options. +// +// It panics if more than one Optional is given, since it would be impossible to +// correctly infer which optional is provided or missing in general. +// +// Examples +// +// // Standard parser without descriptors +// specParser := NewParser(Minute | Hour | Dom | Month | Dow) +// sched, err := specParser.Parse("0 0 15 */3 *") +// +// // Same as above, just excludes time fields +// subsParser := NewParser(Dom | Month | Dow) +// sched, err := specParser.Parse("15 */3 *") +// +// // Same as above, just makes Dow optional +// subsParser := NewParser(Dom | Month | DowOptional) +// sched, err := specParser.Parse("15 */3") +// +func NewParser(options ParseOption) Parser { + optionals := 0 + if options&DowOptional > 0 { + optionals++ + } + if options&SecondOptional > 0 { + optionals++ + } + if optionals > 1 { + panic("multiple optionals may not be configured") + } + return Parser{options} +} + +// Parse returns a new crontab schedule representing the given spec. +// It returns a descriptive error if the spec is not valid. +// It accepts crontab specs and features configured by NewParser. +func (p Parser) Parse(spec string) (Schedule, error) { + if len(spec) == 0 { + return nil, fmt.Errorf("empty spec string") + } + + // Extract timezone if present + var loc = time.Local + if strings.HasPrefix(spec, "TZ=") || strings.HasPrefix(spec, "CRON_TZ=") { + var err error + i := strings.Index(spec, " ") + eq := strings.Index(spec, "=") + if loc, err = time.LoadLocation(spec[eq+1 : i]); err != nil { + return nil, fmt.Errorf("provided bad location %s: %v", spec[eq+1:i], err) + } + spec = strings.TrimSpace(spec[i:]) + } + + // Handle named schedules (descriptors), if configured + if strings.HasPrefix(spec, "@") { + if p.options&Descriptor == 0 { + return nil, fmt.Errorf("parser does not accept descriptors: %v", spec) + } + return parseDescriptor(spec, loc) + } + + // Split on whitespace. + fields := strings.Fields(spec) + + // Validate & fill in any omitted or optional fields + var err error + fields, err = normalizeFields(fields, p.options) + if err != nil { + return nil, err + } + + field := func(field string, r bounds) uint64 { + if err != nil { + return 0 + } + var bits uint64 + bits, err = getField(field, r) + return bits + } + + var ( + second = field(fields[0], seconds) + minute = field(fields[1], minutes) + hour = field(fields[2], hours) + dayofmonth = field(fields[3], dom) + month = field(fields[4], months) + dayofweek = field(fields[5], dow) + ) + if err != nil { + return nil, err + } + + return &SpecSchedule{ + Second: second, + Minute: minute, + Hour: hour, + Dom: dayofmonth, + Month: month, + Dow: dayofweek, + Location: loc, + }, nil +} + +// normalizeFields takes a subset set of the time fields and returns the full set +// with defaults (zeroes) populated for unset fields. +// +// As part of performing this function, it also validates that the provided +// fields are compatible with the configured options. +func normalizeFields(fields []string, options ParseOption) ([]string, error) { + // Validate optionals & add their field to options + optionals := 0 + if options&SecondOptional > 0 { + options |= Second + optionals++ + } + if options&DowOptional > 0 { + options |= Dow + optionals++ + } + if optionals > 1 { + return nil, fmt.Errorf("multiple optionals may not be configured") + } + + // Figure out how many fields we need + max := 0 + for _, place := range places { + if options&place > 0 { + max++ + } + } + min := max - optionals + + // Validate number of fields + if count := len(fields); count < min || count > max { + if min == max { + return nil, fmt.Errorf("expected exactly %d fields, found %d: %s", min, count, fields) + } + return nil, fmt.Errorf("expected %d to %d fields, found %d: %s", min, max, count, fields) + } + + // Populate the optional field if not provided + if min < max && len(fields) == min { + switch { + case options&DowOptional > 0: + fields = append(fields, defaults[5]) // TODO: improve access to default + case options&SecondOptional > 0: + fields = append([]string{defaults[0]}, fields...) + default: + return nil, fmt.Errorf("unknown optional field") + } + } + + // Populate all fields not part of options with their defaults + n := 0 + expandedFields := make([]string, len(places)) + copy(expandedFields, defaults) + for i, place := range places { + if options&place > 0 { + expandedFields[i] = fields[n] + n++ + } + } + return expandedFields, nil +} + +var standardParser = NewParser( + Minute | Hour | Dom | Month | Dow | Descriptor, +) + +// ParseStandard returns a new crontab schedule representing the given +// standardSpec (https://en.wikipedia.org/wiki/Cron). It requires 5 entries +// representing: minute, hour, day of month, month and day of week, in that +// order. It returns a descriptive error if the spec is not valid. +// +// It accepts +// - Standard crontab specs, e.g. "* * * * ?" +// - Descriptors, e.g. "@midnight", "@every 1h30m" +func ParseStandard(standardSpec string) (Schedule, error) { + return standardParser.Parse(standardSpec) +} + +// getField returns an Int with the bits set representing all of the times that +// the field represents or error parsing field value. A "field" is a comma-separated +// list of "ranges". +func getField(field string, r bounds) (uint64, error) { + var bits uint64 + ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' }) + for _, expr := range ranges { + bit, err := getRange(expr, r) + if err != nil { + return bits, err + } + bits |= bit + } + return bits, nil +} + +// getRange returns the bits indicated by the given expression: +// number | number "-" number [ "/" number ] +// or error parsing range. +func getRange(expr string, r bounds) (uint64, error) { + var ( + start, end, step uint + rangeAndStep = strings.Split(expr, "/") + lowAndHigh = strings.Split(rangeAndStep[0], "-") + singleDigit = len(lowAndHigh) == 1 + err error + ) + + var extra uint64 + if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" { + start = r.min + end = r.max + extra = starBit + } else { + start, err = parseIntOrName(lowAndHigh[0], r.names) + if err != nil { + return 0, err + } + switch len(lowAndHigh) { + case 1: + end = start + case 2: + end, err = parseIntOrName(lowAndHigh[1], r.names) + if err != nil { + return 0, err + } + default: + return 0, fmt.Errorf("too many hyphens: %s", expr) + } + } + + switch len(rangeAndStep) { + case 1: + step = 1 + case 2: + step, err = mustParseInt(rangeAndStep[1]) + if err != nil { + return 0, err + } + + // Special handling: "N/step" means "N-max/step". + if singleDigit { + end = r.max + } + if step > 1 { + extra = 0 + } + default: + return 0, fmt.Errorf("too many slashes: %s", expr) + } + + if start < r.min { + return 0, fmt.Errorf("beginning of range (%d) below minimum (%d): %s", start, r.min, expr) + } + if end > r.max { + return 0, fmt.Errorf("end of range (%d) above maximum (%d): %s", end, r.max, expr) + } + if start > end { + return 0, fmt.Errorf("beginning of range (%d) beyond end of range (%d): %s", start, end, expr) + } + if step == 0 { + return 0, fmt.Errorf("step of range should be a positive number: %s", expr) + } + + return getBits(start, end, step) | extra, nil +} + +// parseIntOrName returns the (possibly-named) integer contained in expr. +func parseIntOrName(expr string, names map[string]uint) (uint, error) { + if names != nil { + if namedInt, ok := names[strings.ToLower(expr)]; ok { + return namedInt, nil + } + } + return mustParseInt(expr) +} + +// mustParseInt parses the given expression as an int or returns an error. +func mustParseInt(expr string) (uint, error) { + num, err := strconv.Atoi(expr) + if err != nil { + return 0, fmt.Errorf("failed to parse int from %s: %s", expr, err) + } + if num < 0 { + return 0, fmt.Errorf("negative number (%d) not allowed: %s", num, expr) + } + + return uint(num), nil +} + +// getBits sets all bits in the range [min, max], modulo the given step size. +func getBits(min, max, step uint) uint64 { + var bits uint64 + + // If step is 1, use shifts. + if step == 1 { + return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min) + } + + // Else, use a simple loop. + for i := min; i <= max; i += step { + bits |= 1 << i + } + return bits +} + +// all returns all bits within the given bounds. (plus the star bit) +func all(r bounds) uint64 { + return getBits(r.min, r.max, 1) | starBit +} + +// parseDescriptor returns a predefined schedule for the expression, or error if none matches. +func parseDescriptor(descriptor string, loc *time.Location) (Schedule, error) { + switch descriptor { + case "@yearly", "@annually": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Dom: 1 << dom.min, + Month: 1 << months.min, + Dow: all(dow), + Location: loc, + }, nil + + case "@monthly": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Dom: 1 << dom.min, + Month: all(months), + Dow: all(dow), + Location: loc, + }, nil + + case "@weekly": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Dom: all(dom), + Month: all(months), + Dow: 1 << dow.min, + Location: loc, + }, nil + + case "@daily", "@midnight": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Dom: all(dom), + Month: all(months), + Dow: all(dow), + Location: loc, + }, nil + + case "@hourly": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: all(hours), + Dom: all(dom), + Month: all(months), + Dow: all(dow), + Location: loc, + }, nil + + } + + const every = "@every " + if strings.HasPrefix(descriptor, every) { + duration, err := time.ParseDuration(descriptor[len(every):]) + if err != nil { + return nil, fmt.Errorf("failed to parse duration %s: %s", descriptor, err) + } + return Every(duration), nil + } + + return nil, fmt.Errorf("unrecognized descriptor: %s", descriptor) +} diff --git a/vendor/github.com/robfig/cron/v3/spec.go b/vendor/github.com/robfig/cron/v3/spec.go new file mode 100644 index 0000000..fa1e241 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/spec.go @@ -0,0 +1,188 @@ +package cron + +import "time" + +// SpecSchedule specifies a duty cycle (to the second granularity), based on a +// traditional crontab specification. It is computed initially and stored as bit sets. +type SpecSchedule struct { + Second, Minute, Hour, Dom, Month, Dow uint64 + + // Override location for this schedule. + Location *time.Location +} + +// bounds provides a range of acceptable values (plus a map of name to value). +type bounds struct { + min, max uint + names map[string]uint +} + +// The bounds for each field. +var ( + seconds = bounds{0, 59, nil} + minutes = bounds{0, 59, nil} + hours = bounds{0, 23, nil} + dom = bounds{1, 31, nil} + months = bounds{1, 12, map[string]uint{ + "jan": 1, + "feb": 2, + "mar": 3, + "apr": 4, + "may": 5, + "jun": 6, + "jul": 7, + "aug": 8, + "sep": 9, + "oct": 10, + "nov": 11, + "dec": 12, + }} + dow = bounds{0, 6, map[string]uint{ + "sun": 0, + "mon": 1, + "tue": 2, + "wed": 3, + "thu": 4, + "fri": 5, + "sat": 6, + }} +) + +const ( + // Set the top bit if a star was included in the expression. + starBit = 1 << 63 +) + +// Next returns the next time this schedule is activated, greater than the given +// time. If no time can be found to satisfy the schedule, return the zero time. +func (s *SpecSchedule) Next(t time.Time) time.Time { + // General approach + // + // For Month, Day, Hour, Minute, Second: + // Check if the time value matches. If yes, continue to the next field. + // If the field doesn't match the schedule, then increment the field until it matches. + // While incrementing the field, a wrap-around brings it back to the beginning + // of the field list (since it is necessary to re-verify previous field + // values) + + // Convert the given time into the schedule's timezone, if one is specified. + // Save the original timezone so we can convert back after we find a time. + // Note that schedules without a time zone specified (time.Local) are treated + // as local to the time provided. + origLocation := t.Location() + loc := s.Location + if loc == time.Local { + loc = t.Location() + } + if s.Location != time.Local { + t = t.In(s.Location) + } + + // Start at the earliest possible time (the upcoming second). + t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond) + + // This flag indicates whether a field has been incremented. + added := false + + // If no time is found within five years, return zero. + yearLimit := t.Year() + 5 + +WRAP: + if t.Year() > yearLimit { + return time.Time{} + } + + // Find the first applicable month. + // If it's this month, then do nothing. + for 1< 12 { + t = t.Add(time.Duration(24-t.Hour()) * time.Hour) + } else { + t = t.Add(time.Duration(-t.Hour()) * time.Hour) + } + } + + if t.Day() == 1 { + goto WRAP + } + } + + for 1< 0 + dowMatch bool = 1< 0 + ) + if s.Dom&starBit > 0 || s.Dow&starBit > 0 { + return domMatch && dowMatch + } + return domMatch || dowMatch +} diff --git a/vendor/github.com/shopspring/decimal/.gitignore b/vendor/github.com/shopspring/decimal/.gitignore new file mode 100644 index 0000000..8a43ce9 --- /dev/null +++ b/vendor/github.com/shopspring/decimal/.gitignore @@ -0,0 +1,6 @@ +.git +*.swp + +# IntelliJ +.idea/ +*.iml diff --git a/vendor/github.com/shopspring/decimal/.travis.yml b/vendor/github.com/shopspring/decimal/.travis.yml new file mode 100644 index 0000000..55d42b2 --- /dev/null +++ b/vendor/github.com/shopspring/decimal/.travis.yml @@ -0,0 +1,13 @@ +language: go + +go: + - 1.7.x + - 1.12.x + - 1.13.x + - tip + +install: + - go build . + +script: + - go test -v diff --git a/vendor/github.com/shopspring/decimal/CHANGELOG.md b/vendor/github.com/shopspring/decimal/CHANGELOG.md new file mode 100644 index 0000000..01ba02f --- /dev/null +++ b/vendor/github.com/shopspring/decimal/CHANGELOG.md @@ -0,0 +1,19 @@ +## Decimal v1.2.0 + +#### BREAKING +- Drop support for Go version older than 1.7 [#172](https://github.com/shopspring/decimal/pull/172) + +#### FEATURES +- Add NewFromInt and NewFromInt32 initializers [#72](https://github.com/shopspring/decimal/pull/72) +- Add support for Go modules [#157](https://github.com/shopspring/decimal/pull/157) +- Add BigInt, BigFloat helper methods [#171](https://github.com/shopspring/decimal/pull/171) + +#### ENHANCEMENTS +- Memory usage optimization [#160](https://github.com/shopspring/decimal/pull/160) +- Updated travis CI golang versions [#156](https://github.com/shopspring/decimal/pull/156) +- Update documentation [#173](https://github.com/shopspring/decimal/pull/173) +- Improve code quality [#174](https://github.com/shopspring/decimal/pull/174) + +#### BUGFIXES +- Revert remove insignificant digits [#159](https://github.com/shopspring/decimal/pull/159) +- Remove 15 interval for RoundCash [#166](https://github.com/shopspring/decimal/pull/166) diff --git a/vendor/github.com/shopspring/decimal/LICENSE b/vendor/github.com/shopspring/decimal/LICENSE new file mode 100644 index 0000000..ad2148a --- /dev/null +++ b/vendor/github.com/shopspring/decimal/LICENSE @@ -0,0 +1,45 @@ +The MIT License (MIT) + +Copyright (c) 2015 Spring, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +- Based on https://github.com/oguzbilgic/fpd, which has the following license: +""" +The MIT License (MIT) + +Copyright (c) 2013 Oguz Bilgic + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" diff --git a/vendor/github.com/shopspring/decimal/README.md b/vendor/github.com/shopspring/decimal/README.md new file mode 100644 index 0000000..b70f901 --- /dev/null +++ b/vendor/github.com/shopspring/decimal/README.md @@ -0,0 +1,130 @@ +# decimal + +[![Build Status](https://travis-ci.org/shopspring/decimal.png?branch=master)](https://travis-ci.org/shopspring/decimal) [![GoDoc](https://godoc.org/github.com/shopspring/decimal?status.svg)](https://godoc.org/github.com/shopspring/decimal) [![Go Report Card](https://goreportcard.com/badge/github.com/shopspring/decimal)](https://goreportcard.com/report/github.com/shopspring/decimal) + +Arbitrary-precision fixed-point decimal numbers in go. + +_Note:_ Decimal library can "only" represent numbers with a maximum of 2^31 digits after the decimal point. + +## Features + + * The zero-value is 0, and is safe to use without initialization + * Addition, subtraction, multiplication with no loss of precision + * Division with specified precision + * Database/sql serialization/deserialization + * JSON and XML serialization/deserialization + +## Install + +Run `go get github.com/shopspring/decimal` + +## Requirements + +Decimal library requires Go version `>=1.7` + +## Usage + +```go +package main + +import ( + "fmt" + "github.com/shopspring/decimal" +) + +func main() { + price, err := decimal.NewFromString("136.02") + if err != nil { + panic(err) + } + + quantity := decimal.NewFromInt(3) + + fee, _ := decimal.NewFromString(".035") + taxRate, _ := decimal.NewFromString(".08875") + + subtotal := price.Mul(quantity) + + preTax := subtotal.Mul(fee.Add(decimal.NewFromFloat(1))) + + total := preTax.Mul(taxRate.Add(decimal.NewFromFloat(1))) + + fmt.Println("Subtotal:", subtotal) // Subtotal: 408.06 + fmt.Println("Pre-tax:", preTax) // Pre-tax: 422.3421 + fmt.Println("Taxes:", total.Sub(preTax)) // Taxes: 37.482861375 + fmt.Println("Total:", total) // Total: 459.824961375 + fmt.Println("Tax rate:", total.Sub(preTax).Div(preTax)) // Tax rate: 0.08875 +} +``` + +## Documentation + +http://godoc.org/github.com/shopspring/decimal + +## Production Usage + +* [Spring](https://shopspring.com/), since August 14, 2014. +* If you are using this in production, please let us know! + +## FAQ + +#### Why don't you just use float64? + +Because float64 (or any binary floating point type, actually) can't represent +numbers such as `0.1` exactly. + +Consider this code: http://play.golang.org/p/TQBd4yJe6B You might expect that +it prints out `10`, but it actually prints `9.999999999999831`. Over time, +these small errors can really add up! + +#### Why don't you just use big.Rat? + +big.Rat is fine for representing rational numbers, but Decimal is better for +representing money. Why? Here's a (contrived) example: + +Let's say you use big.Rat, and you have two numbers, x and y, both +representing 1/3, and you have `z = 1 - x - y = 1/3`. If you print each one +out, the string output has to stop somewhere (let's say it stops at 3 decimal +digits, for simplicity), so you'll get 0.333, 0.333, and 0.333. But where did +the other 0.001 go? + +Here's the above example as code: http://play.golang.org/p/lCZZs0w9KE + +With Decimal, the strings being printed out represent the number exactly. So, +if you have `x = y = 1/3` (with precision 3), they will actually be equal to +0.333, and when you do `z = 1 - x - y`, `z` will be equal to .334. No money is +unaccounted for! + +You still have to be careful. If you want to split a number `N` 3 ways, you +can't just send `N/3` to three different people. You have to pick one to send +`N - (2/3*N)` to. That person will receive the fraction of a penny remainder. + +But, it is much easier to be careful with Decimal than with big.Rat. + +#### Why isn't the API similar to big.Int's? + +big.Int's API is built to reduce the number of memory allocations for maximal +performance. This makes sense for its use-case, but the trade-off is that the +API is awkward and easy to misuse. + +For example, to add two big.Ints, you do: `z := new(big.Int).Add(x, y)`. A +developer unfamiliar with this API might try to do `z := a.Add(a, b)`. This +modifies `a` and sets `z` as an alias for `a`, which they might not expect. It +also modifies any other aliases to `a`. + +Here's an example of the subtle bugs you can introduce with big.Int's API: +https://play.golang.org/p/x2R_78pa8r + +In contrast, it's difficult to make such mistakes with decimal. Decimals +behave like other go numbers types: even though `a = b` will not deep copy +`b` into `a`, it is impossible to modify a Decimal, since all Decimal methods +return new Decimals and do not modify the originals. The downside is that +this causes extra allocations, so Decimal is less performant. My assumption +is that if you're using Decimals, you probably care more about correctness +than performance. + +## License + +The MIT License (MIT) + +This is a heavily modified fork of [fpd.Decimal](https://github.com/oguzbilgic/fpd), which was also released under the MIT License. diff --git a/vendor/github.com/shopspring/decimal/decimal-go.go b/vendor/github.com/shopspring/decimal/decimal-go.go new file mode 100644 index 0000000..9958d69 --- /dev/null +++ b/vendor/github.com/shopspring/decimal/decimal-go.go @@ -0,0 +1,415 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Multiprecision decimal numbers. +// For floating-point formatting only; not general purpose. +// Only operations are assign and (binary) left/right shift. +// Can do binary floating point in multiprecision decimal precisely +// because 2 divides 10; cannot do decimal floating point +// in multiprecision binary precisely. + +package decimal + +type decimal struct { + d [800]byte // digits, big-endian representation + nd int // number of digits used + dp int // decimal point + neg bool // negative flag + trunc bool // discarded nonzero digits beyond d[:nd] +} + +func (a *decimal) String() string { + n := 10 + a.nd + if a.dp > 0 { + n += a.dp + } + if a.dp < 0 { + n += -a.dp + } + + buf := make([]byte, n) + w := 0 + switch { + case a.nd == 0: + return "0" + + case a.dp <= 0: + // zeros fill space between decimal point and digits + buf[w] = '0' + w++ + buf[w] = '.' + w++ + w += digitZero(buf[w : w+-a.dp]) + w += copy(buf[w:], a.d[0:a.nd]) + + case a.dp < a.nd: + // decimal point in middle of digits + w += copy(buf[w:], a.d[0:a.dp]) + buf[w] = '.' + w++ + w += copy(buf[w:], a.d[a.dp:a.nd]) + + default: + // zeros fill space between digits and decimal point + w += copy(buf[w:], a.d[0:a.nd]) + w += digitZero(buf[w : w+a.dp-a.nd]) + } + return string(buf[0:w]) +} + +func digitZero(dst []byte) int { + for i := range dst { + dst[i] = '0' + } + return len(dst) +} + +// trim trailing zeros from number. +// (They are meaningless; the decimal point is tracked +// independent of the number of digits.) +func trim(a *decimal) { + for a.nd > 0 && a.d[a.nd-1] == '0' { + a.nd-- + } + if a.nd == 0 { + a.dp = 0 + } +} + +// Assign v to a. +func (a *decimal) Assign(v uint64) { + var buf [24]byte + + // Write reversed decimal in buf. + n := 0 + for v > 0 { + v1 := v / 10 + v -= 10 * v1 + buf[n] = byte(v + '0') + n++ + v = v1 + } + + // Reverse again to produce forward decimal in a.d. + a.nd = 0 + for n--; n >= 0; n-- { + a.d[a.nd] = buf[n] + a.nd++ + } + a.dp = a.nd + trim(a) +} + +// Maximum shift that we can do in one pass without overflow. +// A uint has 32 or 64 bits, and we have to be able to accommodate 9<> 63) +const maxShift = uintSize - 4 + +// Binary shift right (/ 2) by k bits. k <= maxShift to avoid overflow. +func rightShift(a *decimal, k uint) { + r := 0 // read pointer + w := 0 // write pointer + + // Pick up enough leading digits to cover first shift. + var n uint + for ; n>>k == 0; r++ { + if r >= a.nd { + if n == 0 { + // a == 0; shouldn't get here, but handle anyway. + a.nd = 0 + return + } + for n>>k == 0 { + n = n * 10 + r++ + } + break + } + c := uint(a.d[r]) + n = n*10 + c - '0' + } + a.dp -= r - 1 + + var mask uint = (1 << k) - 1 + + // Pick up a digit, put down a digit. + for ; r < a.nd; r++ { + c := uint(a.d[r]) + dig := n >> k + n &= mask + a.d[w] = byte(dig + '0') + w++ + n = n*10 + c - '0' + } + + // Put down extra digits. + for n > 0 { + dig := n >> k + n &= mask + if w < len(a.d) { + a.d[w] = byte(dig + '0') + w++ + } else if dig > 0 { + a.trunc = true + } + n = n * 10 + } + + a.nd = w + trim(a) +} + +// Cheat sheet for left shift: table indexed by shift count giving +// number of new digits that will be introduced by that shift. +// +// For example, leftcheats[4] = {2, "625"}. That means that +// if we are shifting by 4 (multiplying by 16), it will add 2 digits +// when the string prefix is "625" through "999", and one fewer digit +// if the string prefix is "000" through "624". +// +// Credit for this trick goes to Ken. + +type leftCheat struct { + delta int // number of new digits + cutoff string // minus one digit if original < a. +} + +var leftcheats = []leftCheat{ + // Leading digits of 1/2^i = 5^i. + // 5^23 is not an exact 64-bit floating point number, + // so have to use bc for the math. + // Go up to 60 to be large enough for 32bit and 64bit platforms. + /* + seq 60 | sed 's/^/5^/' | bc | + awk 'BEGIN{ print "\t{ 0, \"\" }," } + { + log2 = log(2)/log(10) + printf("\t{ %d, \"%s\" },\t// * %d\n", + int(log2*NR+1), $0, 2**NR) + }' + */ + {0, ""}, + {1, "5"}, // * 2 + {1, "25"}, // * 4 + {1, "125"}, // * 8 + {2, "625"}, // * 16 + {2, "3125"}, // * 32 + {2, "15625"}, // * 64 + {3, "78125"}, // * 128 + {3, "390625"}, // * 256 + {3, "1953125"}, // * 512 + {4, "9765625"}, // * 1024 + {4, "48828125"}, // * 2048 + {4, "244140625"}, // * 4096 + {4, "1220703125"}, // * 8192 + {5, "6103515625"}, // * 16384 + {5, "30517578125"}, // * 32768 + {5, "152587890625"}, // * 65536 + {6, "762939453125"}, // * 131072 + {6, "3814697265625"}, // * 262144 + {6, "19073486328125"}, // * 524288 + {7, "95367431640625"}, // * 1048576 + {7, "476837158203125"}, // * 2097152 + {7, "2384185791015625"}, // * 4194304 + {7, "11920928955078125"}, // * 8388608 + {8, "59604644775390625"}, // * 16777216 + {8, "298023223876953125"}, // * 33554432 + {8, "1490116119384765625"}, // * 67108864 + {9, "7450580596923828125"}, // * 134217728 + {9, "37252902984619140625"}, // * 268435456 + {9, "186264514923095703125"}, // * 536870912 + {10, "931322574615478515625"}, // * 1073741824 + {10, "4656612873077392578125"}, // * 2147483648 + {10, "23283064365386962890625"}, // * 4294967296 + {10, "116415321826934814453125"}, // * 8589934592 + {11, "582076609134674072265625"}, // * 17179869184 + {11, "2910383045673370361328125"}, // * 34359738368 + {11, "14551915228366851806640625"}, // * 68719476736 + {12, "72759576141834259033203125"}, // * 137438953472 + {12, "363797880709171295166015625"}, // * 274877906944 + {12, "1818989403545856475830078125"}, // * 549755813888 + {13, "9094947017729282379150390625"}, // * 1099511627776 + {13, "45474735088646411895751953125"}, // * 2199023255552 + {13, "227373675443232059478759765625"}, // * 4398046511104 + {13, "1136868377216160297393798828125"}, // * 8796093022208 + {14, "5684341886080801486968994140625"}, // * 17592186044416 + {14, "28421709430404007434844970703125"}, // * 35184372088832 + {14, "142108547152020037174224853515625"}, // * 70368744177664 + {15, "710542735760100185871124267578125"}, // * 140737488355328 + {15, "3552713678800500929355621337890625"}, // * 281474976710656 + {15, "17763568394002504646778106689453125"}, // * 562949953421312 + {16, "88817841970012523233890533447265625"}, // * 1125899906842624 + {16, "444089209850062616169452667236328125"}, // * 2251799813685248 + {16, "2220446049250313080847263336181640625"}, // * 4503599627370496 + {16, "11102230246251565404236316680908203125"}, // * 9007199254740992 + {17, "55511151231257827021181583404541015625"}, // * 18014398509481984 + {17, "277555756156289135105907917022705078125"}, // * 36028797018963968 + {17, "1387778780781445675529539585113525390625"}, // * 72057594037927936 + {18, "6938893903907228377647697925567626953125"}, // * 144115188075855872 + {18, "34694469519536141888238489627838134765625"}, // * 288230376151711744 + {18, "173472347597680709441192448139190673828125"}, // * 576460752303423488 + {19, "867361737988403547205962240695953369140625"}, // * 1152921504606846976 +} + +// Is the leading prefix of b lexicographically less than s? +func prefixIsLessThan(b []byte, s string) bool { + for i := 0; i < len(s); i++ { + if i >= len(b) { + return true + } + if b[i] != s[i] { + return b[i] < s[i] + } + } + return false +} + +// Binary shift left (* 2) by k bits. k <= maxShift to avoid overflow. +func leftShift(a *decimal, k uint) { + delta := leftcheats[k].delta + if prefixIsLessThan(a.d[0:a.nd], leftcheats[k].cutoff) { + delta-- + } + + r := a.nd // read index + w := a.nd + delta // write index + + // Pick up a digit, put down a digit. + var n uint + for r--; r >= 0; r-- { + n += (uint(a.d[r]) - '0') << k + quo := n / 10 + rem := n - 10*quo + w-- + if w < len(a.d) { + a.d[w] = byte(rem + '0') + } else if rem != 0 { + a.trunc = true + } + n = quo + } + + // Put down extra digits. + for n > 0 { + quo := n / 10 + rem := n - 10*quo + w-- + if w < len(a.d) { + a.d[w] = byte(rem + '0') + } else if rem != 0 { + a.trunc = true + } + n = quo + } + + a.nd += delta + if a.nd >= len(a.d) { + a.nd = len(a.d) + } + a.dp += delta + trim(a) +} + +// Binary shift left (k > 0) or right (k < 0). +func (a *decimal) Shift(k int) { + switch { + case a.nd == 0: + // nothing to do: a == 0 + case k > 0: + for k > maxShift { + leftShift(a, maxShift) + k -= maxShift + } + leftShift(a, uint(k)) + case k < 0: + for k < -maxShift { + rightShift(a, maxShift) + k += maxShift + } + rightShift(a, uint(-k)) + } +} + +// If we chop a at nd digits, should we round up? +func shouldRoundUp(a *decimal, nd int) bool { + if nd < 0 || nd >= a.nd { + return false + } + if a.d[nd] == '5' && nd+1 == a.nd { // exactly halfway - round to even + // if we truncated, a little higher than what's recorded - always round up + if a.trunc { + return true + } + return nd > 0 && (a.d[nd-1]-'0')%2 != 0 + } + // not halfway - digit tells all + return a.d[nd] >= '5' +} + +// Round a to nd digits (or fewer). +// If nd is zero, it means we're rounding +// just to the left of the digits, as in +// 0.09 -> 0.1. +func (a *decimal) Round(nd int) { + if nd < 0 || nd >= a.nd { + return + } + if shouldRoundUp(a, nd) { + a.RoundUp(nd) + } else { + a.RoundDown(nd) + } +} + +// Round a down to nd digits (or fewer). +func (a *decimal) RoundDown(nd int) { + if nd < 0 || nd >= a.nd { + return + } + a.nd = nd + trim(a) +} + +// Round a up to nd digits (or fewer). +func (a *decimal) RoundUp(nd int) { + if nd < 0 || nd >= a.nd { + return + } + + // round up + for i := nd - 1; i >= 0; i-- { + c := a.d[i] + if c < '9' { // can stop after this digit + a.d[i]++ + a.nd = i + 1 + return + } + } + + // Number is all 9s. + // Change to single 1 with adjusted decimal point. + a.d[0] = '1' + a.nd = 1 + a.dp++ +} + +// Extract integer part, rounded appropriately. +// No guarantees about overflow. +func (a *decimal) RoundedInteger() uint64 { + if a.dp > 20 { + return 0xFFFFFFFFFFFFFFFF + } + var i int + n := uint64(0) + for i = 0; i < a.dp && i < a.nd; i++ { + n = n*10 + uint64(a.d[i]-'0') + } + for ; i < a.dp; i++ { + n *= 10 + } + if shouldRoundUp(a, a.dp) { + n++ + } + return n +} diff --git a/vendor/github.com/shopspring/decimal/decimal.go b/vendor/github.com/shopspring/decimal/decimal.go new file mode 100644 index 0000000..801c1a0 --- /dev/null +++ b/vendor/github.com/shopspring/decimal/decimal.go @@ -0,0 +1,1477 @@ +// Package decimal implements an arbitrary precision fixed-point decimal. +// +// The zero-value of a Decimal is 0, as you would expect. +// +// The best way to create a new Decimal is to use decimal.NewFromString, ex: +// +// n, err := decimal.NewFromString("-123.4567") +// n.String() // output: "-123.4567" +// +// To use Decimal as part of a struct: +// +// type Struct struct { +// Number Decimal +// } +// +// Note: This can "only" represent numbers with a maximum of 2^31 digits after the decimal point. +package decimal + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "math/big" + "strconv" + "strings" +) + +// DivisionPrecision is the number of decimal places in the result when it +// doesn't divide exactly. +// +// Example: +// +// d1 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)) +// d1.String() // output: "0.6666666666666667" +// d2 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(30000)) +// d2.String() // output: "0.0000666666666667" +// d3 := decimal.NewFromFloat(20000).Div(decimal.NewFromFloat(3)) +// d3.String() // output: "6666.6666666666666667" +// decimal.DivisionPrecision = 3 +// d4 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)) +// d4.String() // output: "0.667" +// +var DivisionPrecision = 16 + +// MarshalJSONWithoutQuotes should be set to true if you want the decimal to +// be JSON marshaled as a number, instead of as a string. +// WARNING: this is dangerous for decimals with many digits, since many JSON +// unmarshallers (ex: Javascript's) will unmarshal JSON numbers to IEEE 754 +// double-precision floating point numbers, which means you can potentially +// silently lose precision. +var MarshalJSONWithoutQuotes = false + +// Zero constant, to make computations faster. +// Zero should never be compared with == or != directly, please use decimal.Equal or decimal.Cmp instead. +var Zero = New(0, 1) + +var zeroInt = big.NewInt(0) +var oneInt = big.NewInt(1) +var twoInt = big.NewInt(2) +var fourInt = big.NewInt(4) +var fiveInt = big.NewInt(5) +var tenInt = big.NewInt(10) +var twentyInt = big.NewInt(20) + +// Decimal represents a fixed-point decimal. It is immutable. +// number = value * 10 ^ exp +type Decimal struct { + value *big.Int + + // NOTE(vadim): this must be an int32, because we cast it to float64 during + // calculations. If exp is 64 bit, we might lose precision. + // If we cared about being able to represent every possible decimal, we + // could make exp a *big.Int but it would hurt performance and numbers + // like that are unrealistic. + exp int32 +} + +// New returns a new fixed-point decimal, value * 10 ^ exp. +func New(value int64, exp int32) Decimal { + return Decimal{ + value: big.NewInt(value), + exp: exp, + } +} + +// NewFromInt converts a int64 to Decimal. +// +// Example: +// +// NewFromInt(123).String() // output: "123" +// NewFromInt(-10).String() // output: "-10" +func NewFromInt(value int64) Decimal { + return Decimal{ + value: big.NewInt(value), + exp: 0, + } +} + +// NewFromInt32 converts a int32 to Decimal. +// +// Example: +// +// NewFromInt(123).String() // output: "123" +// NewFromInt(-10).String() // output: "-10" +func NewFromInt32(value int32) Decimal { + return Decimal{ + value: big.NewInt(int64(value)), + exp: 0, + } +} + +// NewFromBigInt returns a new Decimal from a big.Int, value * 10 ^ exp +func NewFromBigInt(value *big.Int, exp int32) Decimal { + return Decimal{ + value: big.NewInt(0).Set(value), + exp: exp, + } +} + +// NewFromString returns a new Decimal from a string representation. +// Trailing zeroes are not trimmed. +// +// Example: +// +// d, err := NewFromString("-123.45") +// d2, err := NewFromString(".0001") +// d3, err := NewFromString("1.47000") +// +func NewFromString(value string) (Decimal, error) { + originalInput := value + var intString string + var exp int64 + + // Check if number is using scientific notation + eIndex := strings.IndexAny(value, "Ee") + if eIndex != -1 { + expInt, err := strconv.ParseInt(value[eIndex+1:], 10, 32) + if err != nil { + if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { + return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", value) + } + return Decimal{}, fmt.Errorf("can't convert %s to decimal: exponent is not numeric", value) + } + value = value[:eIndex] + exp = expInt + } + + parts := strings.Split(value, ".") + if len(parts) == 1 { + // There is no decimal point, we can just parse the original string as + // an int + intString = value + } else if len(parts) == 2 { + intString = parts[0] + parts[1] + expInt := -len(parts[1]) + exp += int64(expInt) + } else { + return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value) + } + + dValue := new(big.Int) + _, ok := dValue.SetString(intString, 10) + if !ok { + return Decimal{}, fmt.Errorf("can't convert %s to decimal", value) + } + + if exp < math.MinInt32 || exp > math.MaxInt32 { + // NOTE(vadim): I doubt a string could realistically be this long + return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", originalInput) + } + + return Decimal{ + value: dValue, + exp: int32(exp), + }, nil +} + +// RequireFromString returns a new Decimal from a string representation +// or panics if NewFromString would have returned an error. +// +// Example: +// +// d := RequireFromString("-123.45") +// d2 := RequireFromString(".0001") +// +func RequireFromString(value string) Decimal { + dec, err := NewFromString(value) + if err != nil { + panic(err) + } + return dec +} + +// NewFromFloat converts a float64 to Decimal. +// +// The converted number will contain the number of significant digits that can be +// represented in a float with reliable roundtrip. +// This is typically 15 digits, but may be more in some cases. +// See https://www.exploringbinary.com/decimal-precision-of-binary-floating-point-numbers/ for more information. +// +// For slightly faster conversion, use NewFromFloatWithExponent where you can specify the precision in absolute terms. +// +// NOTE: this will panic on NaN, +/-inf +func NewFromFloat(value float64) Decimal { + if value == 0 { + return New(0, 0) + } + return newFromFloat(value, math.Float64bits(value), &float64info) +} + +// NewFromFloat32 converts a float32 to Decimal. +// +// The converted number will contain the number of significant digits that can be +// represented in a float with reliable roundtrip. +// This is typically 6-8 digits depending on the input. +// See https://www.exploringbinary.com/decimal-precision-of-binary-floating-point-numbers/ for more information. +// +// For slightly faster conversion, use NewFromFloatWithExponent where you can specify the precision in absolute terms. +// +// NOTE: this will panic on NaN, +/-inf +func NewFromFloat32(value float32) Decimal { + if value == 0 { + return New(0, 0) + } + // XOR is workaround for https://github.com/golang/go/issues/26285 + a := math.Float32bits(value) ^ 0x80808080 + return newFromFloat(float64(value), uint64(a)^0x80808080, &float32info) +} + +func newFromFloat(val float64, bits uint64, flt *floatInfo) Decimal { + if math.IsNaN(val) || math.IsInf(val, 0) { + panic(fmt.Sprintf("Cannot create a Decimal from %v", val)) + } + exp := int(bits>>flt.mantbits) & (1<>(flt.expbits+flt.mantbits) != 0 + + roundShortest(&d, mant, exp, flt) + // If less than 19 digits, we can do calculation in an int64. + if d.nd < 19 { + tmp := int64(0) + m := int64(1) + for i := d.nd - 1; i >= 0; i-- { + tmp += m * int64(d.d[i]-'0') + m *= 10 + } + if d.neg { + tmp *= -1 + } + return Decimal{value: big.NewInt(tmp), exp: int32(d.dp) - int32(d.nd)} + } + dValue := new(big.Int) + dValue, ok := dValue.SetString(string(d.d[:d.nd]), 10) + if ok { + return Decimal{value: dValue, exp: int32(d.dp) - int32(d.nd)} + } + + return NewFromFloatWithExponent(val, int32(d.dp)-int32(d.nd)) +} + +// NewFromFloatWithExponent converts a float64 to Decimal, with an arbitrary +// number of fractional digits. +// +// Example: +// +// NewFromFloatWithExponent(123.456, -2).String() // output: "123.46" +// +func NewFromFloatWithExponent(value float64, exp int32) Decimal { + if math.IsNaN(value) || math.IsInf(value, 0) { + panic(fmt.Sprintf("Cannot create a Decimal from %v", value)) + } + + bits := math.Float64bits(value) + mant := bits & (1<<52 - 1) + exp2 := int32((bits >> 52) & (1<<11 - 1)) + sign := bits >> 63 + + if exp2 == 0 { + // specials + if mant == 0 { + return Decimal{} + } + // subnormal + exp2++ + } else { + // normal + mant |= 1 << 52 + } + + exp2 -= 1023 + 52 + + // normalizing base-2 values + for mant&1 == 0 { + mant = mant >> 1 + exp2++ + } + + // maximum number of fractional base-10 digits to represent 2^N exactly cannot be more than -N if N<0 + if exp < 0 && exp < exp2 { + if exp2 < 0 { + exp = exp2 + } else { + exp = 0 + } + } + + // representing 10^M * 2^N as 5^M * 2^(M+N) + exp2 -= exp + + temp := big.NewInt(1) + dMant := big.NewInt(int64(mant)) + + // applying 5^M + if exp > 0 { + temp = temp.SetInt64(int64(exp)) + temp = temp.Exp(fiveInt, temp, nil) + } else if exp < 0 { + temp = temp.SetInt64(-int64(exp)) + temp = temp.Exp(fiveInt, temp, nil) + dMant = dMant.Mul(dMant, temp) + temp = temp.SetUint64(1) + } + + // applying 2^(M+N) + if exp2 > 0 { + dMant = dMant.Lsh(dMant, uint(exp2)) + } else if exp2 < 0 { + temp = temp.Lsh(temp, uint(-exp2)) + } + + // rounding and downscaling + if exp > 0 || exp2 < 0 { + halfDown := new(big.Int).Rsh(temp, 1) + dMant = dMant.Add(dMant, halfDown) + dMant = dMant.Quo(dMant, temp) + } + + if sign == 1 { + dMant = dMant.Neg(dMant) + } + + return Decimal{ + value: dMant, + exp: exp, + } +} + +// rescale returns a rescaled version of the decimal. Returned +// decimal may be less precise if the given exponent is bigger +// than the initial exponent of the Decimal. +// NOTE: this will truncate, NOT round +// +// Example: +// +// d := New(12345, -4) +// d2 := d.rescale(-1) +// d3 := d2.rescale(-4) +// println(d1) +// println(d2) +// println(d3) +// +// Output: +// +// 1.2345 +// 1.2 +// 1.2000 +// +func (d Decimal) rescale(exp int32) Decimal { + d.ensureInitialized() + + if d.exp == exp { + return Decimal{ + new(big.Int).Set(d.value), + d.exp, + } + } + + // NOTE(vadim): must convert exps to float64 before - to prevent overflow + diff := math.Abs(float64(exp) - float64(d.exp)) + value := new(big.Int).Set(d.value) + + expScale := new(big.Int).Exp(tenInt, big.NewInt(int64(diff)), nil) + if exp > d.exp { + value = value.Quo(value, expScale) + } else if exp < d.exp { + value = value.Mul(value, expScale) + } + + return Decimal{ + value: value, + exp: exp, + } +} + +// Abs returns the absolute value of the decimal. +func (d Decimal) Abs() Decimal { + d.ensureInitialized() + d2Value := new(big.Int).Abs(d.value) + return Decimal{ + value: d2Value, + exp: d.exp, + } +} + +// Add returns d + d2. +func (d Decimal) Add(d2 Decimal) Decimal { + rd, rd2 := RescalePair(d, d2) + + d3Value := new(big.Int).Add(rd.value, rd2.value) + return Decimal{ + value: d3Value, + exp: rd.exp, + } +} + +// Sub returns d - d2. +func (d Decimal) Sub(d2 Decimal) Decimal { + rd, rd2 := RescalePair(d, d2) + + d3Value := new(big.Int).Sub(rd.value, rd2.value) + return Decimal{ + value: d3Value, + exp: rd.exp, + } +} + +// Neg returns -d. +func (d Decimal) Neg() Decimal { + d.ensureInitialized() + val := new(big.Int).Neg(d.value) + return Decimal{ + value: val, + exp: d.exp, + } +} + +// Mul returns d * d2. +func (d Decimal) Mul(d2 Decimal) Decimal { + d.ensureInitialized() + d2.ensureInitialized() + + expInt64 := int64(d.exp) + int64(d2.exp) + if expInt64 > math.MaxInt32 || expInt64 < math.MinInt32 { + // NOTE(vadim): better to panic than give incorrect results, as + // Decimals are usually used for money + panic(fmt.Sprintf("exponent %v overflows an int32!", expInt64)) + } + + d3Value := new(big.Int).Mul(d.value, d2.value) + return Decimal{ + value: d3Value, + exp: int32(expInt64), + } +} + +// Shift shifts the decimal in base 10. +// It shifts left when shift is positive and right if shift is negative. +// In simpler terms, the given value for shift is added to the exponent +// of the decimal. +func (d Decimal) Shift(shift int32) Decimal { + d.ensureInitialized() + return Decimal{ + value: new(big.Int).Set(d.value), + exp: d.exp + shift, + } +} + +// Div returns d / d2. If it doesn't divide exactly, the result will have +// DivisionPrecision digits after the decimal point. +func (d Decimal) Div(d2 Decimal) Decimal { + return d.DivRound(d2, int32(DivisionPrecision)) +} + +// QuoRem does divsion with remainder +// d.QuoRem(d2,precision) returns quotient q and remainder r such that +// d = d2 * q + r, q an integer multiple of 10^(-precision) +// 0 <= r < abs(d2) * 10 ^(-precision) if d>=0 +// 0 >= r > -abs(d2) * 10 ^(-precision) if d<0 +// Note that precision<0 is allowed as input. +func (d Decimal) QuoRem(d2 Decimal, precision int32) (Decimal, Decimal) { + d.ensureInitialized() + d2.ensureInitialized() + if d2.value.Sign() == 0 { + panic("decimal division by 0") + } + scale := -precision + e := int64(d.exp - d2.exp - scale) + if e > math.MaxInt32 || e < math.MinInt32 { + panic("overflow in decimal QuoRem") + } + var aa, bb, expo big.Int + var scalerest int32 + // d = a 10^ea + // d2 = b 10^eb + if e < 0 { + aa = *d.value + expo.SetInt64(-e) + bb.Exp(tenInt, &expo, nil) + bb.Mul(d2.value, &bb) + scalerest = d.exp + // now aa = a + // bb = b 10^(scale + eb - ea) + } else { + expo.SetInt64(e) + aa.Exp(tenInt, &expo, nil) + aa.Mul(d.value, &aa) + bb = *d2.value + scalerest = scale + d2.exp + // now aa = a ^ (ea - eb - scale) + // bb = b + } + var q, r big.Int + q.QuoRem(&aa, &bb, &r) + dq := Decimal{value: &q, exp: scale} + dr := Decimal{value: &r, exp: scalerest} + return dq, dr +} + +// DivRound divides and rounds to a given precision +// i.e. to an integer multiple of 10^(-precision) +// for a positive quotient digit 5 is rounded up, away from 0 +// if the quotient is negative then digit 5 is rounded down, away from 0 +// Note that precision<0 is allowed as input. +func (d Decimal) DivRound(d2 Decimal, precision int32) Decimal { + // QuoRem already checks initialization + q, r := d.QuoRem(d2, precision) + // the actual rounding decision is based on comparing r*10^precision and d2/2 + // instead compare 2 r 10 ^precision and d2 + var rv2 big.Int + rv2.Abs(r.value) + rv2.Lsh(&rv2, 1) + // now rv2 = abs(r.value) * 2 + r2 := Decimal{value: &rv2, exp: r.exp + precision} + // r2 is now 2 * r * 10 ^ precision + var c = r2.Cmp(d2.Abs()) + + if c < 0 { + return q + } + + if d.value.Sign()*d2.value.Sign() < 0 { + return q.Sub(New(1, -precision)) + } + + return q.Add(New(1, -precision)) +} + +// Mod returns d % d2. +func (d Decimal) Mod(d2 Decimal) Decimal { + quo := d.Div(d2).Truncate(0) + return d.Sub(d2.Mul(quo)) +} + +// Pow returns d to the power d2 +func (d Decimal) Pow(d2 Decimal) Decimal { + var temp Decimal + if d2.IntPart() == 0 { + return NewFromFloat(1) + } + temp = d.Pow(d2.Div(NewFromFloat(2))) + if d2.IntPart()%2 == 0 { + return temp.Mul(temp) + } + if d2.IntPart() > 0 { + return temp.Mul(temp).Mul(d) + } + return temp.Mul(temp).Div(d) +} + +// Cmp compares the numbers represented by d and d2 and returns: +// +// -1 if d < d2 +// 0 if d == d2 +// +1 if d > d2 +// +func (d Decimal) Cmp(d2 Decimal) int { + d.ensureInitialized() + d2.ensureInitialized() + + if d.exp == d2.exp { + return d.value.Cmp(d2.value) + } + + rd, rd2 := RescalePair(d, d2) + + return rd.value.Cmp(rd2.value) +} + +// Equal returns whether the numbers represented by d and d2 are equal. +func (d Decimal) Equal(d2 Decimal) bool { + return d.Cmp(d2) == 0 +} + +// Equals is deprecated, please use Equal method instead +func (d Decimal) Equals(d2 Decimal) bool { + return d.Equal(d2) +} + +// GreaterThan (GT) returns true when d is greater than d2. +func (d Decimal) GreaterThan(d2 Decimal) bool { + return d.Cmp(d2) == 1 +} + +// GreaterThanOrEqual (GTE) returns true when d is greater than or equal to d2. +func (d Decimal) GreaterThanOrEqual(d2 Decimal) bool { + cmp := d.Cmp(d2) + return cmp == 1 || cmp == 0 +} + +// LessThan (LT) returns true when d is less than d2. +func (d Decimal) LessThan(d2 Decimal) bool { + return d.Cmp(d2) == -1 +} + +// LessThanOrEqual (LTE) returns true when d is less than or equal to d2. +func (d Decimal) LessThanOrEqual(d2 Decimal) bool { + cmp := d.Cmp(d2) + return cmp == -1 || cmp == 0 +} + +// Sign returns: +// +// -1 if d < 0 +// 0 if d == 0 +// +1 if d > 0 +// +func (d Decimal) Sign() int { + if d.value == nil { + return 0 + } + return d.value.Sign() +} + +// IsPositive return +// +// true if d > 0 +// false if d == 0 +// false if d < 0 +func (d Decimal) IsPositive() bool { + return d.Sign() == 1 +} + +// IsNegative return +// +// true if d < 0 +// false if d == 0 +// false if d > 0 +func (d Decimal) IsNegative() bool { + return d.Sign() == -1 +} + +// IsZero return +// +// true if d == 0 +// false if d > 0 +// false if d < 0 +func (d Decimal) IsZero() bool { + return d.Sign() == 0 +} + +// Exponent returns the exponent, or scale component of the decimal. +func (d Decimal) Exponent() int32 { + return d.exp +} + +// Coefficient returns the coefficient of the decimal. It is scaled by 10^Exponent() +func (d Decimal) Coefficient() *big.Int { + d.ensureInitialized() + // we copy the coefficient so that mutating the result does not mutate the + // Decimal. + return big.NewInt(0).Set(d.value) +} + +// IntPart returns the integer component of the decimal. +func (d Decimal) IntPart() int64 { + scaledD := d.rescale(0) + return scaledD.value.Int64() +} + +// BigInt returns integer component of the decimal as a BigInt. +func (d Decimal) BigInt() *big.Int { + scaledD := d.rescale(0) + i := &big.Int{} + i.SetString(scaledD.String(), 10) + return i +} + +// BigFloat returns decimal as BigFloat. +// Be aware that casting decimal to BigFloat might cause a loss of precision. +func (d Decimal) BigFloat() *big.Float { + f := &big.Float{} + f.SetString(d.String()) + return f +} + +// Rat returns a rational number representation of the decimal. +func (d Decimal) Rat() *big.Rat { + d.ensureInitialized() + if d.exp <= 0 { + // NOTE(vadim): must negate after casting to prevent int32 overflow + denom := new(big.Int).Exp(tenInt, big.NewInt(-int64(d.exp)), nil) + return new(big.Rat).SetFrac(d.value, denom) + } + + mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil) + num := new(big.Int).Mul(d.value, mul) + return new(big.Rat).SetFrac(num, oneInt) +} + +// Float64 returns the nearest float64 value for d and a bool indicating +// whether f represents d exactly. +// For more details, see the documentation for big.Rat.Float64 +func (d Decimal) Float64() (f float64, exact bool) { + return d.Rat().Float64() +} + +// String returns the string representation of the decimal +// with the fixed point. +// +// Example: +// +// d := New(-12345, -3) +// println(d.String()) +// +// Output: +// +// -12.345 +// +func (d Decimal) String() string { + return d.string(true) +} + +// StringFixed returns a rounded fixed-point string with places digits after +// the decimal point. +// +// Example: +// +// NewFromFloat(0).StringFixed(2) // output: "0.00" +// NewFromFloat(0).StringFixed(0) // output: "0" +// NewFromFloat(5.45).StringFixed(0) // output: "5" +// NewFromFloat(5.45).StringFixed(1) // output: "5.5" +// NewFromFloat(5.45).StringFixed(2) // output: "5.45" +// NewFromFloat(5.45).StringFixed(3) // output: "5.450" +// NewFromFloat(545).StringFixed(-1) // output: "550" +// +func (d Decimal) StringFixed(places int32) string { + rounded := d.Round(places) + return rounded.string(false) +} + +// StringFixedBank returns a banker rounded fixed-point string with places digits +// after the decimal point. +// +// Example: +// +// NewFromFloat(0).StringFixedBank(2) // output: "0.00" +// NewFromFloat(0).StringFixedBank(0) // output: "0" +// NewFromFloat(5.45).StringFixedBank(0) // output: "5" +// NewFromFloat(5.45).StringFixedBank(1) // output: "5.4" +// NewFromFloat(5.45).StringFixedBank(2) // output: "5.45" +// NewFromFloat(5.45).StringFixedBank(3) // output: "5.450" +// NewFromFloat(545).StringFixedBank(-1) // output: "540" +// +func (d Decimal) StringFixedBank(places int32) string { + rounded := d.RoundBank(places) + return rounded.string(false) +} + +// StringFixedCash returns a Swedish/Cash rounded fixed-point string. For +// more details see the documentation at function RoundCash. +func (d Decimal) StringFixedCash(interval uint8) string { + rounded := d.RoundCash(interval) + return rounded.string(false) +} + +// Round rounds the decimal to places decimal places. +// If places < 0, it will round the integer part to the nearest 10^(-places). +// +// Example: +// +// NewFromFloat(5.45).Round(1).String() // output: "5.5" +// NewFromFloat(545).Round(-1).String() // output: "550" +// +func (d Decimal) Round(places int32) Decimal { + // truncate to places + 1 + ret := d.rescale(-places - 1) + + // add sign(d) * 0.5 + if ret.value.Sign() < 0 { + ret.value.Sub(ret.value, fiveInt) + } else { + ret.value.Add(ret.value, fiveInt) + } + + // floor for positive numbers, ceil for negative numbers + _, m := ret.value.DivMod(ret.value, tenInt, new(big.Int)) + ret.exp++ + if ret.value.Sign() < 0 && m.Cmp(zeroInt) != 0 { + ret.value.Add(ret.value, oneInt) + } + + return ret +} + +// RoundBank rounds the decimal to places decimal places. +// If the final digit to round is equidistant from the nearest two integers the +// rounded value is taken as the even number +// +// If places < 0, it will round the integer part to the nearest 10^(-places). +// +// Examples: +// +// NewFromFloat(5.45).Round(1).String() // output: "5.4" +// NewFromFloat(545).Round(-1).String() // output: "540" +// NewFromFloat(5.46).Round(1).String() // output: "5.5" +// NewFromFloat(546).Round(-1).String() // output: "550" +// NewFromFloat(5.55).Round(1).String() // output: "5.6" +// NewFromFloat(555).Round(-1).String() // output: "560" +// +func (d Decimal) RoundBank(places int32) Decimal { + + round := d.Round(places) + remainder := d.Sub(round).Abs() + + half := New(5, -places-1) + if remainder.Cmp(half) == 0 && round.value.Bit(0) != 0 { + if round.value.Sign() < 0 { + round.value.Add(round.value, oneInt) + } else { + round.value.Sub(round.value, oneInt) + } + } + + return round +} + +// RoundCash aka Cash/Penny/öre rounding rounds decimal to a specific +// interval. The amount payable for a cash transaction is rounded to the nearest +// multiple of the minimum currency unit available. The following intervals are +// available: 5, 10, 25, 50 and 100; any other number throws a panic. +// 5: 5 cent rounding 3.43 => 3.45 +// 10: 10 cent rounding 3.45 => 3.50 (5 gets rounded up) +// 25: 25 cent rounding 3.41 => 3.50 +// 50: 50 cent rounding 3.75 => 4.00 +// 100: 100 cent rounding 3.50 => 4.00 +// For more details: https://en.wikipedia.org/wiki/Cash_rounding +func (d Decimal) RoundCash(interval uint8) Decimal { + var iVal *big.Int + switch interval { + case 5: + iVal = twentyInt + case 10: + iVal = tenInt + case 25: + iVal = fourInt + case 50: + iVal = twoInt + case 100: + iVal = oneInt + default: + panic(fmt.Sprintf("Decimal does not support this Cash rounding interval `%d`. Supported: 5, 10, 25, 50, 100", interval)) + } + dVal := Decimal{ + value: iVal, + } + + // TODO: optimize those calculations to reduce the high allocations (~29 allocs). + return d.Mul(dVal).Round(0).Div(dVal).Truncate(2) +} + +// Floor returns the nearest integer value less than or equal to d. +func (d Decimal) Floor() Decimal { + d.ensureInitialized() + + if d.exp >= 0 { + return d + } + + exp := big.NewInt(10) + + // NOTE(vadim): must negate after casting to prevent int32 overflow + exp.Exp(exp, big.NewInt(-int64(d.exp)), nil) + + z := new(big.Int).Div(d.value, exp) + return Decimal{value: z, exp: 0} +} + +// Ceil returns the nearest integer value greater than or equal to d. +func (d Decimal) Ceil() Decimal { + d.ensureInitialized() + + if d.exp >= 0 { + return d + } + + exp := big.NewInt(10) + + // NOTE(vadim): must negate after casting to prevent int32 overflow + exp.Exp(exp, big.NewInt(-int64(d.exp)), nil) + + z, m := new(big.Int).DivMod(d.value, exp, new(big.Int)) + if m.Cmp(zeroInt) != 0 { + z.Add(z, oneInt) + } + return Decimal{value: z, exp: 0} +} + +// Truncate truncates off digits from the number, without rounding. +// +// NOTE: precision is the last digit that will not be truncated (must be >= 0). +// +// Example: +// +// decimal.NewFromString("123.456").Truncate(2).String() // "123.45" +// +func (d Decimal) Truncate(precision int32) Decimal { + d.ensureInitialized() + if precision >= 0 && -precision > d.exp { + return d.rescale(-precision) + } + return d +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error { + if string(decimalBytes) == "null" { + return nil + } + + str, err := unquoteIfQuoted(decimalBytes) + if err != nil { + return fmt.Errorf("error decoding string '%s': %s", decimalBytes, err) + } + + decimal, err := NewFromString(str) + *d = decimal + if err != nil { + return fmt.Errorf("error decoding string '%s': %s", str, err) + } + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (d Decimal) MarshalJSON() ([]byte, error) { + var str string + if MarshalJSONWithoutQuotes { + str = d.String() + } else { + str = "\"" + d.String() + "\"" + } + return []byte(str), nil +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation +// is already used when encoding to text, this method stores that string as []byte +func (d *Decimal) UnmarshalBinary(data []byte) error { + // Extract the exponent + d.exp = int32(binary.BigEndian.Uint32(data[:4])) + + // Extract the value + d.value = new(big.Int) + return d.value.GobDecode(data[4:]) +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (d Decimal) MarshalBinary() (data []byte, err error) { + // Write the exponent first since it's a fixed size + v1 := make([]byte, 4) + binary.BigEndian.PutUint32(v1, uint32(d.exp)) + + // Add the value + var v2 []byte + if v2, err = d.value.GobEncode(); err != nil { + return + } + + // Return the byte array + data = append(v1, v2...) + return +} + +// Scan implements the sql.Scanner interface for database deserialization. +func (d *Decimal) Scan(value interface{}) error { + // first try to see if the data is stored in database as a Numeric datatype + switch v := value.(type) { + + case float32: + *d = NewFromFloat(float64(v)) + return nil + + case float64: + // numeric in sqlite3 sends us float64 + *d = NewFromFloat(v) + return nil + + case int64: + // at least in sqlite3 when the value is 0 in db, the data is sent + // to us as an int64 instead of a float64 ... + *d = New(v, 0) + return nil + + default: + // default is trying to interpret value stored as string + str, err := unquoteIfQuoted(v) + if err != nil { + return err + } + *d, err = NewFromString(str) + return err + } +} + +// Value implements the driver.Valuer interface for database serialization. +func (d Decimal) Value() (driver.Value, error) { + return d.String(), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface for XML +// deserialization. +func (d *Decimal) UnmarshalText(text []byte) error { + str := string(text) + + dec, err := NewFromString(str) + *d = dec + if err != nil { + return fmt.Errorf("error decoding string '%s': %s", str, err) + } + + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface for XML +// serialization. +func (d Decimal) MarshalText() (text []byte, err error) { + return []byte(d.String()), nil +} + +// GobEncode implements the gob.GobEncoder interface for gob serialization. +func (d Decimal) GobEncode() ([]byte, error) { + return d.MarshalBinary() +} + +// GobDecode implements the gob.GobDecoder interface for gob serialization. +func (d *Decimal) GobDecode(data []byte) error { + return d.UnmarshalBinary(data) +} + +// StringScaled first scales the decimal then calls .String() on it. +// NOTE: buggy, unintuitive, and DEPRECATED! Use StringFixed instead. +func (d Decimal) StringScaled(exp int32) string { + return d.rescale(exp).String() +} + +func (d Decimal) string(trimTrailingZeros bool) string { + if d.exp >= 0 { + return d.rescale(0).value.String() + } + + abs := new(big.Int).Abs(d.value) + str := abs.String() + + var intPart, fractionalPart string + + // NOTE(vadim): this cast to int will cause bugs if d.exp == INT_MIN + // and you are on a 32-bit machine. Won't fix this super-edge case. + dExpInt := int(d.exp) + if len(str) > -dExpInt { + intPart = str[:len(str)+dExpInt] + fractionalPart = str[len(str)+dExpInt:] + } else { + intPart = "0" + + num0s := -dExpInt - len(str) + fractionalPart = strings.Repeat("0", num0s) + str + } + + if trimTrailingZeros { + i := len(fractionalPart) - 1 + for ; i >= 0; i-- { + if fractionalPart[i] != '0' { + break + } + } + fractionalPart = fractionalPart[:i+1] + } + + number := intPart + if len(fractionalPart) > 0 { + number += "." + fractionalPart + } + + if d.value.Sign() < 0 { + return "-" + number + } + + return number +} + +func (d *Decimal) ensureInitialized() { + if d.value == nil { + d.value = new(big.Int) + } +} + +// Min returns the smallest Decimal that was passed in the arguments. +// +// To call this function with an array, you must do: +// +// Min(arr[0], arr[1:]...) +// +// This makes it harder to accidentally call Min with 0 arguments. +func Min(first Decimal, rest ...Decimal) Decimal { + ans := first + for _, item := range rest { + if item.Cmp(ans) < 0 { + ans = item + } + } + return ans +} + +// Max returns the largest Decimal that was passed in the arguments. +// +// To call this function with an array, you must do: +// +// Max(arr[0], arr[1:]...) +// +// This makes it harder to accidentally call Max with 0 arguments. +func Max(first Decimal, rest ...Decimal) Decimal { + ans := first + for _, item := range rest { + if item.Cmp(ans) > 0 { + ans = item + } + } + return ans +} + +// Sum returns the combined total of the provided first and rest Decimals +func Sum(first Decimal, rest ...Decimal) Decimal { + total := first + for _, item := range rest { + total = total.Add(item) + } + + return total +} + +// Avg returns the average value of the provided first and rest Decimals +func Avg(first Decimal, rest ...Decimal) Decimal { + count := New(int64(len(rest)+1), 0) + sum := Sum(first, rest...) + return sum.Div(count) +} + +// RescalePair rescales two decimals to common exponential value (minimal exp of both decimals) +func RescalePair(d1 Decimal, d2 Decimal) (Decimal, Decimal) { + d1.ensureInitialized() + d2.ensureInitialized() + + if d1.exp == d2.exp { + return d1, d2 + } + + baseScale := min(d1.exp, d2.exp) + if baseScale != d1.exp { + return d1.rescale(baseScale), d2 + } + return d1, d2.rescale(baseScale) +} + +func min(x, y int32) int32 { + if x >= y { + return y + } + return x +} + +func unquoteIfQuoted(value interface{}) (string, error) { + var bytes []byte + + switch v := value.(type) { + case string: + bytes = []byte(v) + case []byte: + bytes = v + default: + return "", fmt.Errorf("could not convert value '%+v' to byte array of type '%T'", + value, value) + } + + // If the amount is quoted, strip the quotes + if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' { + bytes = bytes[1 : len(bytes)-1] + } + return string(bytes), nil +} + +// NullDecimal represents a nullable decimal with compatibility for +// scanning null values from the database. +type NullDecimal struct { + Decimal Decimal + Valid bool +} + +// Scan implements the sql.Scanner interface for database deserialization. +func (d *NullDecimal) Scan(value interface{}) error { + if value == nil { + d.Valid = false + return nil + } + d.Valid = true + return d.Decimal.Scan(value) +} + +// Value implements the driver.Valuer interface for database serialization. +func (d NullDecimal) Value() (driver.Value, error) { + if !d.Valid { + return nil, nil + } + return d.Decimal.Value() +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (d *NullDecimal) UnmarshalJSON(decimalBytes []byte) error { + if string(decimalBytes) == "null" { + d.Valid = false + return nil + } + d.Valid = true + return d.Decimal.UnmarshalJSON(decimalBytes) +} + +// MarshalJSON implements the json.Marshaler interface. +func (d NullDecimal) MarshalJSON() ([]byte, error) { + if !d.Valid { + return []byte("null"), nil + } + return d.Decimal.MarshalJSON() +} + +// Trig functions + +// Atan returns the arctangent, in radians, of x. +func (d Decimal) Atan() Decimal { + if d.Equal(NewFromFloat(0.0)) { + return d + } + if d.GreaterThan(NewFromFloat(0.0)) { + return d.satan() + } + return d.Neg().satan().Neg() +} + +func (d Decimal) xatan() Decimal { + P0 := NewFromFloat(-8.750608600031904122785e-01) + P1 := NewFromFloat(-1.615753718733365076637e+01) + P2 := NewFromFloat(-7.500855792314704667340e+01) + P3 := NewFromFloat(-1.228866684490136173410e+02) + P4 := NewFromFloat(-6.485021904942025371773e+01) + Q0 := NewFromFloat(2.485846490142306297962e+01) + Q1 := NewFromFloat(1.650270098316988542046e+02) + Q2 := NewFromFloat(4.328810604912902668951e+02) + Q3 := NewFromFloat(4.853903996359136964868e+02) + Q4 := NewFromFloat(1.945506571482613964425e+02) + z := d.Mul(d) + b1 := P0.Mul(z).Add(P1).Mul(z).Add(P2).Mul(z).Add(P3).Mul(z).Add(P4).Mul(z) + b2 := z.Add(Q0).Mul(z).Add(Q1).Mul(z).Add(Q2).Mul(z).Add(Q3).Mul(z).Add(Q4) + z = b1.Div(b2) + z = d.Mul(z).Add(d) + return z +} + +// satan reduces its argument (known to be positive) +// to the range [0, 0.66] and calls xatan. +func (d Decimal) satan() Decimal { + Morebits := NewFromFloat(6.123233995736765886130e-17) // pi/2 = PIO2 + Morebits + Tan3pio8 := NewFromFloat(2.41421356237309504880) // tan(3*pi/8) + pi := NewFromFloat(3.14159265358979323846264338327950288419716939937510582097494459) + + if d.LessThanOrEqual(NewFromFloat(0.66)) { + return d.xatan() + } + if d.GreaterThan(Tan3pio8) { + return pi.Div(NewFromFloat(2.0)).Sub(NewFromFloat(1.0).Div(d).xatan()).Add(Morebits) + } + return pi.Div(NewFromFloat(4.0)).Add((d.Sub(NewFromFloat(1.0)).Div(d.Add(NewFromFloat(1.0)))).xatan()).Add(NewFromFloat(0.5).Mul(Morebits)) +} + +// sin coefficients +var _sin = [...]Decimal{ + NewFromFloat(1.58962301576546568060e-10), // 0x3de5d8fd1fd19ccd + NewFromFloat(-2.50507477628578072866e-8), // 0xbe5ae5e5a9291f5d + NewFromFloat(2.75573136213857245213e-6), // 0x3ec71de3567d48a1 + NewFromFloat(-1.98412698295895385996e-4), // 0xbf2a01a019bfdf03 + NewFromFloat(8.33333333332211858878e-3), // 0x3f8111111110f7d0 + NewFromFloat(-1.66666666666666307295e-1), // 0xbfc5555555555548 +} + +// Sin returns the sine of the radian argument x. +func (d Decimal) Sin() Decimal { + PI4A := NewFromFloat(7.85398125648498535156e-1) // 0x3fe921fb40000000, Pi/4 split into three parts + PI4B := NewFromFloat(3.77489470793079817668e-8) // 0x3e64442d00000000, + PI4C := NewFromFloat(2.69515142907905952645e-15) // 0x3ce8469898cc5170, + M4PI := NewFromFloat(1.273239544735162542821171882678754627704620361328125) // 4/pi + + if d.Equal(NewFromFloat(0.0)) { + return d + } + // make argument positive but save the sign + sign := false + if d.LessThan(NewFromFloat(0.0)) { + d = d.Neg() + sign = true + } + + j := d.Mul(M4PI).IntPart() // integer part of x/(Pi/4), as integer for tests on the phase angle + y := NewFromFloat(float64(j)) // integer part of x/(Pi/4), as float + + // map zeros to origin + if j&1 == 1 { + j++ + y = y.Add(NewFromFloat(1.0)) + } + j &= 7 // octant modulo 2Pi radians (360 degrees) + // reflect in x axis + if j > 3 { + sign = !sign + j -= 4 + } + z := d.Sub(y.Mul(PI4A)).Sub(y.Mul(PI4B)).Sub(y.Mul(PI4C)) // Extended precision modular arithmetic + zz := z.Mul(z) + + if j == 1 || j == 2 { + w := zz.Mul(zz).Mul(_cos[0].Mul(zz).Add(_cos[1]).Mul(zz).Add(_cos[2]).Mul(zz).Add(_cos[3]).Mul(zz).Add(_cos[4]).Mul(zz).Add(_cos[5])) + y = NewFromFloat(1.0).Sub(NewFromFloat(0.5).Mul(zz)).Add(w) + } else { + y = z.Add(z.Mul(zz).Mul(_sin[0].Mul(zz).Add(_sin[1]).Mul(zz).Add(_sin[2]).Mul(zz).Add(_sin[3]).Mul(zz).Add(_sin[4]).Mul(zz).Add(_sin[5]))) + } + if sign { + y = y.Neg() + } + return y +} + +// cos coefficients +var _cos = [...]Decimal{ + NewFromFloat(-1.13585365213876817300e-11), // 0xbda8fa49a0861a9b + NewFromFloat(2.08757008419747316778e-9), // 0x3e21ee9d7b4e3f05 + NewFromFloat(-2.75573141792967388112e-7), // 0xbe927e4f7eac4bc6 + NewFromFloat(2.48015872888517045348e-5), // 0x3efa01a019c844f5 + NewFromFloat(-1.38888888888730564116e-3), // 0xbf56c16c16c14f91 + NewFromFloat(4.16666666666665929218e-2), // 0x3fa555555555554b +} + +// Cos returns the cosine of the radian argument x. +func (d Decimal) Cos() Decimal { + + PI4A := NewFromFloat(7.85398125648498535156e-1) // 0x3fe921fb40000000, Pi/4 split into three parts + PI4B := NewFromFloat(3.77489470793079817668e-8) // 0x3e64442d00000000, + PI4C := NewFromFloat(2.69515142907905952645e-15) // 0x3ce8469898cc5170, + M4PI := NewFromFloat(1.273239544735162542821171882678754627704620361328125) // 4/pi + + // make argument positive + sign := false + if d.LessThan(NewFromFloat(0.0)) { + d = d.Neg() + } + + j := d.Mul(M4PI).IntPart() // integer part of x/(Pi/4), as integer for tests on the phase angle + y := NewFromFloat(float64(j)) // integer part of x/(Pi/4), as float + + // map zeros to origin + if j&1 == 1 { + j++ + y = y.Add(NewFromFloat(1.0)) + } + j &= 7 // octant modulo 2Pi radians (360 degrees) + // reflect in x axis + if j > 3 { + sign = !sign + j -= 4 + } + if j > 1 { + sign = !sign + } + + z := d.Sub(y.Mul(PI4A)).Sub(y.Mul(PI4B)).Sub(y.Mul(PI4C)) // Extended precision modular arithmetic + zz := z.Mul(z) + + if j == 1 || j == 2 { + y = z.Add(z.Mul(zz).Mul(_sin[0].Mul(zz).Add(_sin[1]).Mul(zz).Add(_sin[2]).Mul(zz).Add(_sin[3]).Mul(zz).Add(_sin[4]).Mul(zz).Add(_sin[5]))) + } else { + w := zz.Mul(zz).Mul(_cos[0].Mul(zz).Add(_cos[1]).Mul(zz).Add(_cos[2]).Mul(zz).Add(_cos[3]).Mul(zz).Add(_cos[4]).Mul(zz).Add(_cos[5])) + y = NewFromFloat(1.0).Sub(NewFromFloat(0.5).Mul(zz)).Add(w) + } + if sign { + y = y.Neg() + } + return y +} + +var _tanP = [...]Decimal{ + NewFromFloat(-1.30936939181383777646e+4), // 0xc0c992d8d24f3f38 + NewFromFloat(1.15351664838587416140e+6), // 0x413199eca5fc9ddd + NewFromFloat(-1.79565251976484877988e+7), // 0xc1711fead3299176 +} +var _tanQ = [...]Decimal{ + NewFromFloat(1.00000000000000000000e+0), + NewFromFloat(1.36812963470692954678e+4), //0x40cab8a5eeb36572 + NewFromFloat(-1.32089234440210967447e+6), //0xc13427bc582abc96 + NewFromFloat(2.50083801823357915839e+7), //0x4177d98fc2ead8ef + NewFromFloat(-5.38695755929454629881e+7), //0xc189afe03cbe5a31 +} + +// Tan returns the tangent of the radian argument x. +func (d Decimal) Tan() Decimal { + + PI4A := NewFromFloat(7.85398125648498535156e-1) // 0x3fe921fb40000000, Pi/4 split into three parts + PI4B := NewFromFloat(3.77489470793079817668e-8) // 0x3e64442d00000000, + PI4C := NewFromFloat(2.69515142907905952645e-15) // 0x3ce8469898cc5170, + M4PI := NewFromFloat(1.273239544735162542821171882678754627704620361328125) // 4/pi + + if d.Equal(NewFromFloat(0.0)) { + return d + } + + // make argument positive but save the sign + sign := false + if d.LessThan(NewFromFloat(0.0)) { + d = d.Neg() + sign = true + } + + j := d.Mul(M4PI).IntPart() // integer part of x/(Pi/4), as integer for tests on the phase angle + y := NewFromFloat(float64(j)) // integer part of x/(Pi/4), as float + + // map zeros to origin + if j&1 == 1 { + j++ + y = y.Add(NewFromFloat(1.0)) + } + + z := d.Sub(y.Mul(PI4A)).Sub(y.Mul(PI4B)).Sub(y.Mul(PI4C)) // Extended precision modular arithmetic + zz := z.Mul(z) + + if zz.GreaterThan(NewFromFloat(1e-14)) { + w := zz.Mul(_tanP[0].Mul(zz).Add(_tanP[1]).Mul(zz).Add(_tanP[2])) + x := zz.Add(_tanQ[1]).Mul(zz).Add(_tanQ[2]).Mul(zz).Add(_tanQ[3]).Mul(zz).Add(_tanQ[4]) + y = z.Add(z.Mul(w.Div(x))) + } else { + y = z + } + if j&2 == 2 { + y = NewFromFloat(-1.0).Div(y) + } + if sign { + y = y.Neg() + } + return y +} diff --git a/vendor/github.com/shopspring/decimal/go.mod b/vendor/github.com/shopspring/decimal/go.mod new file mode 100644 index 0000000..ae1b7aa --- /dev/null +++ b/vendor/github.com/shopspring/decimal/go.mod @@ -0,0 +1,3 @@ +module github.com/shopspring/decimal + +go 1.13 diff --git a/vendor/github.com/shopspring/decimal/rounding.go b/vendor/github.com/shopspring/decimal/rounding.go new file mode 100644 index 0000000..8008f55 --- /dev/null +++ b/vendor/github.com/shopspring/decimal/rounding.go @@ -0,0 +1,119 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Multiprecision decimal numbers. +// For floating-point formatting only; not general purpose. +// Only operations are assign and (binary) left/right shift. +// Can do binary floating point in multiprecision decimal precisely +// because 2 divides 10; cannot do decimal floating point +// in multiprecision binary precisely. + +package decimal + +type floatInfo struct { + mantbits uint + expbits uint + bias int +} + +var float32info = floatInfo{23, 8, -127} +var float64info = floatInfo{52, 11, -1023} + +// roundShortest rounds d (= mant * 2^exp) to the shortest number of digits +// that will let the original floating point value be precisely reconstructed. +func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) { + // If mantissa is zero, the number is zero; stop now. + if mant == 0 { + d.nd = 0 + return + } + + // Compute upper and lower such that any decimal number + // between upper and lower (possibly inclusive) + // will round to the original floating point number. + + // We may see at once that the number is already shortest. + // + // Suppose d is not denormal, so that 2^exp <= d < 10^dp. + // The closest shorter number is at least 10^(dp-nd) away. + // The lower/upper bounds computed below are at distance + // at most 2^(exp-mantbits). + // + // So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits), + // or equivalently log2(10)*(dp-nd) > exp-mantbits. + // It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32). + minexp := flt.bias + 1 // minimum possible exponent + if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) { + // The number is already shortest. + return + } + + // d = mant << (exp - mantbits) + // Next highest floating point number is mant+1 << exp-mantbits. + // Our upper bound is halfway between, mant*2+1 << exp-mantbits-1. + upper := new(decimal) + upper.Assign(mant*2 + 1) + upper.Shift(exp - int(flt.mantbits) - 1) + + // d = mant << (exp - mantbits) + // Next lowest floating point number is mant-1 << exp-mantbits, + // unless mant-1 drops the significant bit and exp is not the minimum exp, + // in which case the next lowest is mant*2-1 << exp-mantbits-1. + // Either way, call it mantlo << explo-mantbits. + // Our lower bound is halfway between, mantlo*2+1 << explo-mantbits-1. + var mantlo uint64 + var explo int + if mant > 1< $(FMT_LOG) || true + @[ ! -s "$(FMT_LOG)" ] || (echo "gofmt failed:" | cat - $(FMT_LOG) && false) + +.PHONY: golint +golint: bin/golint + @$(GOBIN)/golint -set_exit_status ./... + +.PHONY: lint +lint: gofmt golint staticcheck + +.PHONY: staticcheck +staticcheck: bin/staticcheck + @$(GOBIN)/staticcheck ./... + +.PHONY: test +test: + go test -race ./... diff --git a/vendor/go.uber.org/ratelimit/README.md b/vendor/go.uber.org/ratelimit/README.md new file mode 100644 index 0000000..a05a2a8 --- /dev/null +++ b/vendor/go.uber.org/ratelimit/README.md @@ -0,0 +1,46 @@ +# Go rate limiter [![GoDoc][doc-img]][doc] [![Coverage Status][cov-img]][cov] ![test][test-img] + +This package provides a Golang implementation of the leaky-bucket rate limit algorithm. +This implementation refills the bucket based on the time elapsed between +requests instead of requiring an interval clock to fill the bucket discretely. + +Create a rate limiter with a maximum number of operations to perform per second. +Call Take() before each operation. Take will sleep until you can continue. + +```go +import ( + "fmt" + "time" + + "go.uber.org/ratelimit" +) + +func main() { + rl := ratelimit.New(100) // per second + + prev := time.Now() + for i := 0; i < 10; i++ { + now := rl.Take() + fmt.Println(i, now.Sub(prev)) + prev = now + } + + // Output: + // 0 0 + // 1 10ms + // 2 10ms + // 3 10ms + // 4 10ms + // 5 10ms + // 6 10ms + // 7 10ms + // 8 10ms + // 9 10ms +} +``` + +[cov-img]: https://codecov.io/gh/uber-go/ratelimit/branch/master/graph/badge.svg?token=zhLeUjjrm2 +[cov]: https://codecov.io/gh/uber-go/ratelimit +[doc-img]: https://pkg.go.dev/badge/go.uber.org/ratelimit +[doc]: https://pkg.go.dev/go.uber.org/ratelimit +[test-img]: https://github.com/uber-go/ratelimit/workflows/test/badge.svg diff --git a/vendor/go.uber.org/ratelimit/go.mod b/vendor/go.uber.org/ratelimit/go.mod new file mode 100644 index 0000000..7487c42 --- /dev/null +++ b/vendor/go.uber.org/ratelimit/go.mod @@ -0,0 +1,9 @@ +module go.uber.org/ratelimit + +go 1.14 + +require ( + github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 + github.com/stretchr/testify v1.6.1 + go.uber.org/atomic v1.7.0 +) diff --git a/vendor/go.uber.org/ratelimit/go.sum b/vendor/go.uber.org/ratelimit/go.sum new file mode 100644 index 0000000..15512ea --- /dev/null +++ b/vendor/go.uber.org/ratelimit/go.sum @@ -0,0 +1,17 @@ +github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 h1:MzBOUgng9orim59UnfUTLRjMpd09C5uEVQ6RPGeCaVI= +github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129/go.mod h1:rFgpPQZYZ8vdbc+48xibu8ALc3yeyd64IhHS+PU6Yyg= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/go.uber.org/ratelimit/limiter_atomic.go b/vendor/go.uber.org/ratelimit/limiter_atomic.go new file mode 100644 index 0000000..745aa4c --- /dev/null +++ b/vendor/go.uber.org/ratelimit/limiter_atomic.go @@ -0,0 +1,110 @@ +// Copyright (c) 2016,2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package ratelimit // import "go.uber.org/ratelimit" + +import ( + "time" + + "sync/atomic" + "unsafe" +) + +type state struct { + last time.Time + sleepFor time.Duration +} + +type atomicLimiter struct { + state unsafe.Pointer + //lint:ignore U1000 Padding is unused but it is crucial to maintain performance + // of this rate limiter in case of collocation with other frequently accessed memory. + padding [56]byte // cache line size - state pointer size = 64 - 8; created to avoid false sharing. + + perRequest time.Duration + maxSlack time.Duration + clock Clock +} + +// newAtomicBased returns a new atomic based limiter. +func newAtomicBased(rate int, opts ...Option) *atomicLimiter { + // TODO consider moving config building to the implementation + // independent code. + config := buildConfig(opts) + perRequest := config.per / time.Duration(rate) + l := &atomicLimiter{ + perRequest: perRequest, + maxSlack: -1 * time.Duration(config.slack) * perRequest, + clock: config.clock, + } + + initialState := state{ + last: time.Time{}, + sleepFor: 0, + } + atomic.StorePointer(&l.state, unsafe.Pointer(&initialState)) + return l +} + +// Take blocks to ensure that the time spent between multiple +// Take calls is on average time.Second/rate. +func (t *atomicLimiter) Take() time.Time { + var ( + newState state + taken bool + interval time.Duration + ) + for !taken { + now := t.clock.Now() + + previousStatePointer := atomic.LoadPointer(&t.state) + oldState := (*state)(previousStatePointer) + + newState = state{ + last: now, + sleepFor: oldState.sleepFor, + } + + // If this is our first request, then we allow it. + if oldState.last.IsZero() { + taken = atomic.CompareAndSwapPointer(&t.state, previousStatePointer, unsafe.Pointer(&newState)) + continue + } + + // sleepFor calculates how much time we should sleep based on + // the perRequest budget and how long the last request took. + // Since the request may take longer than the budget, this number + // can get negative, and is summed across requests. + newState.sleepFor += t.perRequest - now.Sub(oldState.last) + // We shouldn't allow sleepFor to get too negative, since it would mean that + // a service that slowed down a lot for a short period of time would get + // a much higher RPS following that. + if newState.sleepFor < t.maxSlack { + newState.sleepFor = t.maxSlack + } + if newState.sleepFor > 0 { + newState.last = newState.last.Add(newState.sleepFor) + interval, newState.sleepFor = newState.sleepFor, 0 + } + taken = atomic.CompareAndSwapPointer(&t.state, previousStatePointer, unsafe.Pointer(&newState)) + } + t.clock.Sleep(interval) + return newState.last +} diff --git a/vendor/go.uber.org/ratelimit/limiter_mutexbased.go b/vendor/go.uber.org/ratelimit/limiter_mutexbased.go new file mode 100644 index 0000000..1408f1c --- /dev/null +++ b/vendor/go.uber.org/ratelimit/limiter_mutexbased.go @@ -0,0 +1,88 @@ +// Copyright (c) 2016,2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package ratelimit // import "go.uber.org/ratelimit" + +import ( + "sync" + "time" +) + +type mutexLimiter struct { + sync.Mutex + last time.Time + sleepFor time.Duration + perRequest time.Duration + maxSlack time.Duration + clock Clock +} + +// newMutexBased returns a new atomic based limiter. +func newMutexBased(rate int, opts ...Option) *mutexLimiter { + // TODO consider moving config building to the implementation + // independent code. + config := buildConfig(opts) + perRequest := config.per / time.Duration(rate) + l := &mutexLimiter{ + perRequest: perRequest, + maxSlack: -1 * time.Duration(config.slack) * perRequest, + clock: config.clock, + } + return l +} + +// Take blocks to ensure that the time spent between multiple +// Take calls is on average time.Second/rate. +func (t *mutexLimiter) Take() time.Time { + t.Lock() + defer t.Unlock() + + now := t.clock.Now() + + // If this is our first request, then we allow it. + if t.last.IsZero() { + t.last = now + return t.last + } + + // sleepFor calculates how much time we should sleep based on + // the perRequest budget and how long the last request took. + // Since the request may take longer than the budget, this number + // can get negative, and is summed across requests. + t.sleepFor += t.perRequest - now.Sub(t.last) + + // We shouldn't allow sleepFor to get too negative, since it would mean that + // a service that slowed down a lot for a short period of time would get + // a much higher RPS following that. + if t.sleepFor < t.maxSlack { + t.sleepFor = t.maxSlack + } + + // If sleepFor is positive, then we should sleep now. + if t.sleepFor > 0 { + t.clock.Sleep(t.sleepFor) + t.last = now.Add(t.sleepFor) + t.sleepFor = 0 + } else { + t.last = now + } + + return t.last +} diff --git a/vendor/go.uber.org/ratelimit/ratelimit.go b/vendor/go.uber.org/ratelimit/ratelimit.go new file mode 100644 index 0000000..b5b16e5 --- /dev/null +++ b/vendor/go.uber.org/ratelimit/ratelimit.go @@ -0,0 +1,135 @@ +// Copyright (c) 2016,2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package ratelimit // import "go.uber.org/ratelimit" + +import ( + "time" + + "github.com/andres-erbsen/clock" +) + +// Note: This file is inspired by: +// https://github.com/prashantv/go-bench/blob/master/ratelimit + +// Limiter is used to rate-limit some process, possibly across goroutines. +// The process is expected to call Take() before every iteration, which +// may block to throttle the goroutine. +type Limiter interface { + // Take should block to make sure that the RPS is met. + Take() time.Time +} + +// Clock is the minimum necessary interface to instantiate a rate limiter with +// a clock or mock clock, compatible with clocks created using +// github.com/andres-erbsen/clock. +type Clock interface { + Now() time.Time + Sleep(time.Duration) +} + +// config configures a limiter. +type config struct { + clock Clock + slack int + per time.Duration +} + +// New returns a Limiter that will limit to the given RPS. +func New(rate int, opts ...Option) Limiter { + return newAtomicBased(rate, opts...) +} + +// buildConfig combines defaults with options. +func buildConfig(opts []Option) config { + c := config{ + clock: clock.New(), + slack: 10, + per: time.Second, + } + + for _, opt := range opts { + opt.apply(&c) + } + return c +} + +// Option configures a Limiter. +type Option interface { + apply(*config) +} + +type clockOption struct { + clock Clock +} + +func (o clockOption) apply(c *config) { + c.clock = o.clock +} + +// WithClock returns an option for ratelimit.New that provides an alternate +// Clock implementation, typically a mock Clock for testing. +func WithClock(clock Clock) Option { + return clockOption{clock: clock} +} + +type slackOption int + +func (o slackOption) apply(c *config) { + c.slack = int(o) +} + +// WithoutSlack configures the limiter to be strict and not to accumulate +// previously "unspent" requests for future bursts of traffic. +var WithoutSlack Option = slackOption(0) + +// WithSlack configures custom slack. +// Slack allows the limiter to accumulate "unspent" requests +// for future bursts of traffic. +func WithSlack(slack int) Option { + return slackOption(slack) +} + +type perOption time.Duration + +func (p perOption) apply(c *config) { + c.per = time.Duration(p) +} + +// Per allows configuring limits for different time windows. +// +// The default window is one second, so New(100) produces a one hundred per +// second (100 Hz) rate limiter. +// +// New(2, Per(60*time.Second)) creates a 2 per minute rate limiter. +func Per(per time.Duration) Option { + return perOption(per) +} + +type unlimited struct{} + +// NewUnlimited returns a RateLimiter that is not limited. +func NewUnlimited() Limiter { + return unlimited{} +} + +func (unlimited) Take() time.Time { + return time.Now() +} diff --git a/vendor/modules.txt b/vendor/modules.txt new file mode 100644 index 0000000..be09ce5 --- /dev/null +++ b/vendor/modules.txt @@ -0,0 +1,26 @@ +# github.com/akamensky/argparse v1.2.2 +## explicit +github.com/akamensky/argparse +# github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 +github.com/andres-erbsen/clock +# github.com/avast/retry-go v3.0.0+incompatible +## explicit +github.com/avast/retry-go +# github.com/gorilla/websocket v1.4.2 +github.com/gorilla/websocket +# github.com/grishinsana/goftx v1.2.0 => ../goftx +## explicit +github.com/grishinsana/goftx +github.com/grishinsana/goftx/models +# github.com/pkg/errors v0.9.1 +github.com/pkg/errors +# github.com/robfig/cron/v3 v3.0.0 +## explicit +github.com/robfig/cron/v3 +# github.com/shopspring/decimal v1.2.0 +## explicit +github.com/shopspring/decimal +# go.uber.org/ratelimit v0.2.0 +## explicit +go.uber.org/ratelimit +# github.com/grishinsana/goftx => ../goftx