diff --git a/docs/swagger.yaml b/docs/swagger.yaml index b6dc4bc..7df916f 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -33,8 +33,6 @@ definitions: type: integer password: type: string - realm: - type: string transport: type: string type: object @@ -83,6 +81,25 @@ definitions: total: type: integer type: object + handler.updateEndpointRequest: + properties: + codecs: + items: + type: string + type: array + context: + type: string + displayName: + type: string + extension: + type: string + maxContacts: + type: integer + password: + type: string + transport: + type: string + type: object host: localhost:8080 info: contact: {} @@ -204,4 +221,25 @@ paths: summary: Get information from a specific endpoint. tags: - endpoints + patch: + parameters: + - description: Sid of the endpoint to be updated + in: path + name: sid + required: true + type: integer + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handler.updateEndpointRequest' + "400": + description: Bad Request + "404": + description: Not Found + "500": + description: Internal Server Error + summary: Update the specified endpoint. Omitted or null fields will remain unchanged. + tags: + - endpoints swagger: "2.0" diff --git a/internal/handler/endpoint.go b/internal/handler/endpoint.go index 3e286f6..44b7089 100644 --- a/internal/handler/endpoint.go +++ b/internal/handler/endpoint.go @@ -17,6 +17,8 @@ import ( "strings" ) +const defaultRealm = "asterisk" + type Endpoint struct { *pgx.Conn } @@ -24,7 +26,6 @@ type Endpoint struct { type createEndpointRequest struct { ID string `json:"id"` Password string `json:"password"` - Realm string `json:"realm,omitempty"` Transport string `json:"transport,omitempty"` Context string `json:"context"` Codecs []string `json:"codecs"` @@ -58,12 +59,23 @@ type getEndpointResponse struct { Extension string `json:"extension"` } +type updateEndpointRequest struct { + Password *string `json:"password,omitempty"` + DisplayName *string `json:"displayName,omitempty"` + Transport *string `json:"transport,omitempty"` + Context *string `json:"context,omitempty"` + Codecs []string `json:"codecs,omitempty"` + MaxContacts *int32 `json:"maxContacts,omitempty"` + Extension *string `json:"extension,omitempty"` +} + func (e *Endpoint) Router() chi.Router { r := chi.NewRouter() r.Post("/", e.create) r.Get("/", e.list) r.Get("/{sid}", e.get) r.Delete("/{sid}", e.delete) + r.Patch("/{sid}", e.update) return r } @@ -89,6 +101,11 @@ func displayNameFromClid(callerID string) string { return callerID[1:end] } +func hashPassword(user, password, realm string) string { + hash := md5.Sum([]byte(user + ":" + realm + ":" + password)) + return hex.EncodeToString(hash[:]) +} + // @Summary Get information from a specific endpoint. // @Param sid path int true "Requested endpoint's sid" // @Produce json @@ -235,7 +252,6 @@ func (e *Endpoint) list(w http.ResponseWriter, r *http.Request) { func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { decoder := json.NewDecoder(r.Body) payload := createEndpointRequest{ - Realm: "asterisk", MaxContacts: 1, } @@ -254,12 +270,11 @@ func (e *Endpoint) create(w http.ResponseWriter, r *http.Request) { queries := sqlc.New(tx) - hash := md5.Sum([]byte(payload.ID + ":" + payload.Realm + ":" + payload.Password)) err = queries.NewMD5Auth(r.Context(), sqlc.NewMD5AuthParams{ ID: payload.ID, Username: db.Text(payload.ID), - Realm: db.Text(payload.Realm), - Md5Cred: db.Text(hex.EncodeToString(hash[:])), + Realm: db.Text(defaultRealm), + Md5Cred: db.Text(hashPassword(payload.ID, payload.Password, defaultRealm)), }) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -397,3 +412,179 @@ func (e *Endpoint) delete(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } + +// @Summary Update the specified endpoint. Omitted or null fields will remain unchanged. +// @Param sid path int true "Sid of the endpoint to be updated" +// @Success 200 {object} updateEndpointRequest +// @Failure 400 +// @Failure 404 +// @Failure 500 +// @Tags endpoints +// @Router /endpoints/{sid} [patch] +func (e *Endpoint) update(w http.ResponseWriter, r *http.Request) { + decoder := json.NewDecoder(r.Body) + var payload updateEndpointRequest + + err := decoder.Decode(&payload) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + urlSid := chi.URLParam(r, "sid") + sid, err := strconv.Atoi(urlSid) + if err != nil || sid <= 0 { + w.WriteHeader(http.StatusBadRequest) + return + } + + tx, err := e.Begin(r.Context()) + if err != nil { + slog.Error("Failed to start transaction", slog.String("reason", err.Error()), slog.String("path", r.URL.Path)) + w.WriteHeader(http.StatusInternalServerError) + return + } + + queries := sqlc.New(tx) + endpoint, err := queries.GetEndpointByID(r.Context(), int32(sid)) + if errors.Is(err, pgx.ErrNoRows) { + w.WriteHeader(http.StatusNotFound) + return + } else if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to retrieve endpoint", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + // Sorry for the incoming boilerplate but no dynamic SQL yet + var patchedEndpoint = sqlc.UpdateEndpointBySidParams{Sid: int32(sid)} + if payload.DisplayName != nil { + if *payload.DisplayName == "" { + patchedEndpoint.Callerid = db.Text("") + } else { + patchedEndpoint.Callerid = db.Text(fmt.Sprintf(`"%s" <%s>`, *payload.DisplayName, endpoint.ID)) + } + } else { + patchedEndpoint.Callerid = endpoint.Callerid + } + if payload.Context != nil { + patchedEndpoint.Context = db.Text(*payload.Context) + } else { + patchedEndpoint.Context = endpoint.Context + } + if payload.Transport != nil { + patchedEndpoint.Transport = db.Text(*payload.Transport) + } else { + patchedEndpoint.Transport = endpoint.Transport + } + if payload.Codecs != nil { + patchedEndpoint.Allow = db.Text(strings.Join(payload.Codecs, ",")) + } else { + patchedEndpoint.Allow = endpoint.Allow + } + err = queries.UpdateEndpointBySid(r.Context(), patchedEndpoint) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to update endpoint", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + if payload.MaxContacts != nil { + err = queries.UpdateAORById( + r.Context(), + sqlc.UpdateAORByIdParams{ + ID: endpoint.ID, + MaxContacts: db.Int4(*payload.MaxContacts), + }, + ) + } + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to update AOR", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + if payload.Extension != nil { + err = queries.UpdateExtensionByEndpointId( + r.Context(), + sqlc.UpdateExtensionByEndpointIdParams{ + EndpointID: int32(sid), + Extension: db.Text(*payload.Extension), + }, + ) + } + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to update extension", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + if payload.Password != nil { + if len(*payload.Password) < 12 { + w.WriteHeader(http.StatusBadRequest) + slog.Info("Invalid password provided", slog.String("path", r.URL.Path)) + return + } + err = queries.UpdateMD5AuthById( + r.Context(), + sqlc.UpdateMD5AuthByIdParams{ + ID: endpoint.ID, + Md5Cred: db.Text(hashPassword(endpoint.ID, *payload.Password, defaultRealm)), + }, + ) + } + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to update password", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + err = tx.Commit(r.Context()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to commit update", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + return + } + + // TODO: Duplicate code, same as when fetching endpoint. Probably should put this into a service layer. + tx, err = e.Begin(r.Context()) + queries = sqlc.New(tx) + if err != nil { + slog.Error("Failed to create new transaction", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + w.WriteHeader(http.StatusInternalServerError) + return + } + res, err := queries.GetEndpointByID(r.Context(), int32(sid)) + if err != nil { + slog.Error( + "Failed to retrieve created endpoint", + slog.String("path", r.URL.Path), slog.String("reason", err.Error()), slog.Int("sid", int(sid)), + ) + w.WriteHeader(http.StatusInternalServerError) + return + } + + result := getEndpointResponse{ + Sid: int32(sid), + ID: res.ID, + Transport: res.Transport.String, + Context: res.Context.String, + Codecs: strings.Split(res.Allow.String, ","), + MaxContacts: res.MaxContacts.Int32, + Extension: res.Extension.String, + DisplayName: displayNameFromClid(res.Callerid.String), + } + content, err := json.Marshal(result) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Failed to marshall response", slog.String("path", r.URL.Path)) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(content) + if err != nil { + slog.Error("Failed to write response", slog.String("path", r.URL.Path), slog.String("reason", err.Error())) + } + w.WriteHeader(http.StatusOK) +} diff --git a/internal/sqlc/queries.sql.go b/internal/sqlc/queries.sql.go index d5ccc69..b0750f5 100644 --- a/internal/sqlc/queries.sql.go +++ b/internal/sqlc/queries.sql.go @@ -255,3 +255,91 @@ func (q *Queries) NewMD5Auth(ctx context.Context, arg NewMD5AuthParams) error { ) return err } + +const updateAORById = `-- name: UpdateAORById :exec +UPDATE + ps_aors +SET + max_contacts = $1 +WHERE + id = $2 +` + +type UpdateAORByIdParams struct { + MaxContacts pgtype.Int4 `json:"max_contacts"` + ID string `json:"id"` +} + +func (q *Queries) UpdateAORById(ctx context.Context, arg UpdateAORByIdParams) error { + _, err := q.db.Exec(ctx, updateAORById, arg.MaxContacts, arg.ID) + return err +} + +const updateEndpointBySid = `-- name: UpdateEndpointBySid :exec +UPDATE + ps_endpoints +SET + callerid = $1, + context = $2, + transport = $3, + allow = $4 +WHERE + sid = $5 +` + +type UpdateEndpointBySidParams struct { + Callerid pgtype.Text `json:"callerid"` + Context pgtype.Text `json:"context"` + Transport pgtype.Text `json:"transport"` + Allow pgtype.Text `json:"allow"` + Sid int32 `json:"sid"` +} + +func (q *Queries) UpdateEndpointBySid(ctx context.Context, arg UpdateEndpointBySidParams) error { + _, err := q.db.Exec(ctx, updateEndpointBySid, + arg.Callerid, + arg.Context, + arg.Transport, + arg.Allow, + arg.Sid, + ) + return err +} + +const updateExtensionByEndpointId = `-- name: UpdateExtensionByEndpointId :exec +UPDATE + ery_extension +SET + extension = $1 +WHERE + endpoint_id = $2 +` + +type UpdateExtensionByEndpointIdParams struct { + Extension pgtype.Text `json:"extension"` + EndpointID int32 `json:"endpoint_id"` +} + +func (q *Queries) UpdateExtensionByEndpointId(ctx context.Context, arg UpdateExtensionByEndpointIdParams) error { + _, err := q.db.Exec(ctx, updateExtensionByEndpointId, arg.Extension, arg.EndpointID) + return err +} + +const updateMD5AuthById = `-- name: UpdateMD5AuthById :exec +UPDATE + ps_auths +SET + md5_cred = $1 +WHERE + id = $2 +` + +type UpdateMD5AuthByIdParams struct { + Md5Cred pgtype.Text `json:"md5_cred"` + ID string `json:"id"` +} + +func (q *Queries) UpdateMD5AuthById(ctx context.Context, arg UpdateMD5AuthByIdParams) error { + _, err := q.db.Exec(ctx, updateMD5AuthById, arg.Md5Cred, arg.ID) + return err +} diff --git a/queries.sql b/queries.sql index 6f48293..b76dbc6 100644 --- a/queries.sql +++ b/queries.sql @@ -68,3 +68,39 @@ WHERE -- name: CountEndpoints :one SELECT COUNT(*) FROM ps_endpoints; + +-- name: UpdateEndpointBySid :exec +UPDATE + ps_endpoints +SET + callerid = $1, + context = $2, + transport = $3, + allow = $4 +WHERE + sid = $5; + +-- name: UpdateExtensionByEndpointId :exec +UPDATE + ery_extension +SET + extension = $1 +WHERE + endpoint_id = $2; + +-- name: UpdateAORById :exec +UPDATE + ps_aors +SET + max_contacts = $1 +WHERE + id = $2; + +-- name: UpdateMD5AuthById :exec +UPDATE + ps_auths +SET + md5_cred = $1 +WHERE + id = $2; +