Skip to content

Commit

Permalink
feat: Optimize Dockerfile, refactor copy function, and implement grac…
Browse files Browse the repository at this point in the history
…eful shutdown
  • Loading branch information
go-bai committed Oct 13, 2024
1 parent cd36cd7 commit f3bc5a6
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 108 deletions.
14 changes: 9 additions & 5 deletions .github/workflows/build-docker-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ on:
tags:
- '*'

jobs:
env:
GHCR_REGISTRY: ghcr.io
GHCR_USER: go-bai
APP_NAME: http-proxy

jobs:
build:
runs-on: ubuntu-latest
steps:
Expand All @@ -18,18 +22,18 @@ jobs:
run: docker buildx create
--name=multi-builder
--driver=docker-container
--platform=linux/arm64,linux/amd64
--use
--bootstrap
- name: Docker Login
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }}
GHCR_TOKEN: ${{ secrets.GHCR_TOKEN }}
run: |
docker login -u $DOCKER_USER -p $DOCKER_PASSWORD
echo ${{ env.GHCR_TOKEN }} | docker login ${{ env.GHCR_REGISTRY }} --username ${{ env.GHCR_USER }} --password-stdin
- name: Build and Push the Docker image
run: docker buildx build
--platform=linux/arm64,linux/amd64
--tag ${{ secrets.DOCKER_USER }}/http-proxy:$GITHUB_REF_NAME
--tag ${{ env.GHCR_REGISTRY }}/${{ env.GHCR_USER }}/${{ env.APP_NAME }}:$GITHUB_REF_NAME
--file Dockerfile
--push .
- name: Docker Logout
Expand Down
11 changes: 7 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
FROM golang:1.20-alpine AS builder
RUN apk --no-cache add tzdata
RUN apk add upx
FROM --platform=$BUILDPLATFORM golang:1.23.1-alpine AS builder
ARG TARGETOS
ARG TARGETARCH
ENV GO111MODULE=on \
CGO_ENABLED=0
WORKDIR /build
RUN apk --no-cache add tzdata
COPY . .
RUN go mod tidy && go build -ldflags "-s -w" -o main && upx -9 main
RUN GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags "-s -w" -o main


FROM scratch
COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
COPY --from=builder /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
COPY --from=builder /build/main /

ENTRYPOINT ["/main"]
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module github.com/go-bai/http-proxy

go 1.20
go 1.23.1

require github.com/google/uuid v1.3.0
139 changes: 41 additions & 98 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,51 @@ package main

import (
"crypto/tls"
"encoding/base64"
"io"
"log"
"net"
"net/http"
"net/netip"
"os"
"strings"
"runtime"
"sync/atomic"
"time"

"github.com/go-bai/http-proxy/pkg"
"github.com/google/uuid"
)

func handleTunneling(w http.ResponseWriter, r *http.Request) {
destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
}
go transfer(destConn, clientConn)
go transfer(clientConn, destConn)
}

func transfer(destination io.WriteCloser, source io.ReadCloser) {
defer destination.Close()
defer source.Close()
io.Copy(destination, source)
}

func handleHTTP(w http.ResponseWriter, req *http.Request) {
resp, err := http.DefaultTransport.RoundTrip(req)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer resp.Body.Close()
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}

func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
var (
addr = ":38888"
auth = "on"
pass = ""
)

func basicProxyAuth(proxyAuth string) (username, password string, ok bool) {
if proxyAuth == "" {
return
}
const (
authOn = "on"
authOff = "off"
)

if !strings.HasPrefix(proxyAuth, "Basic ") {
return
func init() {
addrEnv, b := os.LookupEnv("HTTP_PROXY_ADDR")
if b {
addr = addrEnv
}
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic "))
if err != nil {
return
authEnv, b := os.LookupEnv("HTTP_PROXY_AUTH")
if b {
auth = authEnv
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
if auth == authOn {
passEnv, b := os.LookupEnv("HTTP_PROXY_PASS")
if b {
pass = passEnv
} else {
pass = uuid.New().String()
}
}

return cs[:s], cs[s+1:], true
}

func handler(w http.ResponseWriter, r *http.Request) {
if auth == authOn {
_, p, ok := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
_, p, ok := pkg.BasicProxyAuth(r.Header.Get("Proxy-Authorization"))
if !ok {
w.Header().Set("Proxy-Authenticate", `Basic realm=Restricted`)
http.Error(w, "proxy auth required", http.StatusProxyAuthRequired)
Expand All @@ -108,39 +70,9 @@ func handler(w http.ResponseWriter, r *http.Request) {
log.Printf("%-15s %-7s %s %s", addrPort.Addr(), r.Method, r.Host, r.URL.Path)

if r.Method == http.MethodConnect {
handleTunneling(w, r)
pkg.HandleTunneling(w, r)
} else {
handleHTTP(w, r)
}
}

var (
addr = ":38888"
auth = "on"
pass = ""
)

const (
authOn = "on"
authOff = "off"
)

func init() {
addrEnv, b := os.LookupEnv("HTTP_PROXY_ADDR")
if b {
addr = addrEnv
}
authEnv, b := os.LookupEnv("HTTP_PROXY_AUTH")
if b {
auth = authEnv
}
if auth == authOn {
passEnv, b := os.LookupEnv("HTTP_PROXY_PASS")
if b {
pass = passEnv
} else {
pass = uuid.New().String()
}
pkg.HandleHTTP(w, r)
}
}

Expand All @@ -151,12 +83,23 @@ func main() {
log.Printf("Password: %s\n", pass)
}

go func() {
for {
time.Sleep(1 * time.Minute)
log.Printf("active connections: %d, goroutine number: %d", atomic.LoadInt64(&pkg.ActiveConnections), runtime.NumGoroutine())
}
}()

server := &http.Server{
Addr: addr,
Handler: http.HandlerFunc(handler),
// Disable HTTP/2.
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
// set timeout for read, write and idle, prevent slowloris attack
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}

log.Fatal(server.ListenAndServe())
pkg.ListenAndServeWithGracefulShutdown(server)
}
27 changes: 27 additions & 0 deletions pkg/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package pkg

import (
"encoding/base64"
"strings"
)

func BasicProxyAuth(proxyAuth string) (username, password string, ok bool) {
if proxyAuth == "" {
return
}

if !strings.HasPrefix(proxyAuth, "Basic ") {
return
}
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic "))
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}

return cs[:s], cs[s+1:], true
}
27 changes: 27 additions & 0 deletions pkg/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package pkg

import (
"io"
"net/http"
)

func HandleHTTP(w http.ResponseWriter, req *http.Request) {
resp, err := http.DefaultTransport.RoundTrip(req)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer resp.Body.Close()
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
// ignore errors
io.Copy(w, resp.Body)
}

func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
53 changes: 53 additions & 0 deletions pkg/serve.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package pkg

import (
"context"
"errors"
"log"
"net/http"
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"time"
)

func ListenAndServeWithGracefulShutdown(server *http.Server) {
shutdownCtx, shutdownCancelFunc = context.WithCancel(context.Background())
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("listen: %s\n", err)
}
}()

<-stop
log.Println("shutting down server...")
shutdownCancelFunc()

// create a deadline for graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// attempt graceful shutdown
if err := server.Shutdown(ctx); err != nil {
log.Printf("error during server shutdown: %s", err.Error())
}

// wait for hijacked connections to finish
for atomic.LoadInt64(&ActiveConnections) > 0 {
log.Printf("waiting for %d hijacked connections to finish", atomic.LoadInt64(&ActiveConnections))
time.Sleep(200 * time.Millisecond)
}

// wait for the server goroutine to finish
wg.Wait()

log.Println("server gracefully stopped")
}
Loading

0 comments on commit f3bc5a6

Please sign in to comment.