Skip to content

Commit

Permalink
refactor: Server.Close 统一调用 http.Server.Shutdown 方法
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed Mar 4, 2024
1 parent 48bd08e commit 1a1951e
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 48 deletions.
4 changes: 2 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ type Server interface {
// Serve 开始 HTTP 服务
//
// 这是个阻塞方法,会等待 [Server.Close] 执行完之后才返回。
// 始终返回非空的错误对象,如果是被 [Server.Close] 关闭的,也将返回 [http.ErrServerClosed]。
// 始终返回非空的错误对象,如果是由 [Server.Close] 关闭的,将返回 [http.ErrServerClosed]。
Serve() error

// Close 触发关闭服务事件
//
// 需要等到 [Server.Serve] 返回才表示服务结束
// 只是触发事件,需要等到 [Server.Serve] 返回才表示服务真正结束
// 调用此方法表示 [Server] 的生命周期结束,对象将处于不可用状态。
Close(shutdownTimeout time.Duration)

Expand Down
24 changes: 16 additions & 8 deletions server/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ type App struct {
restartLock sync.Mutex
}

func (app *App) getServer() web.Server {
app.srvLock.RLock()
s := app.srv
app.srvLock.RUnlock()
return s
}

func (app *App) setServer(s web.Server) {
app.srvLock.Lock()
app.srv = s
app.srvLock.Unlock()
}

func (app *App) init() (err error) {
if app.NewServer == nil {
panic("app.NewServer 不能为空")
Expand All @@ -46,7 +59,7 @@ func (app *App) Exec() error {

RESTART:
app.restart = false
err := app.srv.Serve()
err := app.getServer().Serve()
if app.restart { // 等待 Serve 过程中,如果调用 RestartServer,会将 app.restart 设置为 true。
goto RESTART
}
Expand All @@ -62,19 +75,14 @@ func (app *App) RestartServer() {

app.restart = true

// 先拿到旧服务,以便在新服务初始化失败时能正确输出日志。
app.srvLock.RLock()
old := app.srv
app.srvLock.RUnlock()
old := app.getServer() // 先拿到旧服务,以便在新服务初始化失败时能正确输出日志。

srv, err := app.NewServer()
if err != nil {
old.Logs().ERROR().Error(err)
return
}
app.srvLock.Lock()
app.srv = srv
app.srvLock.Unlock()
app.setServer(srv)

old.Close(app.ShutdownTimeout) // 新服务声明成功,尝试关闭旧服务。
}
8 changes: 4 additions & 4 deletions server/app/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,23 @@ func TestCLI(t *testing.T) {
time.Sleep(500 * time.Millisecond) // 等待 go func 启动完成

// restart1
s1 := cmd.app.srv
s1 := cmd.app.getServer()
t1 := s1.Uptime()
cmd.Name = "restart1"
cmd.RestartServer()
time.Sleep(shutdownTimeout + 500*time.Millisecond) // 此值要大于 CLI.ShutdownTimeout
s2 := cmd.app.srv
s2 := cmd.app.getServer()
t2 := s2.Uptime()
a.True(t2.After(t1)).NotEqual(s1, s2)

// restart2
cmd.Name = "restart2"
cmd.RestartServer()
time.Sleep(shutdownTimeout + 500*time.Millisecond) // 此值要大于 CLI.ShutdownTimeout
t3 := cmd.app.srv.Uptime()
t3 := cmd.app.getServer().Uptime()
a.True(t3.After(t2))

cmd.app.srv.Close(0)
cmd.app.getServer().Close(0)
<-exit
}

Expand Down
10 changes: 5 additions & 5 deletions server/app/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,22 @@ func TestSignalHUP(t *testing.T) {
}()
time.Sleep(2000 * time.Millisecond) // 等待 go func 启动完成
a.NotNil(cmd.app).
NotNil(cmd.app.srv)
NotNil(cmd.app.getServer())

p, err := os.FindProcess(os.Getpid())
a.NotError(err).NotNil(p)

// hup1
t1 := cmd.app.srv.Uptime()
t1 := cmd.app.getServer().Uptime()
a.NotError(p.Signal(syscall.SIGHUP)).Wait(500 * time.Millisecond) // 此值要大于 CLI.ShutdownTimeout
t2 := cmd.app.srv.Uptime()
t2 := cmd.app.getServer().Uptime()
a.True(t2.After(t1))

// hup2
a.NotError(p.Signal(syscall.SIGHUP)).Wait(500 * time.Millisecond) // 此值要大于 CLI.ShutdownTimeout
t3 := cmd.app.srv.Uptime()
t3 := cmd.app.getServer().Uptime()
a.True(t3.After(t2))

cmd.app.srv.Close(0)
cmd.app.getServer().Close(0)
<-exit
}
43 changes: 20 additions & 23 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import (
type (
httpServer struct {
*web.InternalServer
httpServer *http.Server
state web.State
closed chan struct{}
hs *http.Server
state web.State
closed chan struct{}
}

service struct {
Expand All @@ -45,16 +45,16 @@ type (

func newHTTPServer(name, version string, o *Options, s web.Server) *httpServer {
srv := &httpServer{
httpServer: o.HTTPServer,
state: web.Stopped,
closed: make(chan struct{}, 1),
hs: o.HTTPServer,
state: web.Stopped,
closed: make(chan struct{}, 1),
}
if s == nil {
s = srv
}

srv.InternalServer = o.internalServer(name, version, s)
srv.httpServer.Handler = srv
srv.hs.Handler = srv

for _, f := range o.Init { // NOTE: 需要保证在最后
f(srv)
Expand Down Expand Up @@ -83,10 +83,10 @@ func (srv *httpServer) Serve() (err error) {
}
srv.state = web.Running

if c := srv.httpServer.TLSConfig; c != nil && (len(c.Certificates) > 0 || c.GetCertificate != nil) {
err = srv.httpServer.ListenAndServeTLS("", "")
if c := srv.hs.TLSConfig; c != nil && (len(c.Certificates) > 0 || c.GetCertificate != nil) {
err = srv.hs.ListenAndServeTLS("", "")
} else {
err = srv.httpServer.ListenAndServe()
err = srv.hs.ListenAndServe()
}

if errors.Is(err, http.ErrServerClosed) {
Expand All @@ -101,28 +101,25 @@ func (srv *httpServer) Close(shutdownTimeout time.Duration) {
if srv.State() != web.Running {
return
}
srv.state = web.Stopped // 调用 Close 即设置状态

ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)

defer func() {
srv.InternalServer.Close()
srv.state = web.Stopped
srv.closed <- struct{}{} // NOTE: 保证最后执行
}()
cancel()

if shutdownTimeout == 0 {
if err := srv.httpServer.Close(); err != nil {
srv.Logs().ERROR().Error(err)
}
return
}
// [http.Server.Shutdown] 会让 [http.Server.ListenAndServe] 等方法直接返回,
// 所以由 srv.close 保证在当前函数返回之后再通知 [Server.Serve] 退出。
srv.closed <- struct{}{}
}()

c, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
if err := srv.httpServer.Shutdown(c); err != nil && !errors.Is(err, context.DeadlineExceeded) {
if err := srv.hs.Shutdown(ctx); err != nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
srv.Logs().ERROR().Error(err)
}
}

// NewService 将 [web.Server] 作为微服务节点
// NewService 声明微服务节点
func NewService(name, version string, o *Options) (web.Server, error) {
o, err := sanitizeOptions(o, typeService)
if err != nil {
Expand Down
15 changes: 10 additions & 5 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func TestNew(t *testing.T) {
Equal(srv.Location(), time.Local)

s, ok := srv.(*httpServer)
a.True(ok).Equal(s.httpServer.Handler, s).
Equal(s.httpServer.Addr, "")
a.True(ok).Equal(s.hs.Handler, s).
Equal(s.hs.Addr, "")

d, ok := srv.Cache().(cache.Driver)
a.True(ok).
Expand Down Expand Up @@ -205,12 +205,17 @@ func TestHTTPServer_Close(t *testing.T) {

servertest.Get(a, "http://localhost:8080/test").Do(nil).Status(http.StatusAccepted)

// 连接被关闭,返回错误内容
// 尝试关闭连接
a.Equal(0, c)
resp, err := http.Get("http://localhost:8080/close")
time.Sleep(500 * time.Microsecond) // Handle 中的 Server.Close 是触发关闭服务,这里要等待真正完成
a.Error(err).Nil(resp).True(c > 0)
a.Wait(500 * time.Microsecond). // Handle 中的 Server.Close 是触发关闭服务,这里要等待真正完成
NotError(err).NotNil(resp).True(c > 0)

// 连接被关闭,返回错误内容
resp, err = http.Get("http://localhost:8080/close")
a.Error(err).Nil(resp)

// 连接被关闭,返回错误内容
resp, err = http.Get("http://localhost:8080/test")
a.Error(err).Nil(resp)
}
Expand Down
2 changes: 1 addition & 1 deletion web.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

// Version 当前框架的版本
const Version = "0.87.1"
const Version = "0.87.2"

type (
Logger = logs.Logger
Expand Down

0 comments on commit 1a1951e

Please sign in to comment.