From 670553dd0c377c8b9c6ceb3604b413ccb806478d Mon Sep 17 00:00:00 2001 From: mzz Date: Fri, 1 Jan 2021 18:21:41 +0800 Subject: [PATCH] feat: reload strategy (#9) --- .gitignore | 2 +- config/config.go | 54 ++++++++++++++++++++++++++------------- dispatcher/dispatcher.go | 1 + dispatcher/tcp/tcp.go | 34 +++++++++++++++++++------ dispatcher/udp/udp.go | 38 +++++++++++++++++++-------- go.mod | 1 + go.sum | 2 ++ main.go | 38 +++++++++++++++++++-------- reload.go | 55 ++++++++++++++++++++++++++++++++++++++++ signal_other.go | 17 +++++++++++++ signal_windows.go | 8 ++++++ systemd/mmp-go.service | 5 ++-- 12 files changed, 206 insertions(+), 49 deletions(-) create mode 100644 reload.go create mode 100644 signal_other.go create mode 100644 signal_windows.go diff --git a/.gitignore b/.gitignore index 2254e3e..7812799 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ *.so *.dylib /shadomplexer-go -/mmp-go +/mmp-go* # Test binary, built with `go test -c` *.test diff --git a/config/config.go b/config/config.go index cb4ca8d..d62025f 100644 --- a/config/config.go +++ b/config/config.go @@ -13,6 +13,7 @@ import ( ) type Config struct { + ConfPath string `json:"-"` Groups []Group `json:"groups"` ClientCapacity int `json:"clientCapacity"` } @@ -30,15 +31,17 @@ type Group struct { UserContextPool *UserContextPool `json:"-"` } -var config *Config -var once sync.Once -var Version = "debug" - const ( // around 30kB per client if there are 300 servers to forward DefaultClientCapacity = 100 ) +var ( + config *Config + once sync.Once + Version = "debug" +) + func (g *Group) BuildMasterKeys() { servers := g.Servers for j := range servers { @@ -135,10 +138,37 @@ func build(config *Config) { g.BuildMasterKeys() } } + +func BuildConfig(confPath string) (conf *Config, err error) { + conf = new(Config) + conf.ConfPath = confPath + b, err := ioutil.ReadFile(confPath) + if err != nil { + return nil, err + } + if err = json.Unmarshal(b, conf); err != nil { + return nil, err + } + if err = parseUpstreams(conf); err != nil { + return nil, err + } + if err = check(conf); err != nil { + return nil, err + } + build(conf) + return +} + +func SetConfig(conf *Config) { + config = conf +} + func GetConfig() *Config { once.Do(func() { + var err error + version := flag.Bool("v", false, "version") - filename := flag.String("conf", "example.json", "config file path") + confPath := flag.String("conf", "example.json", "config file path") flag.Parse() if *version { @@ -146,21 +176,9 @@ func GetConfig() *Config { os.Exit(0) } - config = new(Config) - b, err := ioutil.ReadFile(*filename) - if err != nil { - log.Fatalln(err) - } - if err = json.Unmarshal(b, config); err != nil { - log.Fatalln(err) - } - if err = parseUpstreams(config); err != nil { - log.Fatalln(err) - } - if err = check(config); err != nil { + if config, err = BuildConfig(*confPath); err != nil { log.Fatalln(err) } - build(config) }) return config } diff --git a/dispatcher/dispatcher.go b/dispatcher/dispatcher.go index 354cd90..d728b07 100644 --- a/dispatcher/dispatcher.go +++ b/dispatcher/dispatcher.go @@ -9,6 +9,7 @@ type Dispatcher interface { Listen() (err error) // buf is a buffer to store decrypted text Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte) + UpdateGroup(group *config.Group) Close() (err error) } diff --git a/dispatcher/tcp/tcp.go b/dispatcher/tcp/tcp.go index f3c7d9f..67284d6 100644 --- a/dispatcher/tcp/tcp.go +++ b/dispatcher/tcp/tcp.go @@ -6,9 +6,11 @@ import ( "github.com/Qv2ray/mmp-go/common/pool" "github.com/Qv2ray/mmp-go/config" "github.com/Qv2ray/mmp-go/dispatcher" + "github.com/pkg/errors" "io" "log" "net" + "sync" "time" ) @@ -22,16 +24,23 @@ func init() { dispatcher.Register("tcp", New) } -type Dispatcher struct { - group *config.Group - l net.Listener +type TCP struct { + gMutex sync.RWMutex + group *config.Group + l net.Listener } func New(g *config.Group) (d dispatcher.Dispatcher) { - return &Dispatcher{group: g} + return &TCP{group: g} } -func (d *Dispatcher) Listen() (err error) { +func (d *TCP) UpdateGroup(group *config.Group) { + d.gMutex.Lock() + defer d.gMutex.Unlock() + d.group = group +} + +func (d *TCP) Listen() (err error) { d.l, err = net.Listen("tcp", fmt.Sprintf(":%d", d.group.Port)) if err != nil { return @@ -41,6 +50,12 @@ func (d *Dispatcher) Listen() (err error) { for { conn, err := d.l.Accept() if err != nil { + switch err := err.(type) { + case *net.OpError: + if errors.Is(err.Unwrap(), net.ErrClosed) { + return nil + } + } log.Printf("[error] ReadFrom: %v", err) continue } @@ -53,11 +68,12 @@ func (d *Dispatcher) Listen() (err error) { } } -func (d *Dispatcher) Close() (err error) { +func (d *TCP) Close() (err error) { + log.Printf("[tcp] closed :%v\n", d.group.Port) return d.l.Close() } -func (d *Dispatcher) handleConn(conn net.Conn) error { +func (d *TCP) handleConn(conn net.Conn) error { /* https://github.com/shadowsocks/shadowsocks-org/blob/master/whitepaper/whitepaper.md */ @@ -77,7 +93,9 @@ func (d *Dispatcher) handleConn(conn net.Conn) error { } // get user's context (preference) + d.gMutex.RLock() // avoid insert old servers to the new userContextPool userContext = d.group.UserContextPool.GetOrInsert(conn.RemoteAddr(), d.group.Servers) + d.gMutex.RUnlock() // auth every server server, _ = d.Auth(buf, data, userContext) @@ -130,7 +148,7 @@ func relay(lc, rc net.Conn) error { return <-ch } -func (d *Dispatcher) Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte) { +func (d *TCP) Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte) { if len(data) < BasicLen { return nil, nil } diff --git a/dispatcher/udp/udp.go b/dispatcher/udp/udp.go index 02a379a..52fb2bd 100644 --- a/dispatcher/udp/udp.go +++ b/dispatcher/udp/udp.go @@ -6,9 +6,11 @@ import ( "github.com/Qv2ray/mmp-go/common/pool" "github.com/Qv2ray/mmp-go/config" "github.com/Qv2ray/mmp-go/dispatcher" + "github.com/pkg/errors" "golang.org/x/net/dns/dnsmessage" "log" "net" + "sync" "time" ) @@ -25,17 +27,24 @@ func init() { dispatcher.Register("udp", New) } -type Dispatcher struct { - group *config.Group - c *net.UDPConn - nm *UDPConnMapping +type UDP struct { + gMutex sync.RWMutex + group *config.Group + c *net.UDPConn + nm *UDPConnMapping } func New(g *config.Group) (d dispatcher.Dispatcher) { - return &Dispatcher{group: g, nm: NewUDPConnMapping()} + return &UDP{group: g, nm: NewUDPConnMapping()} } -func (d *Dispatcher) Listen() (err error) { +func (d *UDP) UpdateGroup(group *config.Group) { + d.gMutex.Lock() + defer d.gMutex.Unlock() + d.group = group +} + +func (d *UDP) Listen() (err error) { d.c, err = net.ListenUDP("udp", &net.UDPAddr{Port: d.group.Port}) if err != nil { return @@ -46,6 +55,12 @@ func (d *Dispatcher) Listen() (err error) { for { n, laddr, err := d.c.ReadFrom(buf[:]) if err != nil { + switch err := err.(type) { + case *net.OpError: + if errors.Is(err.Unwrap(), net.ErrClosed) { + return nil + } + } log.Printf("[error] ReadFrom: %v", err) continue } @@ -78,7 +93,7 @@ func addrLen(packet []byte) int { return l } -func (d *Dispatcher) handleConn(laddr net.Addr, data []byte, n int) (err error) { +func (d *UDP) handleConn(laddr net.Addr, data []byte, n int) (err error) { // get conn or dial and relay rc, err := d.GetOrBuildUCPConn(laddr, data[:n]) if err != nil { @@ -111,7 +126,7 @@ func selectTimeout(packet []byte) time.Duration { } // connTimeout is the timeout of connection to build if not exists -func (d *Dispatcher) GetOrBuildUCPConn(laddr net.Addr, data []byte) (rc *net.UDPConn, err error) { +func (d *UDP) GetOrBuildUCPConn(laddr net.Addr, data []byte) (rc *net.UDPConn, err error) { socketIdent := laddr.String() d.nm.Lock() var conn *UDPConn @@ -122,7 +137,9 @@ func (d *Dispatcher) GetOrBuildUCPConn(laddr net.Addr, data []byte) (rc *net.UDP d.nm.Unlock() // get user's context (preference) + d.gMutex.RLock() // avoid insert old servers to the new userContextPool userContext := d.group.UserContextPool.GetOrInsert(laddr, d.group.Servers) + d.gMutex.RUnlock() buf := pool.Get(len(data)) defer pool.Put(buf) @@ -186,7 +203,7 @@ func relay(dst *net.UDPConn, laddr net.Addr, src *net.UDPConn, timeout time.Dura } } -func (d *Dispatcher) Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte) { +func (d *UDP) Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte) { if len(data) < BasicLen { return nil, nil } @@ -195,7 +212,8 @@ func (d *Dispatcher) Auth(buf []byte, data []byte, userContext *config.UserConte }) } -func (d *Dispatcher) Close() (err error) { +func (d *UDP) Close() (err error) { + log.Printf("[udp] closed :%v\n", d.group.Port) return d.c.Close() } diff --git a/go.mod b/go.mod index e3e350d..4d317f1 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/Qv2ray/mmp-go go 1.15 require ( + github.com/pkg/errors v0.9.1 github.com/studentmain/smaead v0.0.0-20201230222852-75aa2464875d golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 diff --git a/go.sum b/go.sum index be187c6..e4c161a 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/studentmain/smaead v0.0.0-20201230222852-75aa2464875d h1:8cJZoaJdg0EFv+7ryIWRTnviorsmmHT5H06jx6621CI= github.com/studentmain/smaead v0.0.0-20201230222852-75aa2464875d/go.mod h1:1jZXK8G4HFsNzvB6tVf8eQ4lKPb9ccsXoZxDah0Yjb0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/main.go b/main.go index 06ca3a0..36dbd45 100644 --- a/main.go +++ b/main.go @@ -10,27 +10,45 @@ import ( ) var protocols = [...]string{"tcp", "udp"} +var wg sync.WaitGroup func main() { + // handle reload + go signalHandler() + + mMutex.Lock() conf := config.GetConfig() - var wg sync.WaitGroup for i := range conf.Groups { wg.Add(1) - go func(group *config.Group) { - err := listen(group, protocols[:]) - if err != nil { - log.Fatalln(err) - } - wg.Done() - }(&conf.Groups[i]) + go listen(&conf.Groups[i]) } + mMutex.Unlock() wg.Wait() } -func listen(group *config.Group, protocols []string) error { +func listen(group *config.Group) { + mMutex.Lock() + if _, ok := mPortDispatcher[group.Port]; !ok { + mPortDispatcher[group.Port] = new([2]dispatcher.Dispatcher) + } + mMutex.Unlock() + err := listenWithProtocols(group, protocols[:]) + if err != nil { + mMutex.Lock() + // error but listening + if _, ok := mPortDispatcher[group.Port]; ok { + log.Fatalln(err) + } + mMutex.Unlock() + } + wg.Done() +} + +func listenWithProtocols(group *config.Group, protocols []string) error { ch := make(chan error, len(protocols)) - for _, protocol := range protocols { + for i, protocol := range protocols { d, _ := dispatcher.New(protocol, group) + (*mPortDispatcher[group.Port])[i] = d go func() { var err error err = d.Listen() diff --git a/reload.go b/reload.go new file mode 100644 index 0000000..74628fe --- /dev/null +++ b/reload.go @@ -0,0 +1,55 @@ +package main + +import ( + "github.com/Qv2ray/mmp-go/config" + "github.com/Qv2ray/mmp-go/dispatcher" + "log" + "sync" +) + +var mMutex sync.Mutex +var mPortDispatcher = make(map[int]*[len(protocols)]dispatcher.Dispatcher) + +func ReloadConfig() { + log.Println("Reloading configuration") + mMutex.Lock() + defer mMutex.Unlock() + + // rebuild config + confPath := config.GetConfig().ConfPath + newConf, err := config.BuildConfig(confPath) + if err != nil { + log.Printf("failed to reload configuration: %v", err) + return + } + config.SetConfig(newConf) + c := newConf + + // update dispatchers + newConfPortSet := make(map[int]struct{}) + for i := range c.Groups { + newConfPortSet[c.Groups[i].Port] = struct{}{} + + if t, ok := mPortDispatcher[c.Groups[i].Port]; ok { + // update the existing dispatcher + for j := range protocols { + t[j].UpdateGroup(&c.Groups[i]) + } + } else { + // add a new port dispatcher + wg.Add(1) + go listen(&c.Groups[i]) + } + } + // close all removed port dispatcher + for port := range mPortDispatcher { + if _, ok := newConfPortSet[port]; !ok { + t := mPortDispatcher[port] + delete(mPortDispatcher, port) + for j := range protocols { + _ = (*t)[j].Close() + } + } + } + log.Println("Reloaded configuration") +} diff --git a/signal_other.go b/signal_other.go new file mode 100644 index 0000000..41e2363 --- /dev/null +++ b/signal_other.go @@ -0,0 +1,17 @@ +// +build !windows + +package main + +import ( + "os" + "os/signal" + "syscall" +) + +func signalHandler() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGUSR1) + for range ch { + ReloadConfig() + } +} diff --git a/signal_windows.go b/signal_windows.go new file mode 100644 index 0000000..e7eda44 --- /dev/null +++ b/signal_windows.go @@ -0,0 +1,8 @@ +package main + +import "log" + +// not support windows +func signalHandler() { + log.Println(`Signal-triggered configuration reloading is not supported on Windows`) +} diff --git a/systemd/mmp-go.service b/systemd/mmp-go.service index 4571a88..f19772a 100644 --- a/systemd/mmp-go.service +++ b/systemd/mmp-go.service @@ -9,10 +9,11 @@ Type=simple User=nobody Restart=always LimitNOFILE=102400 -CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_BIND_SERVICE CAP_NET_RAW -AmbientCapabilities=CAP_NET_ADMIN CAP_NET_BIND_SERVICE CAP_NET_RAW +CapabilityBoundingSet=CAP_NET_BIND_SERVICE +AmbientCapabilities=CAP_NET_BIND_SERVICE NoNewPrivileges=true ExecStart=/usr/bin/mmp-go -conf /etc/mmp-go/config.json +ExecReload=/bin/kill -USR1 $MAINPID [Install] WantedBy=multi-user.target