diff --git a/app.go b/app.go index e1e607e6fa..9b14463dab 100644 --- a/app.go +++ b/app.go @@ -878,16 +878,18 @@ func (app *App) ShutdownWithTimeout(timeout time.Duration) error { // // ShutdownWithContext does not close keepalive connections so its recommended to set ReadTimeout to something else than 0. func (app *App) ShutdownWithContext(ctx context.Context) error { - if app.hooks != nil { - // TODO: check should be defered? - app.hooks.executeOnShutdownHooks() - } - app.mutex.Lock() defer app.mutex.Unlock() + if app.server == nil { return ErrNotRunning } + + // Execute shutdown hooks in a deferred function + if app.hooks != nil { + defer app.hooks.executeOnShutdownHooks() + } + return app.server.ShutdownWithContext(ctx) } diff --git a/app_test.go b/app_test.go index 9699c85bce..9fff684de8 100644 --- a/app_test.go +++ b/app_test.go @@ -20,6 +20,7 @@ import ( "regexp" "runtime" "strings" + "sync/atomic" "testing" "time" @@ -860,6 +861,12 @@ func Test_App_ShutdownWithContext(t *testing.T) { t.Parallel() app := New() + var shutdownHookCalled atomic.Int32 + app.Hooks().OnShutdown(func() error { + shutdownHookCalled.Store(1) + return nil + }) + app.Get("/", func(ctx Ctx) error { time.Sleep(5 * time.Second) return ctx.SendString("body") @@ -867,24 +874,27 @@ func Test_App_ShutdownWithContext(t *testing.T) { ln := fasthttputil.NewInmemoryListener() + serverErr := make(chan error, 1) go func() { - err := app.Listener(ln) - assert.NoError(t, err) + serverErr <- app.Listener(ln) }() - time.Sleep(1 * time.Second) + time.Sleep(100 * time.Millisecond) + clientDone := make(chan struct{}) go func() { conn, err := ln.Dial() assert.NoError(t, err) - - _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")) + _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")) assert.NoError(t, err) + close(clientDone) }() - time.Sleep(1 * time.Second) + <-clientDone + // Sleep to ensure the server has started processing the request + time.Sleep(100 * time.Millisecond) - shutdownErr := make(chan error) + shutdownErr := make(chan error, 1) go func() { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -892,13 +902,19 @@ func Test_App_ShutdownWithContext(t *testing.T) { }() select { - case <-time.After(5 * time.Second): - t.Fatal("idle connections not closed on shutdown") + case <-time.After(2 * time.Second): + t.Fatal("shutdown did not complete in time") case err := <-shutdownErr: - if err == nil || !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded) - } + require.Error(t, err, "Expected shutdown to return an error due to timeout") + require.ErrorIs(t, err, context.DeadlineExceeded, "Expected DeadlineExceeded error") } + + assert.Equal(t, int32(1), shutdownHookCalled.Load(), "Shutdown hook was not called") + + err := <-serverErr + assert.NoError(t, err, "Server should have shut down without error") + // default: + // Server is still running, which is expected as the long-running request prevented full shutdown } // go test -run Test_App_Mixed_Routes_WithSameLen diff --git a/docs/api/fiber.md b/docs/api/fiber.md index 6892225e11..55566a109d 100644 --- a/docs/api/fiber.md +++ b/docs/api/fiber.md @@ -205,7 +205,7 @@ Shutdown gracefully shuts down the server without interrupting any active connec ShutdownWithTimeout will forcefully close any active connections after the timeout expires. -ShutdownWithContext shuts down the server including by force if the context's deadline is exceeded. +ShutdownWithContext shuts down the server including by force if the context's deadline is exceeded. Shutdown hooks will still be executed, even if an error occurs during the shutdown process, as they are deferred to ensure cleanup happens regardless of errors. ```go func (app *App) Shutdown() error