diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 1f76aeb..9936f8d 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -3,6 +3,8 @@ definitions: properties: allow: type: boolean + callerid: + type: string destination: type: string type: object @@ -21,6 +23,8 @@ definitions: type: array context: type: string + display_name: + type: string extension: type: string id: @@ -34,22 +38,24 @@ definitions: transport: type: string type: object - handler.listEndpointsRequest: - properties: - endpoints: - items: - $ref: '#/definitions/sqlc.ListEndpointsRow' - type: array - type: object - sqlc.ListEndpointsRow: + handler.listEndpointEntry: properties: context: type: string + display_name: + type: string extension: type: string id: type: string type: object + handler.listEndpointsRequest: + properties: + endpoints: + items: + $ref: '#/definitions/handler.listEndpointEntry' + type: array + type: object host: localhost:8080 info: contact: {} @@ -79,7 +85,8 @@ paths: description: Bad Request "500": description: Internal Server Error - summary: Determine whether the specified action (call) is allowed or not. + summary: Determine whether the specified action (call) is allowed or not and + provide details on how tags: - bouncer /endpoint: diff --git a/internal/bouncer/bouncer.go b/internal/bouncer/bouncer.go index 7928f54..d5df61f 100644 --- a/internal/bouncer/bouncer.go +++ b/internal/bouncer/bouncer.go @@ -11,6 +11,7 @@ import ( type Response struct { Allow bool `json:"allow"` Destination string `json:"destination"` + CallerID string `json:"callerid"` } type Bouncer struct { @@ -30,7 +31,10 @@ func (b *Bouncer) Check(ctx context.Context, endpoint, dialed string) Response { } queries := sqlc.New(tx) - destination, err := queries.GetEndpointByExtension(ctx, db.Text(dialed)) + row, err := queries.GetEndpointByExtension(ctx, sqlc.GetEndpointByExtensionParams{ + ID: endpoint, + Extension: db.Text(dialed), + }) if err != nil { slog.Error("Failed to retrieve endpoint", slog.String("dialed", dialed), slog.String("reason", err.Error())) return result @@ -38,6 +42,7 @@ func (b *Bouncer) Check(ctx context.Context, endpoint, dialed string) Response { return Response{ Allow: true, - Destination: destination, + Destination: row.ID, + CallerID: row.Callerid.String, } } diff --git a/internal/handler/authorization.go b/internal/handler/authorization.go index f01a8e9..b45b60d 100644 --- a/internal/handler/authorization.go +++ b/internal/handler/authorization.go @@ -29,7 +29,8 @@ func (e *Authorization) Router() chi.Router { return r } -// @Summary Determine whether the specified action (call) is allowed or not. +// @Summary Determine whether the specified action (call) is allowed or not and provide details on how +// to accomplish it. // @Accept json // @Produce json // @Param payload body AuthorizationRequest true "Action to be reviewed" diff --git a/internal/handler/endpoint.go b/internal/handler/endpoint.go index fb80567..0dff0e3 100644 --- a/internal/handler/endpoint.go +++ b/internal/handler/endpoint.go @@ -4,6 +4,7 @@ import ( "crypto/md5" "encoding/hex" "encoding/json" + "fmt" "github.com/crazybolillo/eryth/internal/db" "github.com/crazybolillo/eryth/internal/sqlc" "github.com/go-chi/chi/v5" @@ -27,10 +28,18 @@ type createEndpointRequest struct { Codecs []string `json:"codecs"` MaxContacts int32 `json:"max_contacts,omitempty"` Extension string `json:"extension,omitempty"` + DisplayName string `json:"display_name"` +} + +type listEndpointEntry struct { + ID string `json:"id"` + Extension string `json:"extension"` + Context string `json:"context"` + DisplayName string `json:"display_name"` } type listEndpointsRequest struct { - Endpoints []sqlc.ListEndpointsRow `json:"endpoints"` + Endpoints []listEndpointEntry `json:"endpoints"` } func (e *Endpoint) Router() chi.Router { @@ -42,6 +51,27 @@ func (e *Endpoint) Router() chi.Router { return r } +// displayNameFromClid extracts the display name from a Caller ID. It is expected for the Caller ID to be in +// the following format: "Display Name" +// If no display name is found, an empty string is returned. +func displayNameFromClid(callerID string) string { + if callerID == "" { + return "" + } + + start := strings.Index(callerID, `"`) + if start != 0 { + return "" + } + + end := strings.LastIndex(callerID, `"`) + if end == -1 || end < 1 { + return "" + } + + return callerID[1:end] +} + // @Summary List existing endpoints. // @Param limit query int false "Limit the amount of endpoints returned" default(15) // @Produce json @@ -63,16 +93,26 @@ func (e *Endpoint) list(w http.ResponseWriter, r *http.Request) { } queries := sqlc.New(e.Conn) - endpoints, err := queries.ListEndpoints(r.Context(), int32(limit)) + rows, err := queries.ListEndpoints(r.Context(), int32(limit)) if err != nil { slog.Error("Query execution failed", slog.String("path", r.URL.Path), slog.String("msg", err.Error())) w.WriteHeader(http.StatusInternalServerError) return } - if endpoints == nil { - endpoints = []sqlc.ListEndpointsRow{} + if rows == nil { + rows = []sqlc.ListEndpointsRow{} } + endpoints := make([]listEndpointEntry, len(rows)) + for idx := range len(rows) { + row := rows[idx] + endpoints[idx] = listEndpointEntry{ + ID: row.ID, + Extension: row.Extension.String, + Context: row.Context.String, + DisplayName: displayNameFromClid(row.Callerid.String), + } + } response := listEndpointsRequest{ Endpoints: endpoints, } @@ -137,6 +177,7 @@ func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { Transport: db.Text(payload.Transport), Context: db.Text(payload.Context), Allow: db.Text(strings.Join(payload.Codecs, ",")), + Callerid: db.Text(fmt.Sprintf(`"%s" <%s>`, payload.DisplayName, payload.ID)), }) if err != nil { w.WriteHeader(http.StatusInternalServerError) diff --git a/internal/handler/endpoint_test.go b/internal/handler/endpoint_test.go new file mode 100644 index 0000000..e32ba43 --- /dev/null +++ b/internal/handler/endpoint_test.go @@ -0,0 +1,44 @@ +package handler + +import "testing" + +func TestDisplayNameFromClid(t *testing.T) { + cases := map[string]struct { + callerID string + displayName string + }{ + "valid": { + `"Kiwi Snow" `, + "Kiwi Snow", + }, + "empty": { + "", + "", + }, + "single_colon": { + `"`, + "", + }, + "empty_quotes": { + `""`, + "", + }, + "missing_start_quote": { + `John Smith" `, + "", + }, + "missing_end_quote": { + `"John Smith `, + "", + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + got := displayNameFromClid(tt.callerID) + if got != tt.displayName { + t.Errorf("got %q, want %q", got, tt.displayName) + } + }) + } +} diff --git a/internal/sqlc/queries.sql.go b/internal/sqlc/queries.sql.go index 5c8c1a9..697dd44 100644 --- a/internal/sqlc/queries.sql.go +++ b/internal/sqlc/queries.sql.go @@ -40,25 +40,37 @@ func (q *Queries) DeleteEndpoint(ctx context.Context, id string) error { const getEndpointByExtension = `-- name: GetEndpointByExtension :one SELECT - ps_endpoints.id + dest.id, src.callerid FROM - ps_endpoints + ps_endpoints dest INNER JOIN - ery_extension ee on ps_endpoints.sid = ee.endpoint_id + ery_extension ee ON dest.sid = ee.endpoint_id +INNER JOIN + ps_endpoints src ON src.id = $1 WHERE - ee.extension = $1 + ee.extension = $2 ` -func (q *Queries) GetEndpointByExtension(ctx context.Context, extension pgtype.Text) (string, error) { - row := q.db.QueryRow(ctx, getEndpointByExtension, extension) - var id string - err := row.Scan(&id) - return id, err +type GetEndpointByExtensionParams struct { + ID string `json:"id"` + Extension pgtype.Text `json:"extension"` +} + +type GetEndpointByExtensionRow struct { + ID string `json:"id"` + Callerid pgtype.Text `json:"callerid"` +} + +func (q *Queries) GetEndpointByExtension(ctx context.Context, arg GetEndpointByExtensionParams) (GetEndpointByExtensionRow, error) { + row := q.db.QueryRow(ctx, getEndpointByExtension, arg.ID, arg.Extension) + var i GetEndpointByExtensionRow + err := row.Scan(&i.ID, &i.Callerid) + return i, err } const listEndpoints = `-- name: ListEndpoints :many SELECT - pe.id, pe.context, ee.extension + pe.id, pe.callerid, pe.context, ee.extension FROM ps_endpoints pe LEFT JOIN @@ -69,6 +81,7 @@ LIMIT $1 type ListEndpointsRow struct { ID string `json:"id"` + Callerid pgtype.Text `json:"callerid"` Context pgtype.Text `json:"context"` Extension pgtype.Text `json:"extension"` } @@ -82,7 +95,12 @@ func (q *Queries) ListEndpoints(ctx context.Context, limit int32) ([]ListEndpoin var items []ListEndpointsRow for rows.Next() { var i ListEndpointsRow - if err := rows.Scan(&i.ID, &i.Context, &i.Extension); err != nil { + if err := rows.Scan( + &i.ID, + &i.Callerid, + &i.Context, + &i.Extension, + ); err != nil { return nil, err } items = append(items, i) @@ -112,9 +130,9 @@ func (q *Queries) NewAOR(ctx context.Context, arg NewAORParams) error { const newEndpoint = `-- name: NewEndpoint :one INSERT INTO ps_endpoints - (id, transport, aors, auth, context, disallow, allow) + (id, transport, aors, auth, context, disallow, allow, callerid) VALUES - ($1, $2, $1, $1, $3, 'all', $4) + ($1, $2, $1, $1, $3, 'all', $4, $5) RETURNING sid ` @@ -123,6 +141,7 @@ type NewEndpointParams struct { Transport pgtype.Text `json:"transport"` Context pgtype.Text `json:"context"` Allow pgtype.Text `json:"allow"` + Callerid pgtype.Text `json:"callerid"` } func (q *Queries) NewEndpoint(ctx context.Context, arg NewEndpointParams) (int32, error) { @@ -131,6 +150,7 @@ func (q *Queries) NewEndpoint(ctx context.Context, arg NewEndpointParams) (int32 arg.Transport, arg.Context, arg.Allow, + arg.Callerid, ) var sid int32 err := row.Scan(&sid) diff --git a/queries.sql b/queries.sql index 128ebc1..387d23a 100644 --- a/queries.sql +++ b/queries.sql @@ -12,9 +12,9 @@ VALUES -- name: NewEndpoint :one INSERT INTO ps_endpoints - (id, transport, aors, auth, context, disallow, allow) + (id, transport, aors, auth, context, disallow, allow, callerid) VALUES - ($1, $2, $1, $1, $3, 'all', $4) + ($1, $2, $1, $1, $3, 'all', $4, $5) RETURNING sid; -- name: DeleteEndpoint :exec @@ -28,7 +28,7 @@ DELETE FROM ps_auths WHERE id = $1; -- name: ListEndpoints :many SELECT - pe.id, pe.context, ee.extension + pe.id, pe.callerid, pe.context, ee.extension FROM ps_endpoints pe LEFT JOIN @@ -44,10 +44,12 @@ VALUES -- name: GetEndpointByExtension :one SELECT - ps_endpoints.id + dest.id, src.callerid FROM - ps_endpoints + ps_endpoints dest INNER JOIN - ery_extension ee on ps_endpoints.sid = ee.endpoint_id + ery_extension ee ON dest.sid = ee.endpoint_id +INNER JOIN + ps_endpoints src ON src.id = $1 WHERE - ee.extension = $1; + ee.extension = $2;