diff --git a/.github/workflows/qa.yaml b/.github/workflows/qa.yaml index 877c220..1e15020 100644 --- a/.github/workflows/qa.yaml +++ b/.github/workflows/qa.yaml @@ -39,6 +39,33 @@ jobs: run: sqlc vet - name: Diff run: sqlc diff + dbmate: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:15 + env: + POSTGRES_USER: pbx + POSTGRES_DB: asterisk + POSTGRES_PASSWORD: pbx + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + env: + DATABASE_URL: postgres://pbx:pbx@127.0.0.1:5432/asterisk?sslmode=disable + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: 20 + - name: Install dbmate + run: npm install dbmate + - name: Run migrations + run: npx dbmate up docker: runs-on: ubuntu-latest steps: diff --git a/Makefile b/Makefile index 73772d8..a43b797 100644 --- a/Makefile +++ b/Makefile @@ -4,4 +4,4 @@ docs: swag init --generalInfo cmd/main.go --outputTypes=yaml swagger: - docker run --detach -p 4000:8080 -e API_URL=/doc/swagger.yaml --mount 'type=bind,src=$(shell pwd)/docs,dst=/usr/share/nginx/html/doc' swaggerapi/swagger-ui + docker run --detach --name eryth-swagger -p 4000:8080 -e API_URL=/doc/swagger.yaml --mount 'type=bind,src=$(shell pwd)/docs,dst=/usr/share/nginx/html/doc' swaggerapi/swagger-ui diff --git a/cmd/main.go b/cmd/main.go index da70f64..5124677 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "github.com/crazybolillo/eryth/internal/bouncer" "github.com/crazybolillo/eryth/internal/handler" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -75,6 +76,10 @@ func serve(ctx context.Context) error { endpoint := handler.Endpoint{Conn: conn} r.Mount("/endpoint", endpoint.Router()) + checker := &bouncer.Bouncer{Conn: conn} + authorization := handler.Authorization{Bouncer: checker} + r.Mount("/bouncer", authorization.Router()) + slog.Info("Listening on :8080") return http.ListenAndServe(":8080", r) } diff --git a/db/migrations/20240730043324_asterisk_v21_3_1.sql b/db/migrations/20240730043324_asterisk_v21_3_1.sql index ee989ec..cf4d1c2 100644 --- a/db/migrations/20240730043324_asterisk_v21_3_1.sql +++ b/db/migrations/20240730043324_asterisk_v21_3_1.sql @@ -1,4 +1,4 @@ --- migrate:upclear +-- migrate:up CREATE TYPE public.ast_bool_values AS ENUM ( '0', '1', diff --git a/db/migrations/20240730051335_ery_extension.sql b/db/migrations/20240730051335_ery_extension.sql new file mode 100644 index 0000000..dcabf2b --- /dev/null +++ b/db/migrations/20240730051335_ery_extension.sql @@ -0,0 +1,12 @@ +-- migrate:up +ALTER TABLE ps_endpoints ADD COLUMN sid SERIAL PRIMARY KEY; + +CREATE TABLE ery_extension ( + id SERIAL PRIMARY KEY, + endpoint_id SERIAL NOT NULL, + extension varchar UNIQUE, + FOREIGN KEY (endpoint_id) REFERENCES ps_endpoints(sid) +) + +-- migrate:down + diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 88898f2..1f76aeb 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -1,4 +1,18 @@ definitions: + bouncer.Response: + properties: + allow: + type: boolean + destination: + type: string + type: object + handler.AuthorizationRequest: + properties: + endpoint: + type: string + extension: + type: string + type: object handler.createEndpointRequest: properties: codecs: @@ -7,6 +21,8 @@ definitions: type: array context: type: string + extension: + type: string id: type: string max_contacts: @@ -18,13 +34,20 @@ definitions: transport: type: string type: object + handler.listEndpointsRequest: + properties: + endpoints: + items: + $ref: '#/definitions/sqlc.ListEndpointsRow' + type: array + type: object sqlc.ListEndpointsRow: properties: context: type: string - id: + extension: type: string - transport: + id: type: string type: object host: localhost:8080 @@ -34,6 +57,31 @@ info: title: Asterisk Administration API version: "1.0" paths: + /bouncer: + post: + consumes: + - application/json + parameters: + - description: Action to be reviewed + in: body + name: payload + required: true + schema: + $ref: '#/definitions/handler.AuthorizationRequest' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/bouncer.Response' + "400": + description: Bad Request + "500": + description: Internal Server Error + summary: Determine whether the specified action (call) is allowed or not. + tags: + - bouncer /endpoint: post: consumes: @@ -87,9 +135,7 @@ paths: "200": description: OK schema: - items: - $ref: '#/definitions/sqlc.ListEndpointsRow' - type: array + $ref: '#/definitions/handler.listEndpointsRequest' "400": description: Bad Request "500": diff --git a/internal/bouncer/bouncer.go b/internal/bouncer/bouncer.go new file mode 100644 index 0000000..7928f54 --- /dev/null +++ b/internal/bouncer/bouncer.go @@ -0,0 +1,43 @@ +package bouncer + +import ( + "context" + "github.com/crazybolillo/eryth/internal/db" + "github.com/crazybolillo/eryth/internal/sqlc" + "github.com/jackc/pgx/v5" + "log/slog" +) + +type Response struct { + Allow bool `json:"allow"` + Destination string `json:"destination"` +} + +type Bouncer struct { + *pgx.Conn +} + +func (b *Bouncer) Check(ctx context.Context, endpoint, dialed string) Response { + result := Response{ + Allow: false, + Destination: "", + } + + tx, err := b.Begin(ctx) + if err != nil { + slog.Error("Unable to start transaction", slog.String("reason", err.Error())) + return result + } + + queries := sqlc.New(tx) + destination, err := queries.GetEndpointByExtension(ctx, db.Text(dialed)) + if err != nil { + slog.Error("Failed to retrieve endpoint", slog.String("dialed", dialed), slog.String("reason", err.Error())) + return result + } + + return Response{ + Allow: true, + Destination: destination, + } +} diff --git a/internal/handler/authorization.go b/internal/handler/authorization.go new file mode 100644 index 0000000..f01a8e9 --- /dev/null +++ b/internal/handler/authorization.go @@ -0,0 +1,63 @@ +package handler + +import ( + "context" + "encoding/json" + "github.com/crazybolillo/eryth/internal/bouncer" + "github.com/go-chi/chi/v5" + "log/slog" + "net/http" +) + +type CallBouncer interface { + Check(ctx context.Context, endpoint, dialed string) bouncer.Response +} + +type Authorization struct { + Bouncer CallBouncer +} + +type AuthorizationRequest struct { + From string `json:"endpoint"` + Extension string `json:"extension"` +} + +func (e *Authorization) Router() chi.Router { + r := chi.NewRouter() + r.Post("/", e.post) + + return r +} + +// @Summary Determine whether the specified action (call) is allowed or not. +// @Accept json +// @Produce json +// @Param payload body AuthorizationRequest true "Action to be reviewed" +// @Success 200 {object} bouncer.Response +// @Failure 400 +// @Failure 500 +// @Tags bouncer +// @Router /bouncer [post] +func (e *Authorization) post(w http.ResponseWriter, r *http.Request) { + var payload AuthorizationRequest + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&payload) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + response := e.Bouncer.Check(r.Context(), payload.From, payload.Extension) + content, err := json.Marshal(response) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(content) + if err != nil { + slog.Error("Failed to write response", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + } + w.WriteHeader(http.StatusOK) +} diff --git a/internal/handler/endpoint.go b/internal/handler/endpoint.go index a08315c..fb80567 100644 --- a/internal/handler/endpoint.go +++ b/internal/handler/endpoint.go @@ -22,10 +22,15 @@ type createEndpointRequest struct { ID string `json:"id"` Password string `json:"password"` Realm string `json:"realm,omitempty"` - Transport string `json:"transport"` + Transport string `json:"transport,omitempty"` Context string `json:"context"` Codecs []string `json:"codecs"` MaxContacts int32 `json:"max_contacts,omitempty"` + Extension string `json:"extension,omitempty"` +} + +type listEndpointsRequest struct { + Endpoints []sqlc.ListEndpointsRow `json:"endpoints"` } func (e *Endpoint) Router() chi.Router { @@ -40,7 +45,7 @@ func (e *Endpoint) Router() chi.Router { // @Summary List existing endpoints. // @Param limit query int false "Limit the amount of endpoints returned" default(15) // @Produce json -// @Success 200 {object} []sqlc.ListEndpointsRow +// @Success 200 {object} listEndpointsRequest // @Failure 400 // @Failure 500 // @Tags endpoints @@ -60,11 +65,18 @@ func (e *Endpoint) list(w http.ResponseWriter, r *http.Request) { queries := sqlc.New(e.Conn) endpoints, err := queries.ListEndpoints(r.Context(), int32(limit)) if err != nil { + slog.Error("Query execution failed", slog.String("path", r.URL.Path), slog.String("msg", err.Error())) w.WriteHeader(http.StatusInternalServerError) return } + if endpoints == nil { + endpoints = []sqlc.ListEndpointsRow{} + } - content, err := json.Marshal(endpoints) + response := listEndpointsRequest{ + Endpoints: endpoints, + } + content, err := json.Marshal(response) if err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -120,7 +132,7 @@ func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { return } - err = queries.NewEndpoint(r.Context(), sqlc.NewEndpointParams{ + sid, err := queries.NewEndpoint(r.Context(), sqlc.NewEndpointParams{ ID: payload.ID, Transport: db.Text(payload.Transport), Context: db.Text(payload.Context), @@ -140,6 +152,17 @@ func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { return } + if payload.Extension != "" { + err = queries.NewExtension(r.Context(), sqlc.NewExtensionParams{ + EndpointID: sid, + Extension: db.Text(payload.Extension), + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + } + err = tx.Commit(r.Context()) if err != nil { w.WriteHeader(http.StatusInternalServerError) diff --git a/internal/sqlc/models.go b/internal/sqlc/models.go index c973113..7f41e32 100644 --- a/internal/sqlc/models.go +++ b/internal/sqlc/models.go @@ -1589,6 +1589,12 @@ type Cdr struct { ID int64 `json:"id"` } +type EryExtension struct { + ID int32 `json:"id"` + EndpointID int32 `json:"endpoint_id"` + Extension pgtype.Text `json:"extension"` +} + type Extension struct { ID int64 `json:"id"` Context string `json:"context"` @@ -1887,6 +1893,7 @@ type PsEndpoint struct { SecurityMechanisms pgtype.Text `json:"security_mechanisms"` SendAoc NullAstBoolValues `json:"send_aoc"` OverlapContext pgtype.Text `json:"overlap_context"` + Sid int32 `json:"sid"` } type PsEndpointIDIp struct { diff --git a/internal/sqlc/queries.sql.go b/internal/sqlc/queries.sql.go index 942645c..5c8c1a9 100644 --- a/internal/sqlc/queries.sql.go +++ b/internal/sqlc/queries.sql.go @@ -38,18 +38,39 @@ func (q *Queries) DeleteEndpoint(ctx context.Context, id string) error { return err } -const listEndpoints = `-- name: ListEndpoints :many +const getEndpointByExtension = `-- name: GetEndpointByExtension :one SELECT - id, context, transport + ps_endpoints.id FROM ps_endpoints +INNER JOIN + ery_extension ee on ps_endpoints.sid = ee.endpoint_id +WHERE + ee.extension = $1 +` + +func (q *Queries) GetEndpointByExtension(ctx context.Context, extension pgtype.Text) (string, error) { + row := q.db.QueryRow(ctx, getEndpointByExtension, extension) + var id string + err := row.Scan(&id) + return id, err +} + +const listEndpoints = `-- name: ListEndpoints :many +SELECT + pe.id, pe.context, ee.extension +FROM + ps_endpoints pe +LEFT JOIN + ery_extension ee +ON ee.endpoint_id = pe.sid LIMIT $1 ` type ListEndpointsRow struct { ID string `json:"id"` Context pgtype.Text `json:"context"` - Transport pgtype.Text `json:"transport"` + Extension pgtype.Text `json:"extension"` } func (q *Queries) ListEndpoints(ctx context.Context, limit int32) ([]ListEndpointsRow, error) { @@ -61,7 +82,7 @@ func (q *Queries) ListEndpoints(ctx context.Context, limit int32) ([]ListEndpoin var items []ListEndpointsRow for rows.Next() { var i ListEndpointsRow - if err := rows.Scan(&i.ID, &i.Context, &i.Transport); err != nil { + if err := rows.Scan(&i.ID, &i.Context, &i.Extension); err != nil { return nil, err } items = append(items, i) @@ -89,11 +110,12 @@ func (q *Queries) NewAOR(ctx context.Context, arg NewAORParams) error { return err } -const newEndpoint = `-- name: NewEndpoint :exec +const newEndpoint = `-- name: NewEndpoint :one INSERT INTO ps_endpoints (id, transport, aors, auth, context, disallow, allow) VALUES ($1, $2, $1, $1, $3, 'all', $4) +RETURNING sid ` type NewEndpointParams struct { @@ -103,13 +125,32 @@ type NewEndpointParams struct { Allow pgtype.Text `json:"allow"` } -func (q *Queries) NewEndpoint(ctx context.Context, arg NewEndpointParams) error { - _, err := q.db.Exec(ctx, newEndpoint, +func (q *Queries) NewEndpoint(ctx context.Context, arg NewEndpointParams) (int32, error) { + row := q.db.QueryRow(ctx, newEndpoint, arg.ID, arg.Transport, arg.Context, arg.Allow, ) + var sid int32 + err := row.Scan(&sid) + return sid, err +} + +const newExtension = `-- name: NewExtension :exec +INSERT INTO ery_extension + (endpoint_id, extension) +VALUES + ($1, $2) +` + +type NewExtensionParams struct { + EndpointID int32 `json:"endpoint_id"` + Extension pgtype.Text `json:"extension"` +} + +func (q *Queries) NewExtension(ctx context.Context, arg NewExtensionParams) error { + _, err := q.db.Exec(ctx, newExtension, arg.EndpointID, arg.Extension) return err } diff --git a/queries.sql b/queries.sql index 7e19f53..128ebc1 100644 --- a/queries.sql +++ b/queries.sql @@ -10,11 +10,12 @@ INSERT INTO ps_aors VALUES ($1, $2); --- name: NewEndpoint :exec +-- name: NewEndpoint :one INSERT INTO ps_endpoints (id, transport, aors, auth, context, disallow, allow) VALUES - ($1, $2, $1, $1, $3, 'all', $4); + ($1, $2, $1, $1, $3, 'all', $4) +RETURNING sid; -- name: DeleteEndpoint :exec DELETE FROM ps_endpoints WHERE id = $1; @@ -27,7 +28,26 @@ DELETE FROM ps_auths WHERE id = $1; -- name: ListEndpoints :many SELECT - id, context, transport + pe.id, pe.context, ee.extension FROM - ps_endpoints + ps_endpoints pe +LEFT JOIN + ery_extension ee +ON ee.endpoint_id = pe.sid LIMIT $1; + +-- name: NewExtension :exec +INSERT INTO ery_extension + (endpoint_id, extension) +VALUES + ($1, $2); + +-- name: GetEndpointByExtension :one +SELECT + ps_endpoints.id +FROM + ps_endpoints +INNER JOIN + ery_extension ee on ps_endpoints.sid = ee.endpoint_id +WHERE + ee.extension = $1;