Skip to content

Commit

Permalink
feat: when reloading, remain those servers whose upstream pull conf w…
Browse files Browse the repository at this point in the history
…ith net error (#16)

* feat: remain those servers whose upstream pull conf with net error when reloading

* fix: potential nil pointer error

* refine code

* use brute-force-match instead of hash

* test: test and fix

* clean code
  • Loading branch information
mzz2017 authored Feb 18, 2021
1 parent 2c71fd9 commit 48ffcbc
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 35 deletions.
65 changes: 51 additions & 14 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package config

import (
"encoding/json"
"errors"
"flag"
"fmt"
"github.com/Qv2ray/mmp-go/cipher"
"github.com/Qv2ray/mmp-go/infra/lru"
"log"
"net"
"os"
"sync"
"time"
Expand All @@ -17,17 +19,46 @@ type Config struct {
Groups []Group `json:"groups"`
}
type Server struct {
Target string `json:"target"`
Method string `json:"method"`
Password string `json:"password"`
MasterKey []byte `json:"-"`
Target string `json:"target"`
Method string `json:"method"`
Password string `json:"password"`
MasterKey []byte `json:"-"`
UpstreamConf *UpstreamConf `json:"-"`
}
type Group struct {
Port int `json:"port"`
Servers []Server `json:"servers"`
Upstreams []map[string]string `json:"upstreams"`
LRUSize int `json:"lruSize"`
UserContextPool *UserContextPool `json:"-"`
Port int `json:"port"`
Servers []Server `json:"servers"`
Upstreams []UpstreamConf `json:"upstreams"`
UserContextPool *UserContextPool `json:"-"`
}
type UpstreamConf map[string]string

const (
PullingErrorKey = "__pulling_error__"
PullingErrorNetError = "net_error"
)

func (uc UpstreamConf) InitPullingError() {
if _, ok := uc[PullingErrorKey]; !ok {
uc[PullingErrorKey] = ""
}
}

func (uc UpstreamConf) Equal(that UpstreamConf) bool {
uc.InitPullingError()
that.InitPullingError()
if len(uc) != len(that) {
return false
}
for k, v := range uc {
if k == PullingErrorKey {
continue
}
if vv, ok := that[k]; !ok || vv != v {
return false
}
}
return true
}

const (
Expand Down Expand Up @@ -85,28 +116,34 @@ func parseUpstreams(config *Config) (err error) {
logged := false
for i := range config.Groups {
g := &config.Groups[i]
for j, u := range g.Upstreams {
for j, upstreamConf := range g.Upstreams {
var upstream Upstream
switch u["type"] {
switch upstreamConf["type"] {
case "outline":
var outline Outline
err = Map2upstream(u, &outline)
err = Map2Upstream(upstreamConf, &outline)
if err != nil {
return
}
upstream = outline
default:
return fmt.Errorf("unknown upstream type: %v", u["type"])
return fmt.Errorf("unknown upstream type: %v", upstreamConf["type"])
}
if !logged {
log.Println("pulling configures from upstreams...")
logged = true
}
servers, err := upstream.GetServers()
if err != nil {
log.Printf("[warning] Failed to retrieve configure from groups[%d].upstreams[%d]: %v\n", i, j, err)
if netError := new(net.Error); errors.As(err, netError) {
upstreamConf[PullingErrorKey] = PullingErrorNetError
}
log.Printf("[warning] Failed to retrieve configure from groups[%d].upstreams[%d]: %v: %v\n", i, j, err, upstreamConf[PullingErrorKey])
continue
}
for i := range servers {
servers[i].UpstreamConf = &upstreamConf
}
g.Servers = append(g.Servers, servers...)
}
}
Expand Down
16 changes: 8 additions & 8 deletions config/outline.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (outline Outline) getConfig() ([]byte, error) {
// concatenate errors
err = errs[0]
for i := 1; i < len(errs); i++ {
err = fmt.Errorf("%v; %v", err, errs[i])
err = fmt.Errorf("%w; %s", err, errs[i].Error())
}
return nil, err
}
Expand All @@ -83,7 +83,7 @@ func (outline Outline) getConfigFromLink() ([]byte, error) {
}
resp, err := client.Get(outline.Link)
if err != nil {
return nil, fmt.Errorf("getConfigFromLink failed: %v", err)
return nil, fmt.Errorf("getConfigFromLink failed: %w", err)
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
Expand Down Expand Up @@ -113,7 +113,7 @@ func (outline Outline) getConfigFromApi() ([]byte, error) {
outline.ApiUrl = strings.TrimSuffix(outline.ApiUrl, "/")
resp, err := client.Get(fmt.Sprintf("%v/access-keys", outline.ApiUrl))
if err != nil {
return nil, fmt.Errorf("getConfigFromLink failed: %v", err)
return nil, fmt.Errorf("getConfigFromApi failed: %w", err)
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
Expand All @@ -130,7 +130,7 @@ func (outline Outline) getConfigFromSSH() ([]byte, error) {
if outline.SSHPrivateKey != "" {
signer, err := ssh.ParsePrivateKey([]byte(outline.SSHPrivateKey))
if err != nil {
return nil, fmt.Errorf("parse privateKey error: %v", err)
return nil, fmt.Errorf("parse privateKey error: %w", err)
}
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
Expand All @@ -151,18 +151,18 @@ func (outline Outline) getConfigFromSSH() ([]byte, error) {
}
client, err := ssh.Dial("tcp", net.JoinHostPort(outline.Server, port), conf)
if err != nil {
return nil, fmt.Errorf("failed to dial: %v", err)
return nil, fmt.Errorf("failed to dial: %w", err)
}
defer client.Close()

session, err := client.NewSession()
if err != nil {
return nil, fmt.Errorf("failed to create session: %v", err)
return nil, fmt.Errorf("failed to create session: %w", err)
}
defer session.Close()
out, err := session.CombinedOutput("cat /opt/outline/persisted-state/shadowbox_config.json")
if err != nil {
err = fmt.Errorf("%v: %v", string(bytes.TrimSpace(out)), err)
err = fmt.Errorf("%v: %w", string(bytes.TrimSpace(out)), err)
return nil, err
}
return out, nil
Expand All @@ -171,7 +171,7 @@ func (outline Outline) getConfigFromSSH() ([]byte, error) {
func (outline Outline) GetServers() (servers []Server, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("outline.GetGroups: %v", err)
err = fmt.Errorf("outline.GetGroups: %w", err)
}
}()
b, err := outline.getConfig()
Expand Down
2 changes: 1 addition & 1 deletion config/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type Upstream interface {

var InvalidUpstreamErr = fmt.Errorf("invalid upstream")

func Map2upstream(m map[string]string, upstream interface{}) error {
func Map2Upstream(m map[string]string, upstream interface{}) error {
v := reflect.ValueOf(upstream)
if !v.IsValid() {
return fmt.Errorf("upstream should not be nil")
Expand Down
15 changes: 6 additions & 9 deletions dispatcher/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,8 @@ func (d *TCP) Listen() (err error) {
for {
conn, err := d.l.Accept()
if err != nil {
switch err := err.(type) {
case *net.OpError:
if errors.Is(err.Unwrap(), net.ErrClosed) {
return nil
}
if errors.Is(err, net.ErrClosed) {
return nil
}
log.Printf("[error] ReadFrom: %v", err)
continue
Expand Down Expand Up @@ -89,7 +86,7 @@ func (d *TCP) handleConn(conn net.Conn) error {
defer pool.Put(buf)
n, err := io.ReadFull(conn, data)
if err != nil {
return fmt.Errorf("[tcp] handleConn readfull error: %v", err)
return fmt.Errorf("[tcp] handleConn readfull error: %w", err)
}

// get user's context (preference)
Expand All @@ -110,13 +107,13 @@ func (d *TCP) handleConn(conn net.Conn) error {
// dial and relay
rc, err := net.Dial("tcp", server.Target)
if err != nil {
return fmt.Errorf("[tcp] handleConn dial error: %v", err)
return fmt.Errorf("[tcp] handleConn dial error: %w", err)
}

_ = rc.SetDeadline(time.Now().Add(DefaultTimeout))
_, err = rc.Write(data[:n])
if err != nil {
return fmt.Errorf("[tcp] handleConn write error: %v", err)
return fmt.Errorf("[tcp] handleConn write error: %w", err)
}

log.Printf("[tcp] %s <-> %s <-> %s", conn.RemoteAddr(), conn.LocalAddr(), rc.RemoteAddr())
Expand All @@ -125,7 +122,7 @@ func (d *TCP) handleConn(conn net.Conn) error {
if err, ok := err.(net.Error); ok && err.Timeout() {
return nil // ignore i/o timeout
}
return fmt.Errorf("[tcp] handleConn relay error: %v", err)
return fmt.Errorf("[tcp] handleConn relay error: %w", err)
}
return nil
}
Expand Down
6 changes: 3 additions & 3 deletions dispatcher/udp/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ func (d *UDP) handleConn(laddr net.Addr, data []byte, n int) (err error) {
if err == AuthFailedErr {
return nil
}
return fmt.Errorf("[udp] handleConn dial target error: %v", err)
return fmt.Errorf("[udp] handleConn dial target error: %w", err)
}

// send packet
if _, err = rc.Write(data[:n]); err != nil {
return fmt.Errorf("[udp] handleConn write error: %v", err)
return fmt.Errorf("[udp] handleConn write error: %w", err)
}
return nil
}
Expand Down Expand Up @@ -136,7 +136,7 @@ func (d *UDP) GetOrBuildUCPConn(laddr net.Addr, data []byte) (rc *net.UDPConn, e
d.nm.Lock()
d.nm.Remove(socketIdent) // close channel to inform that establishment ends
d.nm.Unlock()
return nil, fmt.Errorf("GetOrBuildUCPConn dial error: %v", err)
return nil, fmt.Errorf("GetOrBuildUCPConn dial error: %w", err)
}
rc = rconn.(*net.UDPConn)
d.nm.Lock()
Expand Down
35 changes: 35 additions & 0 deletions reload.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,46 @@ func ReloadConfig() {

// rebuild config
confPath := config.GetConfig().ConfPath
oldConf := config.GetConfig()
newConf, err := config.BuildConfig(confPath)
if err != nil {
log.Printf("failed to reload configuration: %v", err)
return
}
// check if there is any net error when pulling the upstream configurations
for i := range newConf.Groups {
newGroup := &newConf.Groups[i]
for j := range newGroup.Upstreams {
newUpstream := newGroup.Upstreams[j]
if newUpstream[config.PullingErrorKey] != config.PullingErrorNetError {
continue
}
// net error, remain those servers

// find the group in the oldConf
var oldGroup *config.Group
for k := range oldConf.Groups {
// they should have the same port
if oldConf.Groups[k].Port != newGroup.Port {
continue
}
oldGroup = &oldConf.Groups[k]
break
}
if oldGroup == nil {
// cannot find the corresponding old group
continue
}
// check if upstreamConf can match
for k := range oldGroup.Servers {
oldServer := oldGroup.Servers[k]
if oldServer.UpstreamConf != nil && newUpstream.Equal(*oldServer.UpstreamConf) {
// remain the server
newGroup.Servers = append(newGroup.Servers, oldServer)
}
}
}
}
config.SetConfig(newConf)
c := newConf

Expand Down

0 comments on commit 48ffcbc

Please sign in to comment.