From 5542f40ca55cf12d22a4b0911e4b6f92fedf10ec Mon Sep 17 00:00:00 2001 From: CrazyBolillo Date: Thu, 22 Aug 2024 17:04:50 -0600 Subject: [PATCH] feat(endpoint): support updates Endpoints may now be updated (patched). It is important to note that because Go's JSON marshalling can't detect between omitted fields and fields set to null, they are treated as the same (ommited). Once created only text fields may be set to NULL by updating the endpoint with an empty string. Realm was also removed from the API since there is no clear benefit at the moment. All use cases so far use the default value for it. --- docs/swagger.yaml | 42 +++++++- internal/handler/endpoint.go | 201 ++++++++++++++++++++++++++++++++++- internal/sqlc/queries.sql.go | 88 +++++++++++++++ queries.sql | 36 +++++++ 4 files changed, 360 insertions(+), 7 deletions(-) diff --git a/docs/swagger.yaml b/docs/swagger.yaml index b6dc4bc..7df916f 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -33,8 +33,6 @@ definitions: type: integer password: type: string - realm: - type: string transport: type: string type: object @@ -83,6 +81,25 @@ definitions: total: type: integer type: object + handler.updateEndpointRequest: + properties: + codecs: + items: + type: string + type: array + context: + type: string + displayName: + type: string + extension: + type: string + maxContacts: + type: integer + password: + type: string + transport: + type: string + type: object host: localhost:8080 info: contact: {} @@ -204,4 +221,25 @@ paths: summary: Get information from a specific endpoint. tags: - endpoints + patch: + parameters: + - description: Sid of the endpoint to be updated + in: path + name: sid + required: true + type: integer + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handler.updateEndpointRequest' + "400": + description: Bad Request + "404": + description: Not Found + "500": + description: Internal Server Error + summary: Update the specified endpoint. Omitted or null fields will remain unchanged. + tags: + - endpoints swagger: "2.0" diff --git a/internal/handler/endpoint.go b/internal/handler/endpoint.go index 3e286f6..44b7089 100644 --- a/internal/handler/endpoint.go +++ b/internal/handler/endpoint.go @@ -17,6 +17,8 @@ import ( "strings" ) +const defaultRealm = "asterisk" + type Endpoint struct { *pgx.Conn } @@ -24,7 +26,6 @@ type Endpoint struct { type createEndpointRequest struct { ID string `json:"id"` Password string `json:"password"` - Realm string `json:"realm,omitempty"` Transport string `json:"transport,omitempty"` Context string `json:"context"` Codecs []string `json:"codecs"` @@ -58,12 +59,23 @@ type getEndpointResponse struct { Extension string `json:"extension"` } +type updateEndpointRequest struct { + Password *string `json:"password,omitempty"` + DisplayName *string `json:"displayName,omitempty"` + Transport *string `json:"transport,omitempty"` + Context *string `json:"context,omitempty"` + Codecs []string `json:"codecs,omitempty"` + MaxContacts *int32 `json:"maxContacts,omitempty"` + Extension *string `json:"extension,omitempty"` +} + func (e *Endpoint) Router() chi.Router { r := chi.NewRouter() r.Post("/", e.create) r.Get("/", e.list) r.Get("/{sid}", e.get) r.Delete("/{sid}", e.delete) + r.Patch("/{sid}", e.update) return r } @@ -89,6 +101,11 @@ func displayNameFromClid(callerID string) string { return callerID[1:end] } +func hashPassword(user, password, realm string) string { + hash := md5.Sum([]byte(user + ":" + realm + ":" + password)) + return hex.EncodeToString(hash[:]) +} + // @Summary Get information from a specific endpoint. // @Param sid path int true "Requested endpoint's sid" // @Produce json @@ -235,7 +252,6 @@ func (e *Endpoint) list(w http.ResponseWriter, r *http.Request) { func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { decoder := json.NewDecoder(r.Body) payload := createEndpointRequest{ - Realm: "asterisk", MaxContacts: 1, } @@ -254,12 +270,11 @@ func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { queries := sqlc.New(tx) - hash := md5.Sum([]byte(payload.ID + ":" + payload.Realm + ":" + payload.Password)) err = queries.NewMD5Auth(r.Context(), sqlc.NewMD5AuthParams{ ID: payload.ID, Username: db.Text(payload.ID), - Realm: db.Text(payload.Realm), - Md5Cred: db.Text(hex.EncodeToString(hash[:])), + Realm: db.Text(defaultRealm), + Md5Cred: db.Text(hashPassword(payload.ID, payload.Password, defaultRealm)), }) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -397,3 +412,179 @@ func (e *Endpoint) delete(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } + +// @Summary Update the specified endpoint. Omitted or null fields will remain unchanged. +// @Param sid path int true "Sid of the endpoint to be updated" +// @Success 200 {object} updateEndpointRequest +// @Failure 400 +// @Failure 404 +// @Failure 500 +// @Tags endpoints +// @Router /endpoints/{sid} [patch] +func (e *Endpoint) update(w http.ResponseWriter, r *http.Request) { + decoder := json.NewDecoder(r.Body) + var payload updateEndpointRequest + + err := decoder.Decode(&payload) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + urlSid := chi.URLParam(r, "sid") + sid, err := strconv.Atoi(urlSid) + if err != nil || sid <= 0 { + w.WriteHeader(http.StatusBadRequest) + return + } + + tx, err := e.Begin(r.Context()) + if err != nil { + slog.Error("Failed to start transaction", slog.String("reason", err.Error()), slog.String("path", r.URL.Path)) + w.WriteHeader(http.StatusInternalServerError) + return + } + + queries := sqlc.New(tx) + endpoint, err := queries.GetEndpointByID(r.Context(), int32(sid)) + if errors.Is(err, pgx.ErrNoRows) { + w.WriteHeader(http.StatusNotFound) + return + } else if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to retrieve endpoint", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + // Sorry for the incoming boilerplate but no dynamic SQL yet + var patchedEndpoint = sqlc.UpdateEndpointBySidParams{Sid: int32(sid)} + if payload.DisplayName != nil { + if *payload.DisplayName == "" { + patchedEndpoint.Callerid = db.Text("") + } else { + patchedEndpoint.Callerid = db.Text(fmt.Sprintf(`"%s" <%s>`, *payload.DisplayName, endpoint.ID)) + } + } else { + patchedEndpoint.Callerid = endpoint.Callerid + } + if payload.Context != nil { + patchedEndpoint.Context = db.Text(*payload.Context) + } else { + patchedEndpoint.Context = endpoint.Context + } + if payload.Transport != nil { + patchedEndpoint.Transport = db.Text(*payload.Transport) + } else { + patchedEndpoint.Transport = endpoint.Transport + } + if payload.Codecs != nil { + patchedEndpoint.Allow = db.Text(strings.Join(payload.Codecs, ",")) + } else { + patchedEndpoint.Allow = endpoint.Allow + } + err = queries.UpdateEndpointBySid(r.Context(), patchedEndpoint) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to update endpoint", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + if payload.MaxContacts != nil { + err = queries.UpdateAORById( + r.Context(), + sqlc.UpdateAORByIdParams{ + ID: endpoint.ID, + MaxContacts: db.Int4(*payload.MaxContacts), + }, + ) + } + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to update AOR", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + if payload.Extension != nil { + err = queries.UpdateExtensionByEndpointId( + r.Context(), + sqlc.UpdateExtensionByEndpointIdParams{ + EndpointID: int32(sid), + Extension: db.Text(*payload.Extension), + }, + ) + } + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to update extension", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + if payload.Password != nil { + if len(*payload.Password) < 12 { + w.WriteHeader(http.StatusBadRequest) + slog.Info("Invalid password provided", slog.String("path", r.URL.Path)) + return + } + err = queries.UpdateMD5AuthById( + r.Context(), + sqlc.UpdateMD5AuthByIdParams{ + ID: endpoint.ID, + Md5Cred: db.Text(hashPassword(endpoint.ID, *payload.Password, defaultRealm)), + }, + ) + } + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to update password", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + err = tx.Commit(r.Context()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to commit update", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + // TODO: Duplicate code, same as when fetching endpoint. Probably should put this into a service layer. + tx, err = e.Begin(r.Context()) + queries = sqlc.New(tx) + if err != nil { + slog.Error("Failed to create new transaction", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + w.WriteHeader(http.StatusInternalServerError) + return + } + res, err := queries.GetEndpointByID(r.Context(), int32(sid)) + if err != nil { + slog.Error( + "Failed to retrieve created endpoint", + slog.String("path", r.URL.Path), slog.String("reason", err.Error()), slog.Int("sid", int(sid)), + ) + w.WriteHeader(http.StatusInternalServerError) + return + } + + result := getEndpointResponse{ + Sid: int32(sid), + ID: res.ID, + Transport: res.Transport.String, + Context: res.Context.String, + Codecs: strings.Split(res.Allow.String, ","), + MaxContacts: res.MaxContacts.Int32, + Extension: res.Extension.String, + DisplayName: displayNameFromClid(res.Callerid.String), + } + content, err := json.Marshal(result) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to marshall response", slog.String("path", r.URL.Path)) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(content) + if err != nil { + slog.Error("Failed to write response", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + } + w.WriteHeader(http.StatusOK) +} diff --git a/internal/sqlc/queries.sql.go b/internal/sqlc/queries.sql.go index d5ccc69..b0750f5 100644 --- a/internal/sqlc/queries.sql.go +++ b/internal/sqlc/queries.sql.go @@ -255,3 +255,91 @@ func (q *Queries) NewMD5Auth(ctx context.Context, arg NewMD5AuthParams) error { ) return err } + +const updateAORById = `-- name: UpdateAORById :exec +UPDATE + ps_aors +SET + max_contacts = $1 +WHERE + id = $2 +` + +type UpdateAORByIdParams struct { + MaxContacts pgtype.Int4 `json:"max_contacts"` + ID string `json:"id"` +} + +func (q *Queries) UpdateAORById(ctx context.Context, arg UpdateAORByIdParams) error { + _, err := q.db.Exec(ctx, updateAORById, arg.MaxContacts, arg.ID) + return err +} + +const updateEndpointBySid = `-- name: UpdateEndpointBySid :exec +UPDATE + ps_endpoints +SET + callerid = $1, + context = $2, + transport = $3, + allow = $4 +WHERE + sid = $5 +` + +type UpdateEndpointBySidParams struct { + Callerid pgtype.Text `json:"callerid"` + Context pgtype.Text `json:"context"` + Transport pgtype.Text `json:"transport"` + Allow pgtype.Text `json:"allow"` + Sid int32 `json:"sid"` +} + +func (q *Queries) UpdateEndpointBySid(ctx context.Context, arg UpdateEndpointBySidParams) error { + _, err := q.db.Exec(ctx, updateEndpointBySid, + arg.Callerid, + arg.Context, + arg.Transport, + arg.Allow, + arg.Sid, + ) + return err +} + +const updateExtensionByEndpointId = `-- name: UpdateExtensionByEndpointId :exec +UPDATE + ery_extension +SET + extension = $1 +WHERE + endpoint_id = $2 +` + +type UpdateExtensionByEndpointIdParams struct { + Extension pgtype.Text `json:"extension"` + EndpointID int32 `json:"endpoint_id"` +} + +func (q *Queries) UpdateExtensionByEndpointId(ctx context.Context, arg UpdateExtensionByEndpointIdParams) error { + _, err := q.db.Exec(ctx, updateExtensionByEndpointId, arg.Extension, arg.EndpointID) + return err +} + +const updateMD5AuthById = `-- name: UpdateMD5AuthById :exec +UPDATE + ps_auths +SET + md5_cred = $1 +WHERE + id = $2 +` + +type UpdateMD5AuthByIdParams struct { + Md5Cred pgtype.Text `json:"md5_cred"` + ID string `json:"id"` +} + +func (q *Queries) UpdateMD5AuthById(ctx context.Context, arg UpdateMD5AuthByIdParams) error { + _, err := q.db.Exec(ctx, updateMD5AuthById, arg.Md5Cred, arg.ID) + return err +} diff --git a/queries.sql b/queries.sql index 6f48293..b76dbc6 100644 --- a/queries.sql +++ b/queries.sql @@ -68,3 +68,39 @@ WHERE -- name: CountEndpoints :one SELECT COUNT(*) FROM ps_endpoints; + +-- name: UpdateEndpointBySid :exec +UPDATE + ps_endpoints +SET + callerid = $1, + context = $2, + transport = $3, + allow = $4 +WHERE + sid = $5; + +-- name: UpdateExtensionByEndpointId :exec +UPDATE + ery_extension +SET + extension = $1 +WHERE + endpoint_id = $2; + +-- name: UpdateAORById :exec +UPDATE + ps_aors +SET + max_contacts = $1 +WHERE + id = $2; + +-- name: UpdateMD5AuthById :exec +UPDATE + ps_auths +SET + md5_cred = $1 +WHERE + id = $2; +