diff --git a/.github/workflows/qa.yaml b/.github/workflows/qa.yaml index 6544651..2e10fff 100644 --- a/.github/workflows/qa.yaml +++ b/.github/workflows/qa.yaml @@ -20,6 +20,8 @@ jobs: steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 + - name: Start Postgresql + run: make db - name: Run tests run: go test -v ./... swagger: @@ -41,7 +43,7 @@ jobs: - uses: actions/checkout@v4 - uses: sqlc-dev/setup-sqlc@v4 with: - sqlc-version: '1.26.0' + sqlc-version: '1.27.0' - name: Vet run: sqlc vet - name: Diff diff --git a/Makefile b/Makefile index a43b797..81fe421 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,10 @@ -.PHONY: swagger docs +.PHONY: swagger docs db docs: swag init --generalInfo cmd/main.go --outputTypes=yaml swagger: docker run --detach --name eryth-swagger -p 4000:8080 -e API_URL=/doc/swagger.yaml --mount 'type=bind,src=$(shell pwd)/docs,dst=/usr/share/nginx/html/doc' swaggerapi/swagger-ui + +db: + docker compose up --wait db diff --git a/cmd/main.go b/cmd/main.go index 5be8fa2..0c132ed 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/crazybolillo/eryth/internal/bouncer" "github.com/crazybolillo/eryth/internal/handler" + "github.com/crazybolillo/eryth/internal/service" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" @@ -56,7 +57,7 @@ func serve(ctx context.Context) error { conn, err := pgx.Connect(ctx, os.Getenv("DATABASE_URL")) if err != nil { - slog.Error("failed to establish database connection") + slog.Error("failed to establish database connection", slog.String("reason", err.Error())) return err } defer conn.Close(ctx) @@ -73,7 +74,7 @@ func serve(ctx context.Context) error { })) r.Use(middleware.AllowContentEncoding("application/json")) - endpoint := handler.Endpoint{Conn: conn} + endpoint := handler.Endpoint{Service: &service.EndpointService{Cursor: conn}} r.Mount("/endpoints", endpoint.Router()) checker := &bouncer.Bouncer{Conn: conn} diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 0000000..f00084d --- /dev/null +++ b/compose.yaml @@ -0,0 +1,17 @@ +name: eryth +services: + db: + image: postgres:15-alpine + ports: + - '54321:5432' + environment: + - POSTGRES_USER=go + - POSTGRES_PASSWORD=go + - POSTGRES_DB=eryth + volumes: + - ./db/migrations/:/docker-entrypoint-initdb.d + healthcheck: + test: [ "CMD-SHELL", "pg_isready -U go" ] + interval: 1s + timeout: 1s + retries: 10 diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 7df916f..7fc0277 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -15,8 +15,10 @@ definitions: extension: type: string type: object - handler.createEndpointRequest: + model.Endpoint: properties: + accountCode: + type: string codecs: items: type: string @@ -31,17 +33,24 @@ definitions: type: string maxContacts: type: integer - password: - type: string + sid: + type: integer transport: type: string type: object - handler.getEndpointResponse: + model.EndpointPage: properties: - codecs: + endpoints: items: - type: string + $ref: '#/definitions/model.EndpointPageEntry' type: array + retrieved: + type: integer + total: + type: integer + type: object + model.EndpointPageEntry: + properties: context: type: string displayName: @@ -50,15 +59,17 @@ definitions: type: string id: type: string - maxContacts: - type: integer sid: type: integer - transport: - type: string type: object - handler.listEndpointEntry: + model.NewEndpoint: properties: + accountCode: + type: string + codecs: + items: + type: string + type: array context: type: string displayName: @@ -67,21 +78,14 @@ definitions: type: string id: type: string - sid: - type: integer - type: object - handler.listEndpointsResponse: - properties: - endpoints: - items: - $ref: '#/definitions/handler.listEndpointEntry' - type: array - retrieved: - type: integer - total: + maxContacts: type: integer + password: + type: string + transport: + type: string type: object - handler.updateEndpointRequest: + model.PatchedEndpoint: properties: codecs: items: @@ -152,7 +156,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/handler.listEndpointsResponse' + $ref: '#/definitions/model.EndpointPage' "400": description: Bad Request "500": @@ -169,12 +173,12 @@ paths: name: payload required: true schema: - $ref: '#/definitions/handler.createEndpointRequest' + $ref: '#/definitions/model.NewEndpoint' responses: "201": description: Created schema: - $ref: '#/definitions/handler.getEndpointResponse' + $ref: '#/definitions/model.Endpoint' "400": description: Bad Request "500": @@ -213,7 +217,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/handler.getEndpointResponse' + $ref: '#/definitions/model.Endpoint' "400": description: Bad Request "500": @@ -232,7 +236,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/handler.updateEndpointRequest' + $ref: '#/definitions/model.PatchedEndpoint' "400": description: Bad Request "404": diff --git a/internal/handler/endpoint.go b/internal/handler/endpoint.go index 44b7089..93dc753 100644 --- a/internal/handler/endpoint.go +++ b/internal/handler/endpoint.go @@ -1,72 +1,20 @@ package handler import ( - "crypto/md5" - "encoding/hex" "encoding/json" "errors" - "fmt" - "github.com/crazybolillo/eryth/internal/db" + "github.com/crazybolillo/eryth/internal/model" "github.com/crazybolillo/eryth/internal/query" - "github.com/crazybolillo/eryth/internal/sqlc" + "github.com/crazybolillo/eryth/internal/service" "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5" "log/slog" "net/http" "strconv" - "strings" ) -const defaultRealm = "asterisk" - type Endpoint struct { - *pgx.Conn -} - -type createEndpointRequest struct { - ID string `json:"id"` - Password string `json:"password"` - Transport string `json:"transport,omitempty"` - Context string `json:"context"` - Codecs []string `json:"codecs"` - MaxContacts int32 `json:"maxContacts,omitempty"` - Extension string `json:"extension,omitempty"` - DisplayName string `json:"displayName"` -} - -type listEndpointEntry struct { - Sid int32 `json:"sid"` - ID string `json:"id"` - Extension string `json:"extension"` - Context string `json:"context"` - DisplayName string `json:"displayName"` -} - -type listEndpointsResponse struct { - Total int64 `json:"total"` - Retrieved int `json:"retrieved"` - Endpoints []listEndpointEntry `json:"endpoints"` -} - -type getEndpointResponse struct { - Sid int32 `json:"sid"` - ID string `json:"id"` - DisplayName string `json:"displayName"` - Transport string `json:"transport"` - Context string `json:"context"` - Codecs []string `json:"codecs"` - MaxContacts int32 `json:"maxContacts"` - Extension string `json:"extension"` -} - -type updateEndpointRequest struct { - Password *string `json:"password,omitempty"` - DisplayName *string `json:"displayName,omitempty"` - Transport *string `json:"transport,omitempty"` - Context *string `json:"context,omitempty"` - Codecs []string `json:"codecs,omitempty"` - MaxContacts *int32 `json:"maxContacts,omitempty"` - Extension *string `json:"extension,omitempty"` + Service *service.EndpointService } func (e *Endpoint) Router() chi.Router { @@ -80,36 +28,10 @@ func (e *Endpoint) Router() chi.Router { return r } -// displayNameFromClid extracts the display name from a Caller ID. It is expected for the Caller ID to be in -// the following format: "Display Name" -// If no display name is found, an empty string is returned. -func displayNameFromClid(callerID string) string { - if callerID == "" { - return "" - } - - start := strings.Index(callerID, `"`) - if start != 0 { - return "" - } - - end := strings.LastIndex(callerID, `"`) - if end == -1 || end < 1 { - return "" - } - - return callerID[1:end] -} - -func hashPassword(user, password, realm string) string { - hash := md5.Sum([]byte(user + ":" + realm + ":" + password)) - return hex.EncodeToString(hash[:]) -} - // @Summary Get information from a specific endpoint. // @Param sid path int true "Requested endpoint's sid" // @Produce json -// @Success 200 {object} getEndpointResponse +// @Success 200 {object} model.Endpoint // @Failure 400 // @Failure 500 // @Tags endpoints @@ -126,35 +48,16 @@ func (e *Endpoint) get(w http.ResponseWriter, r *http.Request) { return } - tx, err := e.Begin(r.Context()) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - defer tx.Rollback(r.Context()) - - queries := sqlc.New(tx) - - row, err := queries.GetEndpointByID(r.Context(), int32(id)) + endpoint, err := e.Service.Read(r.Context(), int32(id)) if errors.Is(err, pgx.ErrNoRows) { w.WriteHeader(http.StatusNotFound) return } else if err != nil { w.WriteHeader(http.StatusInternalServerError) - slog.Error("Failed to retrieve endpoint", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + slog.Error("Failed to read endpoint", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) return } - endpoint := getEndpointResponse{ - Sid: int32(id), - ID: row.ID, - Transport: row.Transport.String, - Context: row.Context.String, - Codecs: strings.Split(row.Allow.String, ","), - MaxContacts: row.MaxContacts.Int32, - Extension: row.Extension.String, - DisplayName: displayNameFromClid(row.Callerid.String), - } content, err := json.Marshal(endpoint) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -173,7 +76,7 @@ func (e *Endpoint) get(w http.ResponseWriter, r *http.Request) { // @Param page query int false "Zero based page to fetch" default(0) // @Param pageSize query int false "Max amount of results to be returned" default(10) // @Produce json -// @Success 200 {object} listEndpointsResponse +// @Success 200 {object} model.EndpointPage // @Failure 400 // @Failure 500 // @Tags endpoints @@ -191,43 +94,14 @@ func (e *Endpoint) list(w http.ResponseWriter, r *http.Request) { return } - queries := sqlc.New(e.Conn) - rows, err := queries.ListEndpoints(r.Context(), sqlc.ListEndpointsParams{ - Limit: int32(pageSize), - Offset: int32(page * pageSize), - }) + result, err := e.Service.Paginate(r.Context(), page, pageSize) if err != nil { - slog.Error("Query execution failed", slog.String("path", r.URL.Path), slog.String("msg", err.Error())) - w.WriteHeader(http.StatusInternalServerError) - return - } - if rows == nil { - rows = []sqlc.ListEndpointsRow{} - } - total, err := queries.CountEndpoints(r.Context()) - if err != nil { - slog.Error("Query execution failed", slog.String("path", r.URL.Path), slog.String("msg", err.Error())) + slog.Error("Failed to list endpoints", slog.String("path", r.URL.Path), slog.String("msg", err.Error())) w.WriteHeader(http.StatusInternalServerError) return } - endpoints := make([]listEndpointEntry, len(rows)) - for idx := range len(rows) { - row := rows[idx] - endpoints[idx] = listEndpointEntry{ - Sid: row.Sid, - ID: row.ID, - Extension: row.Extension.String, - Context: row.Context.String, - DisplayName: displayNameFromClid(row.Callerid.String), - } - } - response := listEndpointsResponse{ - Total: total, - Retrieved: len(rows), - Endpoints: endpoints, - } - content, err := json.Marshal(response) + content, err := json.Marshal(result) if err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -236,22 +110,21 @@ func (e *Endpoint) list(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, err = w.Write(content) if err != nil { - slog.Error("Failed to write response", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + slog.Error("Failed to marshall endpoint list", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) } - w.WriteHeader(http.StatusOK) } // @Summary Create a new endpoint. // @Accept json -// @Param payload body createEndpointRequest true "Endpoint's information" -// @Success 201 {object} getEndpointResponse +// @Param payload body model.NewEndpoint true "Endpoint's information" +// @Success 201 {object} model.Endpoint // @Failure 400 // @Failure 500 // @Tags endpoints // @Router /endpoints [post] func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { decoder := json.NewDecoder(r.Body) - payload := createEndpointRequest{ + payload := model.NewEndpoint{ MaxContacts: 1, } @@ -261,92 +134,13 @@ func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { return } - tx, err := e.Begin(r.Context()) + endpoint, err := e.Service.Create(r.Context(), payload) if err != nil { w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to create endpoint", slog.String("reason", err.Error())) return } - defer tx.Rollback(r.Context()) - queries := sqlc.New(tx) - - err = queries.NewMD5Auth(r.Context(), sqlc.NewMD5AuthParams{ - ID: payload.ID, - Username: db.Text(payload.ID), - Realm: db.Text(defaultRealm), - Md5Cred: db.Text(hashPassword(payload.ID, payload.Password, defaultRealm)), - }) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - sid, err := queries.NewEndpoint(r.Context(), sqlc.NewEndpointParams{ - ID: payload.ID, - Transport: db.Text(payload.Transport), - Context: db.Text(payload.Context), - Allow: db.Text(strings.Join(payload.Codecs, ",")), - Callerid: db.Text(fmt.Sprintf(`"%s" <%s>`, payload.DisplayName, payload.ID)), - }) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - err = queries.NewAOR(r.Context(), sqlc.NewAORParams{ - ID: payload.ID, - MaxContacts: db.Int4(payload.MaxContacts), - }) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - if payload.Extension != "" { - err = queries.NewExtension(r.Context(), sqlc.NewExtensionParams{ - EndpointID: sid, - Extension: db.Text(payload.Extension), - }) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - } - - err = tx.Commit(r.Context()) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - // TODO: Duplicate code, same as when fetching endpoint. Probably should put this into a service layer. - tx, err = e.Begin(r.Context()) - queries = sqlc.New(tx) - if err != nil { - slog.Error("Failed to create new transaction", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) - w.WriteHeader(http.StatusInternalServerError) - return - } - res, err := queries.GetEndpointByID(r.Context(), sid) - if err != nil { - slog.Error( - "Failed to retrieve created endpoint", - slog.String("path", r.URL.Path), slog.String("reason", err.Error()), slog.Int("sid", int(sid)), - ) - w.WriteHeader(http.StatusInternalServerError) - return - } - - endpoint := getEndpointResponse{ - Sid: sid, - ID: res.ID, - Transport: res.Transport.String, - Context: res.Context.String, - Codecs: strings.Split(res.Allow.String, ","), - MaxContacts: res.MaxContacts.Int32, - Extension: res.Extension.String, - DisplayName: displayNameFromClid(res.Callerid.String), - } content, err := json.Marshal(endpoint) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -354,12 +148,13 @@ func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { return } + w.WriteHeader(http.StatusCreated) w.Header().Set("Content-Type", "application/json") _, err = w.Write(content) if err != nil { slog.Error("Failed to write response", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) } - w.WriteHeader(http.StatusCreated) + } // @Summary Delete an endpoint and its associated resources. @@ -377,37 +172,10 @@ func (e *Endpoint) delete(w http.ResponseWriter, r *http.Request) { return } - tx, err := e.Begin(r.Context()) + err = e.Service.Delete(r.Context(), int32(sid)) if err != nil { w.WriteHeader(http.StatusInternalServerError) - return - } - defer tx.Rollback(r.Context()) - - queries := sqlc.New(tx) - - id, err := queries.DeleteEndpoint(r.Context(), int32(sid)) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - err = queries.DeleteAOR(r.Context(), id) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - err = queries.DeleteAuth(r.Context(), id) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - err = tx.Commit(r.Context()) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return + slog.Error("Failed to delete endpoint", slog.String("reason", err.Error())) } w.WriteHeader(http.StatusNoContent) @@ -415,7 +183,7 @@ func (e *Endpoint) delete(w http.ResponseWriter, r *http.Request) { // @Summary Update the specified endpoint. Omitted or null fields will remain unchanged. // @Param sid path int true "Sid of the endpoint to be updated" -// @Success 200 {object} updateEndpointRequest +// @Success 200 {object} model.PatchedEndpoint // @Failure 400 // @Failure 404 // @Failure 500 @@ -423,7 +191,7 @@ func (e *Endpoint) delete(w http.ResponseWriter, r *http.Request) { // @Router /endpoints/{sid} [patch] func (e *Endpoint) update(w http.ResponseWriter, r *http.Request) { decoder := json.NewDecoder(r.Body) - var payload updateEndpointRequest + var payload model.PatchedEndpoint err := decoder.Decode(&payload) if err != nil { @@ -438,143 +206,14 @@ func (e *Endpoint) update(w http.ResponseWriter, r *http.Request) { return } - tx, err := e.Begin(r.Context()) + endpoint, err := e.Service.Update(r.Context(), int32(sid), payload) if err != nil { - slog.Error("Failed to start transaction", slog.String("reason", err.Error()), slog.String("path", r.URL.Path)) - w.WriteHeader(http.StatusInternalServerError) - return - } - - queries := sqlc.New(tx) - endpoint, err := queries.GetEndpointByID(r.Context(), int32(sid)) - if errors.Is(err, pgx.ErrNoRows) { - w.WriteHeader(http.StatusNotFound) - return - } else if err != nil { w.WriteHeader(http.StatusInternalServerError) - slog.Error("Failed to retrieve endpoint", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + slog.Error("Failed to update endpoint", slog.String("reason", err.Error())) return } - // Sorry for the incoming boilerplate but no dynamic SQL yet - var patchedEndpoint = sqlc.UpdateEndpointBySidParams{Sid: int32(sid)} - if payload.DisplayName != nil { - if *payload.DisplayName == "" { - patchedEndpoint.Callerid = db.Text("") - } else { - patchedEndpoint.Callerid = db.Text(fmt.Sprintf(`"%s" <%s>`, *payload.DisplayName, endpoint.ID)) - } - } else { - patchedEndpoint.Callerid = endpoint.Callerid - } - if payload.Context != nil { - patchedEndpoint.Context = db.Text(*payload.Context) - } else { - patchedEndpoint.Context = endpoint.Context - } - if payload.Transport != nil { - patchedEndpoint.Transport = db.Text(*payload.Transport) - } else { - patchedEndpoint.Transport = endpoint.Transport - } - if payload.Codecs != nil { - patchedEndpoint.Allow = db.Text(strings.Join(payload.Codecs, ",")) - } else { - patchedEndpoint.Allow = endpoint.Allow - } - err = queries.UpdateEndpointBySid(r.Context(), patchedEndpoint) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - slog.Error("Failed to update endpoint", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) - return - } - - if payload.MaxContacts != nil { - err = queries.UpdateAORById( - r.Context(), - sqlc.UpdateAORByIdParams{ - ID: endpoint.ID, - MaxContacts: db.Int4(*payload.MaxContacts), - }, - ) - } - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - slog.Error("Failed to update AOR", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) - return - } - - if payload.Extension != nil { - err = queries.UpdateExtensionByEndpointId( - r.Context(), - sqlc.UpdateExtensionByEndpointIdParams{ - EndpointID: int32(sid), - Extension: db.Text(*payload.Extension), - }, - ) - } - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - slog.Error("Failed to update extension", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) - return - } - - if payload.Password != nil { - if len(*payload.Password) < 12 { - w.WriteHeader(http.StatusBadRequest) - slog.Info("Invalid password provided", slog.String("path", r.URL.Path)) - return - } - err = queries.UpdateMD5AuthById( - r.Context(), - sqlc.UpdateMD5AuthByIdParams{ - ID: endpoint.ID, - Md5Cred: db.Text(hashPassword(endpoint.ID, *payload.Password, defaultRealm)), - }, - ) - } - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - slog.Error("Failed to update password", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) - return - } - - err = tx.Commit(r.Context()) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - slog.Error("Failed to commit update", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) - return - } - - // TODO: Duplicate code, same as when fetching endpoint. Probably should put this into a service layer. - tx, err = e.Begin(r.Context()) - queries = sqlc.New(tx) - if err != nil { - slog.Error("Failed to create new transaction", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) - w.WriteHeader(http.StatusInternalServerError) - return - } - res, err := queries.GetEndpointByID(r.Context(), int32(sid)) - if err != nil { - slog.Error( - "Failed to retrieve created endpoint", - slog.String("path", r.URL.Path), slog.String("reason", err.Error()), slog.Int("sid", int(sid)), - ) - w.WriteHeader(http.StatusInternalServerError) - return - } - - result := getEndpointResponse{ - Sid: int32(sid), - ID: res.ID, - Transport: res.Transport.String, - Context: res.Context.String, - Codecs: strings.Split(res.Allow.String, ","), - MaxContacts: res.MaxContacts.Int32, - Extension: res.Extension.String, - DisplayName: displayNameFromClid(res.Callerid.String), - } - content, err := json.Marshal(result) + content, err := json.Marshal(endpoint) if err != nil { w.WriteHeader(http.StatusInternalServerError) slog.Error("Failed to marshall response", slog.String("path", r.URL.Path)) @@ -586,5 +225,4 @@ func (e *Endpoint) update(w http.ResponseWriter, r *http.Request) { if err != nil { slog.Error("Failed to write response", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) } - w.WriteHeader(http.StatusOK) } diff --git a/internal/handler/endpoint_test.go b/internal/handler/endpoint_test.go index e32ba43..d376d9c 100644 --- a/internal/handler/endpoint_test.go +++ b/internal/handler/endpoint_test.go @@ -1,44 +1,220 @@ package handler -import "testing" +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/crazybolillo/eryth/internal/model" + "github.com/crazybolillo/eryth/internal/service" + "github.com/jackc/pgx/v5" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) -func TestDisplayNameFromClid(t *testing.T) { - cases := map[string]struct { - callerID string - displayName string +func TestEndpointAPI(t *testing.T) { + cases := []struct { + name string + test func(handler http.Handler) func(*testing.T) }{ - "valid": { - `"Kiwi Snow" `, - "Kiwi Snow", - }, - "empty": { - "", - "", - }, - "single_colon": { - `"`, - "", - }, - "empty_quotes": { - `""`, - "", - }, - "missing_start_quote": { - `John Smith" `, - "", - }, - "missing_end_quote": { - `"John Smith `, - "", - }, - } - - for name, tt := range cases { - t.Run(name, func(t *testing.T) { - got := displayNameFromClid(tt.callerID) - if got != tt.displayName { - t.Errorf("got %q, want %q", got, tt.displayName) - } + {"Create", MustCreate}, + {"Delete", MustDelete}, + {"Read", MustRead}, + {"Update", MustUpdate}, + } + + conn, err := pgx.Connect(context.Background(), "postgres://go:go@127.0.0.1:54321/eryth") + if err != nil { + t.Fatalf( + "Connection to test database failed: %s. Try running 'make db' and run the tests again", + err, + ) + } + defer func(conn *pgx.Conn, ctx context.Context) { + err := conn.Close(ctx) + if err != nil { + t.Error("Failed to close db connection") + } + }(conn, context.Background()) + + for _, tt := range cases { + tx, err := conn.Begin(context.Background()) + if err != nil { + t.Fatalf("Transaction start failed: %s", err) + } + + handler := Endpoint{Service: &service.EndpointService{Cursor: tx}} + t.Run(tt.name, tt.test(handler.Router())) + + err = tx.Rollback(context.Background()) + if err != nil { + t.Fatalf("Failed to rollback transaction: %s", err) + } + } +} + +func createEndpoint(t *testing.T, handler http.Handler, endpoint model.NewEndpoint) *httptest.ResponseRecorder { + payload, err := json.Marshal(endpoint) + if err != nil { + t.Errorf("failed to marshal new endpoint: %s", err) + } + + req := httptest.NewRequest("POST", "/", bytes.NewReader(payload)) + req.Header.Set("Content-Type", "application/json") + res := httptest.NewRecorder() + handler.ServeHTTP(res, req) + + return res +} + +func readEndpoint(handler http.Handler, sid int32) *httptest.ResponseRecorder { + req := httptest.NewRequest("GET", fmt.Sprintf("/%d", sid), nil) + res := httptest.NewRecorder() + handler.ServeHTTP(res, req) + + return res +} + +func updateEndpoint(t *testing.T, handler http.Handler, sid int32, endpoint model.PatchedEndpoint) *httptest.ResponseRecorder { + payload, err := json.Marshal(endpoint) + if err != nil { + t.Errorf("failed to marshal new endpoint: %s", err) + } + req := httptest.NewRequest("PATCH", fmt.Sprintf("/%d", sid), bytes.NewReader(payload)) + res := httptest.NewRecorder() + handler.ServeHTTP(res, req) + + return res +} + +func parseEndpoint(t *testing.T, content *bytes.Buffer) model.Endpoint { + var createdEndpoint model.Endpoint + decoder := json.NewDecoder(content) + err := decoder.Decode(&createdEndpoint) + if err != nil { + t.Errorf("failed to parse endpoint: %s", err) + } + + return createdEndpoint +} + +func MustCreate(handler http.Handler) func(*testing.T) { + return func(t *testing.T) { + endpoint := model.NewEndpoint{ + ID: "zinniaelegans", + Password: "verylongandsafepassword", + Context: "flowers", + Codecs: []string{"ulaw", "g722"}, + Extension: "1234", + DisplayName: "Zinnia Elegans", + MaxContacts: 10, + } + res := createEndpoint(t, handler, endpoint) + if res.Code != http.StatusCreated { + t.Errorf("invalid http code, got %d, want %d", res.Code, http.StatusCreated) + } + got := parseEndpoint(t, res.Body) + + want := model.Endpoint{ + Sid: got.Sid, + AccountCode: "zinniaelegans", + ID: endpoint.ID, + DisplayName: endpoint.DisplayName, + Transport: endpoint.Transport, + Context: endpoint.Context, + Codecs: endpoint.Codecs, + MaxContacts: 10, + Extension: "1234", + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("Created endpoint does not match request, got %v, want %v", got, want) + } + } +} + +func MustRead(handler http.Handler) func(*testing.T) { + return func(t *testing.T) { + endpoint := model.NewEndpoint{ + ID: "kiwi", + Password: "kiwipassword123", + Context: "fruits", + Codecs: nil, + Extension: "9000", + DisplayName: "Blue Kiwi", + } + res := createEndpoint(t, handler, endpoint) + want := parseEndpoint(t, res.Body) + + res = readEndpoint(handler, want.Sid) + if res.Code != http.StatusOK { + t.Errorf("invalid http code, got %d, want %d", res.Code, http.StatusOK) + } + got := parseEndpoint(t, res.Body) + + if !reflect.DeepEqual(want, got) { + t.Errorf("read endpoint does not match want, got %v, want %v", got, want) + } + } +} + +func MustDelete(handler http.Handler) func(*testing.T) { + return func(t *testing.T) { + endpoint := model.NewEndpoint{ + ID: "testuser", + Password: "testpassword123$", + Context: "internal", + Codecs: nil, + Extension: "4000", + DisplayName: "Mr. Test User", + } + + res := createEndpoint(t, handler, endpoint) + createdEndpoint := parseEndpoint(t, res.Body) + + req := httptest.NewRequest("DELETE", fmt.Sprintf("/%d", createdEndpoint.Sid), nil) + res = httptest.NewRecorder() + handler.ServeHTTP(res, req) + + if res.Code != http.StatusNoContent { + t.Errorf("invalid http code, got %d, want %d", res.Code, http.StatusNoContent) + } + + res = readEndpoint(handler, createdEndpoint.Sid) + if res.Code != http.StatusNotFound { + t.Errorf("invalid http code, got %d, want %d", res.Code, http.StatusNotFound) + } + } +} + +func MustUpdate(handler http.Handler) func(*testing.T) { + return func(t *testing.T) { + endpoint := model.NewEndpoint{ + ID: "big_chungus", + Password: "big_chungus_password", + Context: "memes", + Codecs: []string{"ulaw", "opus"}, + Extension: "5061", + DisplayName: "Big Chungus", + } + res := createEndpoint(t, handler, endpoint) + want := parseEndpoint(t, res.Body) + want.MaxContacts = 5 + want.Extension = "6072" + + res = updateEndpoint(t, handler, want.Sid, model.PatchedEndpoint{ + MaxContacts: &want.MaxContacts, + Extension: &want.Extension, }) + if res.Code != http.StatusOK { + t.Errorf("invalid http code, got %d, want %d", res.Code, http.StatusOK) + } + got := parseEndpoint(t, res.Body) + + if !reflect.DeepEqual(got, want) { + t.Errorf("inconsistent update result, got %v, want %v", got, want) + } } } diff --git a/internal/model/endpoint.go b/internal/model/endpoint.go new file mode 100644 index 0000000..5e3cf05 --- /dev/null +++ b/internal/model/endpoint.go @@ -0,0 +1,49 @@ +package model + +type Endpoint struct { + Sid int32 `json:"sid"` + ID string `json:"id"` + AccountCode string `json:"accountCode"` + DisplayName string `json:"displayName"` + Transport string `json:"transport"` + Context string `json:"context"` + Codecs []string `json:"codecs"` + MaxContacts int32 `json:"maxContacts"` + Extension string `json:"extension"` +} + +type NewEndpoint struct { + ID string `json:"id"` + AccountCode string `json:"accountCode"` + Password string `json:"password"` + Transport string `json:"transport"` + Context string `json:"context"` + Codecs []string `json:"codecs"` + MaxContacts int32 `json:"maxContacts"` + Extension string `json:"extension"` + DisplayName string `json:"displayName"` +} + +type PatchedEndpoint struct { + Password *string `json:"password,"` + DisplayName *string `json:"displayName,"` + Transport *string `json:"transport,"` + Context *string `json:"context,"` + Codecs []string `json:"codecs,"` + MaxContacts *int32 `json:"maxContacts,"` + Extension *string `json:"extension,"` +} + +type EndpointPageEntry struct { + Sid int32 `json:"sid"` + ID string `json:"id"` + Extension string `json:"extension"` + Context string `json:"context"` + DisplayName string `json:"displayName"` +} + +type EndpointPage struct { + Total int64 `json:"total"` + Retrieved int `json:"retrieved"` + Endpoints []EndpointPageEntry `json:"endpoints"` +} diff --git a/internal/service/endpoint.go b/internal/service/endpoint.go new file mode 100644 index 0000000..af480d6 --- /dev/null +++ b/internal/service/endpoint.go @@ -0,0 +1,282 @@ +package service + +import ( + "cmp" + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "github.com/crazybolillo/eryth/internal/db" + "github.com/crazybolillo/eryth/internal/model" + "github.com/crazybolillo/eryth/internal/sqlc" + "strings" +) + +const defaultRealm = "asterisk" + +type EndpointService struct { + Cursor +} + +func hashPassword(user, password, realm string) string { + hash := md5.Sum([]byte(user + ":" + realm + ":" + password)) + return hex.EncodeToString(hash[:]) +} + +// displayNameFromClid extracts the display name from a Caller ID. It is expected for the Caller ID to be in +// the following format: "Display Name" +// If no display name is found, an empty string is returned. +func displayNameFromClid(callerID string) string { + if callerID == "" { + return "" + } + + start := strings.Index(callerID, `"`) + if start != 0 { + return "" + } + + end := strings.LastIndex(callerID, `"`) + if end == -1 || end < 1 { + return "" + } + + return callerID[1:end] +} + +func (e *EndpointService) Create(ctx context.Context, payload model.NewEndpoint) (model.Endpoint, error) { + tx, err := e.Begin(ctx) + if err != nil { + return model.Endpoint{}, err + } + defer tx.Rollback(ctx) + + queries := sqlc.New(tx) + err = queries.NewMD5Auth(ctx, sqlc.NewMD5AuthParams{ + ID: payload.ID, + Username: db.Text(payload.ID), + Realm: db.Text(defaultRealm), + Md5Cred: db.Text(hashPassword(payload.ID, payload.Password, defaultRealm)), + }) + if err != nil { + return model.Endpoint{}, err + } + + sid, err := queries.NewEndpoint(ctx, sqlc.NewEndpointParams{ + ID: payload.ID, + Accountcode: db.Text(cmp.Or(payload.AccountCode, payload.ID)), + Transport: db.Text(payload.Transport), + Context: db.Text(payload.Context), + Allow: db.Text(strings.Join(payload.Codecs, ",")), + Callerid: db.Text(fmt.Sprintf(`"%s" <%s>`, payload.DisplayName, payload.ID)), + }) + if err != nil { + return model.Endpoint{}, err + } + + err = queries.NewAOR(ctx, sqlc.NewAORParams{ + ID: payload.ID, + MaxContacts: db.Int4(payload.MaxContacts), + }) + if err != nil { + return model.Endpoint{}, err + } + + if payload.Extension != "" { + err = queries.NewExtension(ctx, sqlc.NewExtensionParams{ + EndpointID: sid, + Extension: db.Text(payload.Extension), + }) + if err != nil { + return model.Endpoint{}, err + } + } + + err = tx.Commit(ctx) + if err != nil { + return model.Endpoint{}, err + } + + return e.Read(ctx, sid) +} + +func (e *EndpointService) Read(ctx context.Context, sid int32) (model.Endpoint, error) { + queries := sqlc.New(e.Cursor) + row, err := queries.GetEndpointByID(ctx, sid) + if err != nil { + return model.Endpoint{}, err + } + + endpoint := model.Endpoint{ + Sid: sid, + ID: row.ID, + AccountCode: row.Accountcode.String, + DisplayName: displayNameFromClid(row.Callerid.String), + Transport: row.Transport.String, + Context: row.Context.String, + Codecs: strings.Split(row.Allow.String, ","), + MaxContacts: row.MaxContacts.Int32, + Extension: row.Extension.String, + } + return endpoint, nil +} + +func (e *EndpointService) Update(ctx context.Context, sid int32, payload model.PatchedEndpoint) (model.Endpoint, error) { + tx, err := e.Begin(ctx) + if err != nil { + return model.Endpoint{}, err + } + + queries := sqlc.New(tx) + endpoint, err := queries.GetEndpointByID(ctx, sid) + if err != nil { + return model.Endpoint{}, err + } + + // Sorry for the incoming boilerplate but no dynamic SQL yet + var patchedEndpoint = sqlc.UpdateEndpointBySidParams{Sid: int32(sid)} + if payload.DisplayName != nil { + if *payload.DisplayName == "" { + patchedEndpoint.Callerid = db.Text("") + } else { + patchedEndpoint.Callerid = db.Text(fmt.Sprintf(`"%s" <%s>`, *payload.DisplayName, endpoint.ID)) + } + } else { + patchedEndpoint.Callerid = endpoint.Callerid + } + if payload.Context != nil { + patchedEndpoint.Context = db.Text(*payload.Context) + } else { + patchedEndpoint.Context = endpoint.Context + } + if payload.Transport != nil { + patchedEndpoint.Transport = db.Text(*payload.Transport) + } else { + patchedEndpoint.Transport = endpoint.Transport + } + if payload.Codecs != nil { + patchedEndpoint.Allow = db.Text(strings.Join(payload.Codecs, ",")) + } else { + patchedEndpoint.Allow = endpoint.Allow + } + err = queries.UpdateEndpointBySid(ctx, patchedEndpoint) + if err != nil { + return model.Endpoint{}, err + } + + if payload.MaxContacts != nil { + err = queries.UpdateAORById( + ctx, + sqlc.UpdateAORByIdParams{ + ID: endpoint.ID, + MaxContacts: db.Int4(*payload.MaxContacts), + }, + ) + } + if err != nil { + return model.Endpoint{}, err + } + + if payload.Extension != nil { + err = queries.UpdateExtensionByEndpointId( + ctx, + sqlc.UpdateExtensionByEndpointIdParams{ + EndpointID: sid, + Extension: db.Text(*payload.Extension), + }, + ) + } + if err != nil { + return model.Endpoint{}, err + } + + if payload.Password != nil { + err = queries.UpdateMD5AuthById( + ctx, + sqlc.UpdateMD5AuthByIdParams{ + ID: endpoint.ID, + Md5Cred: db.Text(hashPassword(endpoint.ID, *payload.Password, defaultRealm)), + }, + ) + } + if err != nil { + return model.Endpoint{}, err + } + + err = tx.Commit(ctx) + if err != nil { + return model.Endpoint{}, err + } + + return e.Read(ctx, sid) +} + +func (e *EndpointService) Delete(ctx context.Context, sid int32) error { + tx, err := e.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + queries := sqlc.New(tx) + + id, err := queries.DeleteEndpoint(ctx, sid) + if err != nil { + return err + } + + err = queries.DeleteAOR(ctx, id) + if err != nil { + return err + } + + err = queries.DeleteAuth(ctx, id) + if err != nil { + return err + } + + err = tx.Commit(ctx) + if err != nil { + return err + } + + return nil +} + +func (e *EndpointService) Paginate(ctx context.Context, page, size int) (model.EndpointPage, error) { + queries := sqlc.New(e.Cursor) + rows, err := queries.ListEndpoints(ctx, sqlc.ListEndpointsParams{ + Limit: int32(size), + Offset: int32(page * size), + }) + if err != nil { + return model.EndpointPage{}, err + } + + count, err := queries.CountEndpoints(ctx) + if err != nil { + return model.EndpointPage{}, err + } + + if rows == nil { + rows = []sqlc.ListEndpointsRow{} + } + + endpoints := make([]model.EndpointPageEntry, len(rows)) + for idx := range len(rows) { + row := rows[idx] + endpoints[idx] = model.EndpointPageEntry{ + Sid: row.Sid, + ID: row.ID, + Extension: row.Extension.String, + Context: row.Context.String, + DisplayName: displayNameFromClid(row.Callerid.String), + } + } + + return model.EndpointPage{ + Total: count, + Retrieved: len(rows), + Endpoints: endpoints, + }, nil +} diff --git a/internal/service/endpoint_test.go b/internal/service/endpoint_test.go new file mode 100644 index 0000000..8a17077 --- /dev/null +++ b/internal/service/endpoint_test.go @@ -0,0 +1,44 @@ +package service + +import "testing" + +func TestDisplayNameFromClid(t *testing.T) { + cases := map[string]struct { + callerID string + displayName string + }{ + "valid": { + `"Kiwi Snow" `, + "Kiwi Snow", + }, + "empty": { + "", + "", + }, + "single_colon": { + `"`, + "", + }, + "empty_quotes": { + `""`, + "", + }, + "missing_start_quote": { + `John Smith" `, + "", + }, + "missing_end_quote": { + `"John Smith `, + "", + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + got := displayNameFromClid(tt.callerID) + if got != tt.displayName { + t.Errorf("got %q, want %q", got, tt.displayName) + } + }) + } +} diff --git a/internal/service/service.go b/internal/service/service.go new file mode 100644 index 0000000..cd59732 --- /dev/null +++ b/internal/service/service.go @@ -0,0 +1,12 @@ +package service + +import ( + "context" + "github.com/crazybolillo/eryth/internal/sqlc" + "github.com/jackc/pgx/v5" +) + +type Cursor interface { + Begin(ctx context.Context) (pgx.Tx, error) + sqlc.DBTX +} diff --git a/internal/sqlc/db.go b/internal/sqlc/db.go index 278c210..b931bc5 100644 --- a/internal/sqlc/db.go +++ b/internal/sqlc/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.27.0 package sqlc diff --git a/internal/sqlc/models.go b/internal/sqlc/models.go index 7f41e32..5e93312 100644 --- a/internal/sqlc/models.go +++ b/internal/sqlc/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.27.0 package sqlc diff --git a/internal/sqlc/queries.sql.go b/internal/sqlc/queries.sql.go index b0750f5..a34dcbc 100644 --- a/internal/sqlc/queries.sql.go +++ b/internal/sqlc/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.27.0 // source: queries.sql package sqlc @@ -83,7 +83,7 @@ func (q *Queries) GetEndpointByExtension(ctx context.Context, arg GetEndpointByE const getEndpointByID = `-- name: GetEndpointByID :one SELECT - pe.id, pe.callerid, pe.context, ee.extension, pe.transport, aor.max_contacts, pe.allow + pe.id, pe.accountcode, pe.callerid, pe.context, ee.extension, pe.transport, aor.max_contacts, pe.allow FROM ps_endpoints pe INNER JOIN @@ -96,6 +96,7 @@ WHERE type GetEndpointByIDRow struct { ID string `json:"id"` + Accountcode pgtype.Text `json:"accountcode"` Callerid pgtype.Text `json:"callerid"` Context pgtype.Text `json:"context"` Extension pgtype.Text `json:"extension"` @@ -109,6 +110,7 @@ func (q *Queries) GetEndpointByID(ctx context.Context, sid int32) (GetEndpointBy var i GetEndpointByIDRow err := row.Scan( &i.ID, + &i.Accountcode, &i.Callerid, &i.Context, &i.Extension, @@ -188,18 +190,19 @@ func (q *Queries) NewAOR(ctx context.Context, arg NewAORParams) error { const newEndpoint = `-- name: NewEndpoint :one INSERT INTO ps_endpoints - (id, transport, aors, auth, context, disallow, allow, callerid) + (id, transport, aors, auth, context, disallow, allow, callerid, accountcode) VALUES - ($1, $2, $1, $1, $3, 'all', $4, $5) + ($1, $2, $1, $1, $3, 'all', $4, $5, $6) RETURNING sid ` type NewEndpointParams struct { - ID string `json:"id"` - Transport pgtype.Text `json:"transport"` - Context pgtype.Text `json:"context"` - Allow pgtype.Text `json:"allow"` - Callerid pgtype.Text `json:"callerid"` + ID string `json:"id"` + Transport pgtype.Text `json:"transport"` + Context pgtype.Text `json:"context"` + Allow pgtype.Text `json:"allow"` + Callerid pgtype.Text `json:"callerid"` + Accountcode pgtype.Text `json:"accountcode"` } func (q *Queries) NewEndpoint(ctx context.Context, arg NewEndpointParams) (int32, error) { @@ -209,6 +212,7 @@ func (q *Queries) NewEndpoint(ctx context.Context, arg NewEndpointParams) (int32 arg.Context, arg.Allow, arg.Callerid, + arg.Accountcode, ) var sid int32 err := row.Scan(&sid) diff --git a/queries.sql b/queries.sql index b76dbc6..a638f58 100644 --- a/queries.sql +++ b/queries.sql @@ -12,9 +12,9 @@ VALUES -- name: NewEndpoint :one INSERT INTO ps_endpoints - (id, transport, aors, auth, context, disallow, allow, callerid) + (id, transport, aors, auth, context, disallow, allow, callerid, accountcode) VALUES - ($1, $2, $1, $1, $3, 'all', $4, $5) + ($1, $2, $1, $1, $3, 'all', $4, $5, $6) RETURNING sid; -- name: DeleteEndpoint :one @@ -56,7 +56,7 @@ WHERE -- name: GetEndpointByID :one SELECT - pe.id, pe.callerid, pe.context, ee.extension, pe.transport, aor.max_contacts, pe.allow + pe.id, pe.accountcode, pe.callerid, pe.context, ee.extension, pe.transport, aor.max_contacts, pe.allow FROM ps_endpoints pe INNER JOIN