Skip to content

Commit

Permalink
feat: reload strategy (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzz2017 authored Jan 1, 2021
1 parent 5c3e8c9 commit 670553d
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 49 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*.so
*.dylib
/shadomplexer-go
/mmp-go
/mmp-go*

# Test binary, built with `go test -c`
*.test
Expand Down
54 changes: 36 additions & 18 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

type Config struct {
ConfPath string `json:"-"`
Groups []Group `json:"groups"`
ClientCapacity int `json:"clientCapacity"`
}
Expand All @@ -30,15 +31,17 @@ type Group struct {
UserContextPool *UserContextPool `json:"-"`
}

var config *Config
var once sync.Once
var Version = "debug"

const (
// around 30kB per client if there are 300 servers to forward
DefaultClientCapacity = 100
)

var (
config *Config
once sync.Once
Version = "debug"
)

func (g *Group) BuildMasterKeys() {
servers := g.Servers
for j := range servers {
Expand Down Expand Up @@ -135,32 +138,47 @@ func build(config *Config) {
g.BuildMasterKeys()
}
}

func BuildConfig(confPath string) (conf *Config, err error) {
conf = new(Config)
conf.ConfPath = confPath
b, err := ioutil.ReadFile(confPath)
if err != nil {
return nil, err
}
if err = json.Unmarshal(b, conf); err != nil {
return nil, err
}
if err = parseUpstreams(conf); err != nil {
return nil, err
}
if err = check(conf); err != nil {
return nil, err
}
build(conf)
return
}

func SetConfig(conf *Config) {
config = conf
}

func GetConfig() *Config {
once.Do(func() {
var err error

version := flag.Bool("v", false, "version")
filename := flag.String("conf", "example.json", "config file path")
confPath := flag.String("conf", "example.json", "config file path")
flag.Parse()

if *version {
fmt.Println(Version)
os.Exit(0)
}

config = new(Config)
b, err := ioutil.ReadFile(*filename)
if err != nil {
log.Fatalln(err)
}
if err = json.Unmarshal(b, config); err != nil {
log.Fatalln(err)
}
if err = parseUpstreams(config); err != nil {
log.Fatalln(err)
}
if err = check(config); err != nil {
if config, err = BuildConfig(*confPath); err != nil {
log.Fatalln(err)
}
build(config)
})
return config
}
1 change: 1 addition & 0 deletions dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type Dispatcher interface {
Listen() (err error)
// buf is a buffer to store decrypted text
Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte)
UpdateGroup(group *config.Group)
Close() (err error)
}

Expand Down
34 changes: 26 additions & 8 deletions dispatcher/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"github.com/Qv2ray/mmp-go/common/pool"
"github.com/Qv2ray/mmp-go/config"
"github.com/Qv2ray/mmp-go/dispatcher"
"github.com/pkg/errors"
"io"
"log"
"net"
"sync"
"time"
)

Expand All @@ -22,16 +24,23 @@ func init() {
dispatcher.Register("tcp", New)
}

type Dispatcher struct {
group *config.Group
l net.Listener
type TCP struct {
gMutex sync.RWMutex
group *config.Group
l net.Listener
}

func New(g *config.Group) (d dispatcher.Dispatcher) {
return &Dispatcher{group: g}
return &TCP{group: g}
}

func (d *Dispatcher) Listen() (err error) {
func (d *TCP) UpdateGroup(group *config.Group) {
d.gMutex.Lock()
defer d.gMutex.Unlock()
d.group = group
}

func (d *TCP) Listen() (err error) {
d.l, err = net.Listen("tcp", fmt.Sprintf(":%d", d.group.Port))
if err != nil {
return
Expand All @@ -41,6 +50,12 @@ func (d *Dispatcher) Listen() (err error) {
for {
conn, err := d.l.Accept()
if err != nil {
switch err := err.(type) {
case *net.OpError:
if errors.Is(err.Unwrap(), net.ErrClosed) {
return nil
}
}
log.Printf("[error] ReadFrom: %v", err)
continue
}
Expand All @@ -53,11 +68,12 @@ func (d *Dispatcher) Listen() (err error) {
}
}

func (d *Dispatcher) Close() (err error) {
func (d *TCP) Close() (err error) {
log.Printf("[tcp] closed :%v\n", d.group.Port)
return d.l.Close()
}

func (d *Dispatcher) handleConn(conn net.Conn) error {
func (d *TCP) handleConn(conn net.Conn) error {
/*
https://github.com/shadowsocks/shadowsocks-org/blob/master/whitepaper/whitepaper.md
*/
Expand All @@ -77,7 +93,9 @@ func (d *Dispatcher) handleConn(conn net.Conn) error {
}

// get user's context (preference)
d.gMutex.RLock() // avoid insert old servers to the new userContextPool
userContext = d.group.UserContextPool.GetOrInsert(conn.RemoteAddr(), d.group.Servers)
d.gMutex.RUnlock()

// auth every server
server, _ = d.Auth(buf, data, userContext)
Expand Down Expand Up @@ -130,7 +148,7 @@ func relay(lc, rc net.Conn) error {
return <-ch
}

func (d *Dispatcher) Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte) {
func (d *TCP) Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte) {
if len(data) < BasicLen {
return nil, nil
}
Expand Down
38 changes: 28 additions & 10 deletions dispatcher/udp/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"github.com/Qv2ray/mmp-go/common/pool"
"github.com/Qv2ray/mmp-go/config"
"github.com/Qv2ray/mmp-go/dispatcher"
"github.com/pkg/errors"
"golang.org/x/net/dns/dnsmessage"
"log"
"net"
"sync"
"time"
)

Expand All @@ -25,17 +27,24 @@ func init() {
dispatcher.Register("udp", New)
}

type Dispatcher struct {
group *config.Group
c *net.UDPConn
nm *UDPConnMapping
type UDP struct {
gMutex sync.RWMutex
group *config.Group
c *net.UDPConn
nm *UDPConnMapping
}

func New(g *config.Group) (d dispatcher.Dispatcher) {
return &Dispatcher{group: g, nm: NewUDPConnMapping()}
return &UDP{group: g, nm: NewUDPConnMapping()}
}

func (d *Dispatcher) Listen() (err error) {
func (d *UDP) UpdateGroup(group *config.Group) {
d.gMutex.Lock()
defer d.gMutex.Unlock()
d.group = group
}

func (d *UDP) Listen() (err error) {
d.c, err = net.ListenUDP("udp", &net.UDPAddr{Port: d.group.Port})
if err != nil {
return
Expand All @@ -46,6 +55,12 @@ func (d *Dispatcher) Listen() (err error) {
for {
n, laddr, err := d.c.ReadFrom(buf[:])
if err != nil {
switch err := err.(type) {
case *net.OpError:
if errors.Is(err.Unwrap(), net.ErrClosed) {
return nil
}
}
log.Printf("[error] ReadFrom: %v", err)
continue
}
Expand Down Expand Up @@ -78,7 +93,7 @@ func addrLen(packet []byte) int {
return l
}

func (d *Dispatcher) handleConn(laddr net.Addr, data []byte, n int) (err error) {
func (d *UDP) handleConn(laddr net.Addr, data []byte, n int) (err error) {
// get conn or dial and relay
rc, err := d.GetOrBuildUCPConn(laddr, data[:n])
if err != nil {
Expand Down Expand Up @@ -111,7 +126,7 @@ func selectTimeout(packet []byte) time.Duration {
}

// connTimeout is the timeout of connection to build if not exists
func (d *Dispatcher) GetOrBuildUCPConn(laddr net.Addr, data []byte) (rc *net.UDPConn, err error) {
func (d *UDP) GetOrBuildUCPConn(laddr net.Addr, data []byte) (rc *net.UDPConn, err error) {
socketIdent := laddr.String()
d.nm.Lock()
var conn *UDPConn
Expand All @@ -122,7 +137,9 @@ func (d *Dispatcher) GetOrBuildUCPConn(laddr net.Addr, data []byte) (rc *net.UDP
d.nm.Unlock()

// get user's context (preference)
d.gMutex.RLock() // avoid insert old servers to the new userContextPool
userContext := d.group.UserContextPool.GetOrInsert(laddr, d.group.Servers)
d.gMutex.RUnlock()

buf := pool.Get(len(data))
defer pool.Put(buf)
Expand Down Expand Up @@ -186,7 +203,7 @@ func relay(dst *net.UDPConn, laddr net.Addr, src *net.UDPConn, timeout time.Dura
}
}

func (d *Dispatcher) Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte) {
func (d *UDP) Auth(buf []byte, data []byte, userContext *config.UserContext) (hit *config.Server, content []byte) {
if len(data) < BasicLen {
return nil, nil
}
Expand All @@ -195,7 +212,8 @@ func (d *Dispatcher) Auth(buf []byte, data []byte, userContext *config.UserConte
})
}

func (d *Dispatcher) Close() (err error) {
func (d *UDP) Close() (err error) {
log.Printf("[udp] closed :%v\n", d.group.Port)
return d.c.Close()
}

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/Qv2ray/mmp-go
go 1.15

require (
github.com/pkg/errors v0.9.1
github.com/studentmain/smaead v0.0.0-20201230222852-75aa2464875d
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/studentmain/smaead v0.0.0-20201230222852-75aa2464875d h1:8cJZoaJdg0EFv+7ryIWRTnviorsmmHT5H06jx6621CI=
github.com/studentmain/smaead v0.0.0-20201230222852-75aa2464875d/go.mod h1:1jZXK8G4HFsNzvB6tVf8eQ4lKPb9ccsXoZxDah0Yjb0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down
38 changes: 28 additions & 10 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,45 @@ import (
)

var protocols = [...]string{"tcp", "udp"}
var wg sync.WaitGroup

func main() {
// handle reload
go signalHandler()

mMutex.Lock()
conf := config.GetConfig()
var wg sync.WaitGroup
for i := range conf.Groups {
wg.Add(1)
go func(group *config.Group) {
err := listen(group, protocols[:])
if err != nil {
log.Fatalln(err)
}
wg.Done()
}(&conf.Groups[i])
go listen(&conf.Groups[i])
}
mMutex.Unlock()
wg.Wait()
}

func listen(group *config.Group, protocols []string) error {
func listen(group *config.Group) {
mMutex.Lock()
if _, ok := mPortDispatcher[group.Port]; !ok {
mPortDispatcher[group.Port] = new([2]dispatcher.Dispatcher)
}
mMutex.Unlock()
err := listenWithProtocols(group, protocols[:])
if err != nil {
mMutex.Lock()
// error but listening
if _, ok := mPortDispatcher[group.Port]; ok {
log.Fatalln(err)
}
mMutex.Unlock()
}
wg.Done()
}

func listenWithProtocols(group *config.Group, protocols []string) error {
ch := make(chan error, len(protocols))
for _, protocol := range protocols {
for i, protocol := range protocols {
d, _ := dispatcher.New(protocol, group)
(*mPortDispatcher[group.Port])[i] = d
go func() {
var err error
err = d.Listen()
Expand Down
Loading

0 comments on commit 670553d

Please sign in to comment.