From 39361a8174a7428cf195942134b6a6cfb7d0b027 Mon Sep 17 00:00:00 2001 From: Thiago dos Santos Hora Date: Tue, 5 Nov 2024 10:39:47 +0100 Subject: [PATCH] [OPIK-309] Expose Prompt library API contracts (#532) * [OPIK-309] Create prompt endpoint * [OPIK-309] Expose API contracts * [OPIK-309] Expose API contracts * Add logic to create first version when specified --- .../comet/opik/api/CreatePromptVersion.java | 17 ++ .../main/java/com/comet/opik/api/Prompt.java | 37 ++- .../com/comet/opik/api/PromptVersion.java | 2 +- .../comet/opik/api/PromptVersionRetrieve.java | 13 + .../api/resources/v1/priv/PromptResource.java | 260 ++++++++++++++++++ .../comet/opik/domain/CommitGenerator.java | 14 + .../com/comet/opik/domain/PromptService.java | 1 - .../resources/v1/priv/PromptResourceTest.java | 1 + 8 files changed, 331 insertions(+), 14 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java new file mode 100644 index 000000000..04064916c --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java @@ -0,0 +1,17 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.NotNull; +import lombok.Builder; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record CreatePromptVersion(@JsonView( { + PromptVersion.View.Detail.class}) @NotBlank String name, + @JsonView({PromptVersion.View.Detail.class}) @NotNull PromptVersion version){ +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java index a726fb4f4..f598ffa08 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java @@ -21,16 +21,25 @@ @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public record Prompt( @JsonView( { - Prompt.View.Public.class, Prompt.View.Write.class}) UUID id, - @JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name, + Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) UUID id, + @JsonView({Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) @NotBlank String name, @JsonView({Prompt.View.Public.class, - Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, + Prompt.View.Write.class, + Prompt.View.Detail.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, @JsonView({ Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){ + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Long versionCount, + @JsonView({ + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PromptVersion latestVersion){ public static class View { public static class Write { @@ -38,14 +47,18 @@ public static class Write { public static class Public { } - } + public static class Detail { + } + } + + @Builder public record PromptPage( @JsonView( { - Project.View.Public.class}) int page, - @JsonView({Project.View.Public.class}) int size, - @JsonView({Project.View.Public.class}) long total, - @JsonView({Project.View.Public.class}) List content) + Prompt.View.Public.class}) int page, + @JsonView({Prompt.View.Public.class}) int size, + @JsonView({Prompt.View.Public.class}) long total, + @JsonView({Prompt.View.Public.class}) List content) implements Page{ diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java index cef10c243..daecd65ef 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java @@ -55,4 +55,4 @@ public static PromptVersion.PromptVersionPage empty(int page) { return new PromptVersion.PromptVersionPage(page, 0, 0, List.of()); } } -} \ No newline at end of file +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java new file mode 100644 index 000000000..60ae9a086 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java @@ -0,0 +1,13 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record PromptVersionRetrieve(@NotBlank String name, String commit) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java index 249bcf623..d1103a77d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java @@ -1,8 +1,14 @@ package com.comet.opik.api.resources.v1.priv; import com.codahale.metrics.annotation.Timed; +import com.comet.opik.api.CreatePromptVersion; import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptVersion; +import com.comet.opik.api.PromptVersionRetrieve; import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.api.Prompt; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.domain.IdGenerator; import com.comet.opik.domain.PromptService; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.ratelimit.RateLimited; @@ -13,6 +19,21 @@ import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.validation.Valid; +import jakarta.validation.constraints.Min; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.PUT; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; import jakarta.inject.Inject; import jakarta.inject.Provider; import jakarta.validation.Valid; @@ -28,16 +49,23 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import java.time.Instant; +import java.util.Set; +import java.util.UUID; +import java.util.stream.IntStream; + @Path("/v1/private/prompts") @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) @Timed @Slf4j @RequiredArgsConstructor(onConstructor_ = @Inject) +@Tag(name = "Prompts", description = "Prompt resources") public class PromptResource { private final @NonNull Provider requestContext; private final @NonNull PromptService promptService; + private final @NonNull IdGenerator idGenerator; @POST @Operation(operationId = "createPrompt", summary = "Create prompt", description = "Create prompt", responses = { @@ -65,4 +93,236 @@ public Response createPrompt( return Response.created(resourceUri).build(); } + @GET + @Operation(operationId = "getPrompts", summary = "Get prompts", description = "Get prompts", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = Prompt.PromptPage.class))), + }) + @JsonView({Prompt.View.Public.class}) + public Response getPrompts( + @QueryParam("page") @Min(1) @DefaultValue("1") int page, + @QueryParam("size") @Min(1) @DefaultValue("10") int size, + @QueryParam("name") String name) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompts by name '{}' on workspace_id '{}'", name, workspaceId); + var promptPage = Prompt.PromptPage.builder() + .page(page) + .size(5) + .total(5) + .content(IntStream.range(0, 5).mapToObj(i -> generatePrompt()).toList()) + .build(); + log.info("Got prompts by name '{}', count '{}' on workspace_id '{}'", name, promptPage.size(), workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptPage).build(); + } + + private Prompt generatePrompt() { + return Prompt.builder() + .id(idGenerator.generateId()) + .name("Prompt 1") + .description("Description 1") + .createdAt(Instant.now()) + .createdBy("User 1") + .lastUpdatedAt(Instant.now()) + .lastUpdatedBy("User 1") + .latestVersion(generatePromptVersion()) + .build(); + } + + @GET + @Path("{id}") + @Operation(operationId = "getPromptById", summary = "Get prompt by id", description = "Get prompt by id", responses = { + @ApiResponse(responseCode = "200", description = "Prompt resource", content = @Content(schema = @Schema(implementation = Prompt.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({Prompt.View.Detail.class}) + public Response getPromptById(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + Prompt prompt = generatePrompt(); + + log.info("Got prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(prompt).build(); + } + + @PUT + @Path("{id}") + @Operation(operationId = "updatePrompt", summary = "Update prompt", description = "Update prompt", responses = { + @ApiResponse(responseCode = "204", description = "No content"), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @RateLimited + public Response updatePrompt( + @PathParam("id") UUID id, + @RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Updating prompt with id '{}' on workspace_id '{}'", id, workspaceId); + + log.info("Updated prompt with id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).build(); + } + + @DELETE + @Path("{id}") + @Operation(operationId = "deletePrompt", summary = "Delete prompt", description = "Delete prompt", responses = { + @ApiResponse(responseCode = "204", description = "No content") + }) + public Response deletePrompt(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Deleting prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + log.info("Deleted prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).build(); + } + + @POST + @Path("/versions") + @Operation(operationId = "createPromptVersion", summary = "Create prompt version", description = "Create prompt version", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))) + }) + @RateLimited + @JsonView({PromptVersion.View.Detail.class}) + public Response createPromptVersion( + @RequestBody(content = @Content(schema = @Schema(implementation = CreatePromptVersion.class))) @JsonView({ + PromptVersion.View.Detail.class}) @Valid CreatePromptVersion promptVersion) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Creating prompt version commit '{}' on workspace_id '{}'", promptVersion.version().commit(), + workspaceId); + + UUID id = idGenerator.generateId(); + log.info("Created prompt version commit '{}' with id '{}' on workspace_id '{}'", + promptVersion.version().commit(), id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED) + .entity(generatePromptVersion(promptVersion, id)) + .build(); + } + + private PromptVersion generatePromptVersion(CreatePromptVersion promptVersion, UUID id) { + return PromptVersion.builder() + .id(id) + .commit(promptVersion.version().commit() == null + ? id.toString().substring(id.toString().length() - 7) + : promptVersion.version().commit()) + .template(promptVersion.version().template()) + .variables( + Set.of("user_message")) + .createdAt(Instant.now()) + .createdBy("User 1") + .build(); + } + + @GET + @Path("/{id}/versions") + @Operation(operationId = "getPromptVersions", summary = "Get prompt versions", description = "Get prompt versions", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.PromptVersionPage.class))), + }) + @JsonView({PromptVersion.View.Public.class}) + public Response getPromptVersions(@PathParam("id") UUID id, + @QueryParam("page") @Min(1) @DefaultValue("1") int page, + @QueryParam("size") @Min(1) @DefaultValue("10") int size) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt versions by id '{}' on workspace_id '{}'", id, workspaceId); + + PromptVersion.PromptVersionPage promptVersionPage = PromptVersion.PromptVersionPage.builder() + .page(1) + .size(5) + .total(5) + .content(IntStream.range(0, 5).mapToObj(i -> generatePromptVersion()).toList()) + .build(); + + log.info("Got prompt versions by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersionPage).build(); + } + + @GET + @Path("/{id}/versions/{versionId}") + @Operation(operationId = "getPromptVersionById", summary = "Get prompt version by id", description = "Get prompt version by id", responses = { + @ApiResponse(responseCode = "200", description = "Prompt version resource", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({PromptVersion.View.Detail.class}) + public Response getPromptVersionById(@PathParam("id") UUID id, @PathParam("versionId") UUID versionId) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId); + + PromptVersion promptVersion = generatePromptVersion().toBuilder() + .id(versionId) + .commit(versionId.toString().substring(versionId.toString().length() - 7)) + .build(); + + log.info("Got prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersion).build(); + } + + @POST + @Path("/prompts/versions/retrieve") + @Operation(operationId = "retrievePromptVersion", summary = "Retrieve prompt version", description = "Retrieve prompt version", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({PromptVersion.View.Detail.class}) + public Response retrievePromptVersion( + @RequestBody(content = @Content(schema = @Schema(implementation = PromptVersionRetrieve.class))) @Valid PromptVersionRetrieve retrieve) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Retrieving prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(), + retrieve.commit(), workspaceId); + + UUID id = idGenerator.generateId(); + + log.info("Retrieved prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(), + retrieve.commit(), workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED) + .entity(generatePromptVersion().toBuilder() + .id(id) + .commit(retrieve.commit() == null + ? id.toString().substring(id.toString().length() - 7) + : retrieve.commit()) + .build()) + .build(); + } + + private PromptVersion generatePromptVersion() { + var id = idGenerator.generateId(); + return PromptVersion.builder() + .id(id) + .commit(id.toString().substring(id.toString().length() - 7)) + .template("Hello %s, My question is ${user_message}".formatted(id)) + .variables( + Set.of("user_message")) + .createdAt(Instant.now()) + .createdBy("User 1") + .build(); + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java new file mode 100644 index 000000000..d738e967d --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java @@ -0,0 +1,14 @@ +package com.comet.opik.domain; + +import lombok.NonNull; +import lombok.experimental.UtilityClass; + +import java.util.UUID; + +@UtilityClass +class CommitGenerator { + + public String generateCommit(@NonNull UUID id) { + return id.toString().substring(id.toString().length() - 8); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java index 9207435bf..9cec53df1 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -22,7 +22,6 @@ @ImplementedBy(PromptServiceImpl.class) public interface PromptService { Prompt create(Prompt prompt); - } @Singleton diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java index ec87ae95c..db3df102c 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -159,6 +159,7 @@ void setUp() { void createPrompt__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean success) { var prompt = factory.manufacturePojo(Prompt.class); + String workspaceName = UUID.randomUUID().toString(); mockTargetWorkspace(okApikey, workspaceName, WORKSPACE_ID);