diff --git a/config/outline.go b/config/outline.go index 5639b10..b314924 100644 --- a/config/outline.go +++ b/config/outline.go @@ -6,7 +6,6 @@ import ( "fmt" "golang.org/x/crypto/ssh" "io/ioutil" - "log" "net" "net/http" "strconv" @@ -23,7 +22,46 @@ type Outline struct { SSHPassword string `json:"sshPassword"` } +func (outline Outline) getConfig() ([]byte, error) { + if outline.Server == "" { + return nil, fmt.Errorf("server field cannot be empty") + } + tryList := []func() ([]byte, error){ + outline.getConfigFromLink, + outline.getConfigFromSSH, + } + var ( + err error + errs []error + b []byte + ) + for _, f := range tryList { + b, err = f() + if err != nil { + errs = append(errs, err) + continue + } + if b != nil { + break + } + } + if b != nil { + return b, nil + } + if len(errs) > 0 { + err = errs[0] + for i := 1; i < len(errs); i++ { + err = fmt.Errorf("%v; %v", err, errs[i]) + } + return nil, err + } + return nil, InvalidUpstreamErr +} + func (outline Outline) getConfigFromLink() ([]byte, error) { + if outline.Link == "" { + return nil, nil + } client := http.Client{ Timeout: 10 * time.Second, } @@ -36,6 +74,9 @@ func (outline Outline) getConfigFromLink() ([]byte, error) { } func (outline Outline) getConfigFromSSH() ([]byte, error) { + if outline.SSHUsername == "" || (outline.SSHPrivateKey == "" && outline.SSHPassword == "") { + return nil, nil + } var ( conf *ssh.ClientConfig authMethods []ssh.AuthMethod @@ -86,14 +127,7 @@ func (outline Outline) GetServers() (servers []Server, err error) { err = fmt.Errorf("outline.GetGroups: %v", err) } }() - var b []byte - if outline.Link != "" { - b, err = outline.getConfigFromLink() - } - if err != nil { - log.Printf("[warning] %v\n", err) - b, err = outline.getConfigFromSSH() - } + b, err := outline.getConfig() if err != nil { return } diff --git a/config/upstream.go b/config/upstream.go index 70658be..2915e7b 100644 --- a/config/upstream.go +++ b/config/upstream.go @@ -9,6 +9,8 @@ type Upstream interface { GetServers() (servers []Server, err error) } +var InvalidUpstreamErr = fmt.Errorf("invalid upstream") + func Map2upstream(m map[string]string, upstream interface{}) error { v := reflect.ValueOf(upstream) if !v.IsValid() {