Skip to content

Commit

Permalink
[OPIK-309] Expose Prompt library API contracts (#532)
Browse files Browse the repository at this point in the history
* [OPIK-309] Create prompt endpoint

* [OPIK-309] Expose API contracts

* [OPIK-309] Expose API contracts

* Add logic to create first version when specified
  • Loading branch information
thiagohora authored Nov 5, 2024
1 parent 203299d commit 39361a8
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Builder;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record CreatePromptVersion(@JsonView( {
PromptVersion.View.Detail.class}) @NotBlank String name,
@JsonView({PromptVersion.View.Detail.class}) @NotNull PromptVersion version){
}
37 changes: 25 additions & 12 deletions apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,44 @@
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record Prompt(
@JsonView( {
Prompt.View.Public.class, Prompt.View.Write.class}) UUID id,
@JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name,
Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) UUID id,
@JsonView({Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) @NotBlank String name,
@JsonView({Prompt.View.Public.class,
Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description,
Prompt.View.Write.class,
Prompt.View.Detail.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description,
@JsonView({
Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template,
@JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy,
@JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt,
@JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){
@JsonView({Prompt.View.Public.class,
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({Prompt.View.Public.class,
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy,
@JsonView({Prompt.View.Public.class,
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt,
@JsonView({Prompt.View.Public.class,
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy,
@JsonView({Prompt.View.Public.class,
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Long versionCount,
@JsonView({
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PromptVersion latestVersion){

public static class View {
public static class Write {
}

public static class Public {
}
}

public static class Detail {
}
}

@Builder
public record PromptPage(
@JsonView( {
Project.View.Public.class}) int page,
@JsonView({Project.View.Public.class}) int size,
@JsonView({Project.View.Public.class}) long total,
@JsonView({Project.View.Public.class}) List<Prompt> content)
Prompt.View.Public.class}) int page,
@JsonView({Prompt.View.Public.class}) int size,
@JsonView({Prompt.View.Public.class}) long total,
@JsonView({Prompt.View.Public.class}) List<Prompt> content)
implements
Page<Prompt>{

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ public static PromptVersion.PromptVersionPage empty(int page) {
return new PromptVersion.PromptVersionPage(page, 0, 0, List.of());
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.NotBlank;
import lombok.Builder;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record PromptVersionRetrieve(@NotBlank String name, String commit) {
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package com.comet.opik.api.resources.v1.priv;

import com.codahale.metrics.annotation.Timed;
import com.comet.opik.api.CreatePromptVersion;
import com.comet.opik.api.Prompt;
import com.comet.opik.api.PromptVersion;
import com.comet.opik.api.PromptVersionRetrieve;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.api.Prompt;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.domain.IdGenerator;
import com.comet.opik.domain.PromptService;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.ratelimit.RateLimited;
Expand All @@ -13,6 +19,21 @@
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Min;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.DefaultValue;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.validation.Valid;
Expand All @@ -28,16 +49,23 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.time.Instant;
import java.util.Set;
import java.util.UUID;
import java.util.stream.IntStream;

@Path("/v1/private/prompts")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@Timed
@Slf4j
@RequiredArgsConstructor(onConstructor_ = @Inject)
@Tag(name = "Prompts", description = "Prompt resources")
public class PromptResource {

private final @NonNull Provider<RequestContext> requestContext;
private final @NonNull PromptService promptService;
private final @NonNull IdGenerator idGenerator;

@POST
@Operation(operationId = "createPrompt", summary = "Create prompt", description = "Create prompt", responses = {
Expand Down Expand Up @@ -65,4 +93,236 @@ public Response createPrompt(
return Response.created(resourceUri).build();
}

@GET
@Operation(operationId = "getPrompts", summary = "Get prompts", description = "Get prompts", responses = {
@ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = Prompt.PromptPage.class))),
})
@JsonView({Prompt.View.Public.class})
public Response getPrompts(
@QueryParam("page") @Min(1) @DefaultValue("1") int page,
@QueryParam("size") @Min(1) @DefaultValue("10") int size,
@QueryParam("name") String name) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Getting prompts by name '{}' on workspace_id '{}'", name, workspaceId);
var promptPage = Prompt.PromptPage.builder()
.page(page)
.size(5)
.total(5)
.content(IntStream.range(0, 5).mapToObj(i -> generatePrompt()).toList())
.build();
log.info("Got prompts by name '{}', count '{}' on workspace_id '{}'", name, promptPage.size(), workspaceId);

return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptPage).build();
}

private Prompt generatePrompt() {
return Prompt.builder()
.id(idGenerator.generateId())
.name("Prompt 1")
.description("Description 1")
.createdAt(Instant.now())
.createdBy("User 1")
.lastUpdatedAt(Instant.now())
.lastUpdatedBy("User 1")
.latestVersion(generatePromptVersion())
.build();
}

@GET
@Path("{id}")
@Operation(operationId = "getPromptById", summary = "Get prompt by id", description = "Get prompt by id", responses = {
@ApiResponse(responseCode = "200", description = "Prompt resource", content = @Content(schema = @Schema(implementation = Prompt.class))),
@ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))),
})
@JsonView({Prompt.View.Detail.class})
public Response getPromptById(@PathParam("id") UUID id) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Getting prompt by id '{}' on workspace_id '{}'", id, workspaceId);

Prompt prompt = generatePrompt();

log.info("Got prompt by id '{}' on workspace_id '{}'", id, workspaceId);

return Response.status(Response.Status.NOT_IMPLEMENTED).entity(prompt).build();
}

@PUT
@Path("{id}")
@Operation(operationId = "updatePrompt", summary = "Update prompt", description = "Update prompt", responses = {
@ApiResponse(responseCode = "204", description = "No content"),
@ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))),
@ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))),
})
@RateLimited
public Response updatePrompt(
@PathParam("id") UUID id,
@RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Updating prompt with id '{}' on workspace_id '{}'", id, workspaceId);

log.info("Updated prompt with id '{}' on workspace_id '{}'", id, workspaceId);

return Response.status(Response.Status.NOT_IMPLEMENTED).build();
}

@DELETE
@Path("{id}")
@Operation(operationId = "deletePrompt", summary = "Delete prompt", description = "Delete prompt", responses = {
@ApiResponse(responseCode = "204", description = "No content")
})
public Response deletePrompt(@PathParam("id") UUID id) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Deleting prompt by id '{}' on workspace_id '{}'", id, workspaceId);

log.info("Deleted prompt by id '{}' on workspace_id '{}'", id, workspaceId);

return Response.status(Response.Status.NOT_IMPLEMENTED).build();
}

@POST
@Path("/versions")
@Operation(operationId = "createPromptVersion", summary = "Create prompt version", description = "Create prompt version", responses = {
@ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))),
@ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class)))
})
@RateLimited
@JsonView({PromptVersion.View.Detail.class})
public Response createPromptVersion(
@RequestBody(content = @Content(schema = @Schema(implementation = CreatePromptVersion.class))) @JsonView({
PromptVersion.View.Detail.class}) @Valid CreatePromptVersion promptVersion) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Creating prompt version commit '{}' on workspace_id '{}'", promptVersion.version().commit(),
workspaceId);

UUID id = idGenerator.generateId();
log.info("Created prompt version commit '{}' with id '{}' on workspace_id '{}'",
promptVersion.version().commit(), id, workspaceId);

return Response.status(Response.Status.NOT_IMPLEMENTED)
.entity(generatePromptVersion(promptVersion, id))
.build();
}

private PromptVersion generatePromptVersion(CreatePromptVersion promptVersion, UUID id) {
return PromptVersion.builder()
.id(id)
.commit(promptVersion.version().commit() == null
? id.toString().substring(id.toString().length() - 7)
: promptVersion.version().commit())
.template(promptVersion.version().template())
.variables(
Set.of("user_message"))
.createdAt(Instant.now())
.createdBy("User 1")
.build();
}

@GET
@Path("/{id}/versions")
@Operation(operationId = "getPromptVersions", summary = "Get prompt versions", description = "Get prompt versions", responses = {
@ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.PromptVersionPage.class))),
})
@JsonView({PromptVersion.View.Public.class})
public Response getPromptVersions(@PathParam("id") UUID id,
@QueryParam("page") @Min(1) @DefaultValue("1") int page,
@QueryParam("size") @Min(1) @DefaultValue("10") int size) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Getting prompt versions by id '{}' on workspace_id '{}'", id, workspaceId);

PromptVersion.PromptVersionPage promptVersionPage = PromptVersion.PromptVersionPage.builder()
.page(1)
.size(5)
.total(5)
.content(IntStream.range(0, 5).mapToObj(i -> generatePromptVersion()).toList())
.build();

log.info("Got prompt versions by id '{}' on workspace_id '{}'", id, workspaceId);

return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersionPage).build();
}

@GET
@Path("/{id}/versions/{versionId}")
@Operation(operationId = "getPromptVersionById", summary = "Get prompt version by id", description = "Get prompt version by id", responses = {
@ApiResponse(responseCode = "200", description = "Prompt version resource", content = @Content(schema = @Schema(implementation = PromptVersion.class))),
@ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))),
})
@JsonView({PromptVersion.View.Detail.class})
public Response getPromptVersionById(@PathParam("id") UUID id, @PathParam("versionId") UUID versionId) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Getting prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId);

PromptVersion promptVersion = generatePromptVersion().toBuilder()
.id(versionId)
.commit(versionId.toString().substring(versionId.toString().length() - 7))
.build();

log.info("Got prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId);

return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersion).build();
}

@POST
@Path("/prompts/versions/retrieve")
@Operation(operationId = "retrievePromptVersion", summary = "Retrieve prompt version", description = "Retrieve prompt version", responses = {
@ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))),
@ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))),
})
@JsonView({PromptVersion.View.Detail.class})
public Response retrievePromptVersion(
@RequestBody(content = @Content(schema = @Schema(implementation = PromptVersionRetrieve.class))) @Valid PromptVersionRetrieve retrieve) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Retrieving prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(),
retrieve.commit(), workspaceId);

UUID id = idGenerator.generateId();

log.info("Retrieved prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(),
retrieve.commit(), workspaceId);

return Response.status(Response.Status.NOT_IMPLEMENTED)
.entity(generatePromptVersion().toBuilder()
.id(id)
.commit(retrieve.commit() == null
? id.toString().substring(id.toString().length() - 7)
: retrieve.commit())
.build())
.build();
}

private PromptVersion generatePromptVersion() {
var id = idGenerator.generateId();
return PromptVersion.builder()
.id(id)
.commit(id.toString().substring(id.toString().length() - 7))
.template("Hello %s, My question is ${user_message}".formatted(id))
.variables(
Set.of("user_message"))
.createdAt(Instant.now())
.createdBy("User 1")
.build();
}

}
Loading

0 comments on commit 39361a8

Please sign in to comment.