From e53d35d2b66a1d8ee0874d7074d74a16376503a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacques=20Verr=C3=A9?= Date: Mon, 4 Nov 2024 15:39:42 +0100 Subject: [PATCH 1/3] Introduce Opik.search_spans (#545) * Added Opik.search_spans method --- .../documentation/docs/tracing/export_data.md | 168 ++++++++++++++++++ .../docs/tracing/export_traces.md | 111 ------------ .../documentation/sidebars.ts | 2 +- .../examples/search_traces_and_spans.py | 10 ++ .../src/opik/api_objects/opik_client.py | 42 ++++- sdks/python/tests/e2e/test_tracing.py | 55 ++++++ 6 files changed, 275 insertions(+), 113 deletions(-) create mode 100644 apps/opik-documentation/documentation/docs/tracing/export_data.md delete mode 100644 apps/opik-documentation/documentation/docs/tracing/export_traces.md create mode 100644 sdks/python/examples/search_traces_and_spans.py diff --git a/apps/opik-documentation/documentation/docs/tracing/export_data.md b/apps/opik-documentation/documentation/docs/tracing/export_data.md new file mode 100644 index 000000000..665c4b9a5 --- /dev/null +++ b/apps/opik-documentation/documentation/docs/tracing/export_data.md @@ -0,0 +1,168 @@ +--- +sidebar_label: Export Traces and Spans +toc_max_heading_level: 4 +--- + +# Exporting Traces and Spans + +When working with Opik, it is important to be able to export traces and spans so that you can use them to fine-tune your models or run deeper analysis. + +You can export the traces you have logged to the Opik platform using: + +1. Using the Opik SDK: You can use the [`Opik.search_traces`](https://www.comet.com/docs/opik/python-sdk-reference/Opik.html#opik.Opik.search_traces) and [`Opik.search_spans`](https://www.comet.com/docs/opik/python-sdk-reference/Opik.html#opik.Opik.search_spans) methods to export traces and spans. +2. Using the Opik REST API: You can use the [`/traces`](/reference/rest_api/get-traces-by-project.api.mdx) and [`/spans`](/reference/rest_api/get-spans-by-project.api.mdx) endpoints to export traces and spans. +3. Using the UI: Once you have selected the traces or spans you want to export, you can click on the `Export CSV` button in the `Actions` dropdown. + +:::tip +The recommended way to export traces is to use the [`Opik.search_traces`](https://www.comet.com/docs/opik/python-sdk-reference/Opik.html#opik.Opik.search_traces) and [`Opik.search_spans`](https://www.comet.com/docs/opik/python-sdk-reference/Opik.html#opik.Opik.search_spans) methods in the Opik SDK. +::: + +## Using the Opik SDK + +### Exporting traces + +The [`Opik.search_traces`](https://www.comet.com/docs/opik/python-sdk-reference/Opik.html#opik.Opik.search_traces) method allows you to both export all the traces in a project or search for specific traces and export them. + +#### Exporting all traces + +To export all traces, you will need to specify a `max_results` value that is higher than the total number of traces in your project: + +```python +import opik + +client = opik.Opik() + +traces = client.search_traces(project_name="Default project", max_results=1000000) +``` + +#### Search for specific traces + +You can use the `filter_string` parameter to search for specific traces: + +```python +import opik + +client = opik.Opik() + +traces = client.search_traces( + project_name="Default project", + filter_string='input contains "Opik"' +) + +# Convert to Dict if required +traces = [trace.dict() for trace in traces] +``` + +The `filter_string` parameter should follow the format ` ` with: + +1. ``: The column to filter on, these can be: + - `name` + - `input` + - `output` + - `start_time` + - `end_time` + - `metadata` + - `feedback_score` + - `tags` + - `usage.total_tokens` + - `usage.prompt_tokens` + - `usage.completion_tokens`. +2. ``: The operator to use for the filter, this can be `=`, `!=`, `>`, `>=`, `<`, `<=`, `contains`, `not_contains`. Not that not all operators are supported for all columns. +3. ``: The value to filter on. If you are filtering on a string, you will need to wrap it in double quotes. + +Here are some additional examples of valid `filter_string` values: + +```python +import opik + +client = opik.Opik( + project_name="Default project" +) + +# Search for traces where the input contains text +traces = client.search_traces( + filter_string='input contains "Opik"' +) + +# Search for traces that were logged after a specific date +traces = client.search_traces(filter_string='start_time >= "2024-01-01T00:00:00Z"') + +# Search for traces that have a specific tag +traces = client.search_traces(filter_string='tags contains "production"') + +# Search for traces based on the number of tokens used +traces = client.search_traces(filter_string='usage.total_tokens > 1000') + +# Search for traces based on the model used +traces = client.search_traces(filter_string='metadata.model = "gpt-4o"') +``` + +### Exporting spans + +You can export spans using the [`Opik.search_spans`](https://www.comet.com/docs/opik/python-sdk-reference/Opik.html#opik.Opik.search_spans) method. This methods allows you to search for spans based on `trace_id` or based on a filter string. + +#### Exporting spans based on `trace_id` + +To export all the spans associated with a specific trace, you can use the `trace_id` parameter: + +```python +import opik + +client = opik.Opik() + +spans = client.search_spans( + project_name="Default project", + trace_id="067092dc-e639-73ff-8000-e1c40172450f" +) +``` + +#### Search for specific spans + +You can use the `filter_string` parameter to search for specific spans: + +```python +import opik + +client = opik.Opik() + +spans = client.search_spans( + project_name="Default project", + filter_string='input contains "Opik"' +) +``` + +:::tip +The `filter_string` parameter should follow the same format as the `filter_string` parameter in the `Opik.search_traces` method as [defined above](#search-for-specific-traces). +::: + +## Using the Opik REST API + +To export traces using the Opik REST API, you can use the [`/traces`](/reference/rest_api/get-traces-by-project.api.mdx) endpoint and the [`/spans`](/reference/rest_api/get-spans-by-project.api.mdx) endpoint. These endpoints are paginated so you will need to make multiple requests to retrieve all the traces or spans you want. + +To search for specific traces or spans, you can use the `filter` parameter. While this is a string parameter, it does not follow the same format as the `filter_string` parameter in the Opik SDK. Instead it is a list of json objects with the following format: + +```json +[ + { + "field": "name", + "type": "string", + "operator": "=", + "value": "Opik" + } +] +``` + +:::warning +The `filter` parameter was designed to be used with the Opik UI and has therefore limited flexibility. If you need more flexibility, +please raise an issue on [GitHub](https://github.com/comet-ml/opik/issues) so we can help. +::: + +## Using the UI + +To export traces as a CSV file from the UI, you can simply select the traces or spans you wish to export and click on `Export CSV` in the `Actions` dropdown: + +![Export CSV](/img/tracing/download_traces.png) + +:::tip +The UI only allows you to export up to 100 traces or spans at a time as it is linked to the page size of the traces table. If you need to export more traces or spans, we recommend using the Opik SDK. +::: diff --git a/apps/opik-documentation/documentation/docs/tracing/export_traces.md b/apps/opik-documentation/documentation/docs/tracing/export_traces.md deleted file mode 100644 index 3fb5259b8..000000000 --- a/apps/opik-documentation/documentation/docs/tracing/export_traces.md +++ /dev/null @@ -1,111 +0,0 @@ ---- -sidebar_label: Export Traces ---- - -# Exporting Traces - -You can export the traces you have logged to the Opik platform using: - -1. Using the Opik SDK: You can use the [`Opik.search_traces`](https://www.comet.com/docs/opik/python-sdk-reference/Opik.html#opik.Opik.search_traces) method to export traces. -2. Using the Opik REST API: You can use the [`/traces`](/reference/rest_api/get-traces-by-project.api.mdx) endpoint to export traces. -3. Using the UI: Once you have selected the traces you want to export, you can click on the `Export CSV` button in the `Actions` dropdown. - -:::tip -The recommended way to export traces is to use the [`Opik.search_traces`](https://www.comet.com/docs/opik/python-sdk-reference/Opik.html#opik.Opik.search_traces) method in the Opik SDK. -::: - -## Using the Opik SDK - -The [`Opik.search_traces`](https://www.comet.com/docs/opik/python-sdk-reference/Opik.html#opik.Opik.search_traces) method allows you to both export all the traces in a project or search for specific traces and export them. - -### Exporting all traces - -To export all traces, you will need to specify a `max_results` value that is higher than the total number of traces in your project: - -```python -import opik - -client = opik.Opik() - -traces = client.search_traces(project_name="Default project", max_results=1000000) -``` - -### Search for specific traces - -You can use the `filter_string` parameter to search for specific traces: - -```python -import opik - -client = opik.Opik() - -traces = client.search_traces(project_name="Default project", filter_string='input contains "Opik"') - -# Convert to Dict if required -traces = [trace.dict() for trace in traces] -``` - -The `filter_string` parameter should follow the format ` ` with: - -1. ``: The column to filter on, these can be: - - `name` - - `input` - - `output` - - `start_time` - - `end_time` - - `metadata` - - `feedback_score` - - `tags` - - `usage.total_tokens` - - `usage.prompt_tokens` - - `usage.completion_tokens`. -2. ``: The operator to use for the filter, this can be `=`, `!=`, `>`, `>=`, `<`, `<=`, `contains`, `not_contains`. Not that not all operators are supported for all columns. -3. ``: The value to filter on. If you are filtering on a string, you will need to wrap it in double quotes. - -Here are some additional examples of valid `filter_string` values: - -```python -import opik - -client = opik.Opik( - project_name="Default project" -) - -traces = client.search_traces(filter_string='input contains "Opik"') -traces = client.search_traces(filter_string='start_time >= "2024-01-01T00:00:00Z"') -traces = client.search_traces(filter_string='tags contains "production"') -traces = client.search_traces(filter_string='usage.total_tokens > 1000') -traces = client.search_traces(filter_string='metadata.model = "gpt-4o"') -``` - -## Using the Opik REST API - -To export traces using the Opik REST API, you can use the [`/traces`](/reference/rest_api/get-traces-by-project.api.mdx) endpoint. This endpoint is paginated so you will need to make multiple requests to retrieve all the traces you want. - -To search for specific traces, you can use the `filter` parameter. While this is a string parameter, it does not follow the same format as the `filter_string` parameter in the Opik SDK. Instead it is a list of json objects with the following format: - -```json -[ - { - "field": "name", - "type": "string", - "operator": "=", - "value": "Opik" - } -] -``` - -:::warning -The `filter` parameter was designed to be used with the Opik UI and is therefore not very flexible. If you need more flexibility, -please raise an issue on [GitHub](https://github.com/comet-ml/opik/issues) so we can help. -::: - -## Using the UI - -To export traces as a CSV file from the UI, you can simply select the traces you wish to export and click on `Export CSV` in the `Actions` dropdown: - -![Export CSV](/img/tracing/download_traces.png) - -:::tip -The UI only allows you to export up to 100 traces at a time as it is linked to the page size of the traces table. If you need to export more traces, we recommend using the Opik SDK. -::: diff --git a/apps/opik-documentation/documentation/sidebars.ts b/apps/opik-documentation/documentation/sidebars.ts index 30e2a77b4..95715c516 100644 --- a/apps/opik-documentation/documentation/sidebars.ts +++ b/apps/opik-documentation/documentation/sidebars.ts @@ -31,7 +31,7 @@ const sidebars: SidebarsConfig = { "tracing/log_distributed_traces", "tracing/annotate_traces", "tracing/sdk_configuration", - "tracing/export_traces", + "tracing/export_data", { type: "category", label: "Integrations", diff --git a/sdks/python/examples/search_traces_and_spans.py b/sdks/python/examples/search_traces_and_spans.py new file mode 100644 index 000000000..c98896825 --- /dev/null +++ b/sdks/python/examples/search_traces_and_spans.py @@ -0,0 +1,10 @@ +import opik + +opik_client = opik.Opik() + +spans = opik_client.search_spans( + project_name="Demo Project", + filter_string='input contains "How many unique albums"', +) + +print(spans) diff --git a/sdks/python/src/opik/api_objects/opik_client.py b/sdks/python/src/opik/api_objects/opik_client.py index 941ab4929..7a65ed355 100644 --- a/sdks/python/src/opik/api_objects/opik_client.py +++ b/sdks/python/src/opik/api_objects/opik_client.py @@ -505,7 +505,7 @@ def search_traces( Search for traces in the given project. Args: - project_name: The name of the project to search traces in. If not provided the project name configured when the Client was created will be used. + project_name: The name of the project to search traces in. If not provided, will search across the project name configured when the Client was created which defaults to the `Default Project`. filter_string: A filter string to narrow down the search. If not provided, all traces in the project will be returned up to the limit. max_results: The maximum number of traces to return. """ @@ -532,6 +532,46 @@ def search_traces( return traces[:max_results] + def search_spans( + self, + project_name: Optional[str] = None, + trace_id: Optional[str] = None, + filter_string: Optional[str] = None, + max_results: int = 1000, + ) -> List[span_public.SpanPublic]: + """ + Search for spans in the given trace. This allows you to search spans based on the span input, output, + metadata, tags, etc or based on the trace ID. + + Args: + project_name: The name of the project to search spans in. If not provided, will search across the project name configured when the Client was created which defaults to the `Default Project`. + trace_id: The ID of the trace to search spans in. If provided, the search will be limited to the spans in the given trace. + filter_string: A filter string to narrow down the search. + max_results: The maximum number of spans to return. + """ + page_size = 200 + spans: List[span_public.SpanPublic] = [] + + filters = opik_query_language.OpikQueryLanguage(filter_string).parsed_filters + + page = 1 + while len(spans) < max_results: + page_spans = self._rest_client.spans.get_spans_by_project( + project_name=project_name or self._project_name, + trace_id=trace_id, + filters=filters, + page=page, + size=page_size, + ) + + if len(page_spans.content) == 0: + break + + spans.extend(page_spans.content) + page += 1 + + return spans[:max_results] + def get_trace_content(self, id: str) -> trace_public.TracePublic: """ Args: diff --git a/sdks/python/tests/e2e/test_tracing.py b/sdks/python/tests/e2e/test_tracing.py index a2e9687cd..d258f98cc 100644 --- a/sdks/python/tests/e2e/test_tracing.py +++ b/sdks/python/tests/e2e/test_tracing.py @@ -3,6 +3,7 @@ import opik from opik import opik_context +from opik.api_objects import helpers from . import verifiers from .conftest import OPIK_E2E_TESTS_PROJECT_NAME @@ -259,3 +260,57 @@ def test_search_traces__happyflow(opik_client): output={"output": "trace-output"}, project_name=OPIK_E2E_TESTS_PROJECT_NAME, ) + + +def test_search_spans__happyflow(opik_client): + # In order to define a unique search query, we will create a unique identifier that will be part of the input of the trace + trace_id = helpers.generate_id() + unique_identifier = str(uuid.uuid4())[-6:] + + filter_string = f'input contains "{unique_identifier}"' + + # Send a trace that matches the input filter + trace = opik_client.trace( + id=trace_id, + name="trace-name", + input={"input": "Some random input"}, + output={"output": "trace-output"}, + project_name=OPIK_E2E_TESTS_PROJECT_NAME, + ) + matching_span = trace.span( + name="span-name", + input={"input": f"Some random input - {unique_identifier}"}, + output={"output": "span-output"}, + ) + trace.span( + name="span-name", + input={"input": "Some random input"}, + output={"output": "span-output"}, + ) + + # Send a trace that does not match the input filter + trace = opik_client.trace( + id=trace_id, + name="trace-name", + input={"input": "Some random input"}, + output={"output": "trace-output"}, + project_name=OPIK_E2E_TESTS_PROJECT_NAME, + ) + trace.span( + name="span-name", + input={"input": "Some random input"}, + output={"output": "span-output"}, + ) + + opik_client.flush() + + # Search for the traces - Note that we use a large max_results to ensure that we get all traces, if the project has more than 100000 matching traces it is possible + spans = opik_client.search_spans( + project_name=OPIK_E2E_TESTS_PROJECT_NAME, + trace_id=trace_id, + filter_string=filter_string, + ) + + # Verify that the matching trace is returned + assert len(spans) == 1, "Expected to find 1 matching span" + assert spans[0].id == matching_span.id, "Expected to find the matching span" From 203299d976c454acd0e1fd349bec4beda3b2f94c Mon Sep 17 00:00:00 2001 From: Thiago dos Santos Hora Date: Mon, 4 Nov 2024 16:36:14 +0100 Subject: [PATCH 2/3] [OPIK-309] Create prompt endpoint (#531) * [OPIK-309] Create prompt endpoint * Add logic to create first version when specified * Address PR review comments --- .../main/java/com/comet/opik/api/Prompt.java | 56 ++++ .../com/comet/opik/api/PromptVersion.java | 58 ++++ .../error/EntityAlreadyExistsException.java | 10 +- .../resources/v1/priv/ProjectsResource.java | 2 +- .../api/resources/v1/priv/PromptResource.java | 68 ++++ .../com/comet/opik/domain/CommitUtils.java | 14 + .../com/comet/opik/domain/DatasetDAO.java | 3 +- .../com/comet/opik/domain/DatasetService.java | 3 +- .../opik/domain/EntityConstraintHandler.java | 60 ++++ .../java/com/comet/opik/domain/PromptDAO.java | 25 ++ .../com/comet/opik/domain/PromptService.java | 118 +++++++ .../comet/opik/domain/PromptVersionDAO.java | 25 ++ ..._increate_prompt_version_commit_length.sql | 6 + .../v1/events/DatasetEventListenerTest.java | 1 - .../v1/priv/DatasetExperimentE2ETest.java | 2 +- .../v1/priv/DatasetsResourceTest.java | 3 +- .../resources/v1/priv/PromptResourceTest.java | 317 ++++++++++++++++++ .../domain/EntityConstraintHandlerTest.java | 88 +++++ 18 files changed, 852 insertions(+), 7 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/CommitUtils.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java create mode 100644 apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java new file mode 100644 index 000000000..a726fb4f4 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java @@ -0,0 +1,56 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.Pattern; +import lombok.Builder; + +import java.time.Instant; +import java.util.List; +import java.util.UUID; + +import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record Prompt( + @JsonView( { + Prompt.View.Public.class, Prompt.View.Write.class}) UUID id, + @JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name, + @JsonView({Prompt.View.Public.class, + Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, + @JsonView({ + Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){ + + public static class View { + public static class Write { + } + + public static class Public { + } + } + + public record PromptPage( + @JsonView( { + Project.View.Public.class}) int page, + @JsonView({Project.View.Public.class}) int size, + @JsonView({Project.View.Public.class}) long total, + @JsonView({Project.View.Public.class}) List content) + implements + Page{ + + public static Prompt.PromptPage empty(int page) { + return new Prompt.PromptPage(page, 0, 0, List.of()); + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java new file mode 100644 index 000000000..cef10c243 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java @@ -0,0 +1,58 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.NotNull; +import lombok.Builder; + +import java.time.Instant; +import java.util.List; +import java.util.Set; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record PromptVersion( + @JsonView( { + PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version unique identifier, generated if absent") UUID id, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID promptId, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent") String commit, + @JsonView({PromptVersion.View.Detail.class}) @NotNull String template, + @JsonView({ + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Set variables, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy){ + + public static class View { + public static class Public { + } + + public static class Detail { + } + } + + @Builder + public record PromptVersionPage( + @JsonView( { + PromptVersion.View.Public.class}) int page, + @JsonView({PromptVersion.View.Public.class}) int size, + @JsonView({PromptVersion.View.Public.class}) long total, + @JsonView({PromptVersion.View.Public.class}) List content) + implements + Page{ + + public static PromptVersion.PromptVersionPage empty(int page) { + return new PromptVersion.PromptVersionPage(page, 0, 0, List.of()); + } + } +} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java b/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java index df74e5e10..b8a16bddd 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java @@ -6,6 +6,14 @@ public class EntityAlreadyExistsException extends ClientErrorException { public EntityAlreadyExistsException(ErrorMessage response) { - super(Response.status(Response.Status.CONFLICT).entity(response).build()); + this((Object) response); + } + + public EntityAlreadyExistsException(io.dropwizard.jersey.errors.ErrorMessage response) { + this((Object) response); + } + + private EntityAlreadyExistsException(Object entity) { + super(Response.status(Response.Status.CONFLICT).entity(entity).build()); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java index 186d52022..ac0a72229 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java @@ -95,7 +95,7 @@ public Response getById(@PathParam("id") UUID id) { } @POST - @Operation(operationId = "createProject", summary = "Create project", description = "Get project", responses = { + @Operation(operationId = "createProject", summary = "Create project", description = "Create project", responses = { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/projects/{projectId}", schema = @Schema(implementation = String.class))}), @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java new file mode 100644 index 000000000..249bcf623 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java @@ -0,0 +1,68 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.codahale.metrics.annotation.Timed; +import com.comet.opik.api.Prompt; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.domain.PromptService; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; +import com.fasterxml.jackson.annotation.JsonView; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.validation.Valid; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +@Path("/v1/private/prompts") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +@Timed +@Slf4j +@RequiredArgsConstructor(onConstructor_ = @Inject) +public class PromptResource { + + private final @NonNull Provider requestContext; + private final @NonNull PromptService promptService; + + @POST + @Operation(operationId = "createPrompt", summary = "Create prompt", description = "Create prompt", responses = { + @ApiResponse(responseCode = "201", description = "Created", headers = { + @Header(name = "Location", required = true, example = "${basePath}/v1/private/prompts/{promptId}", schema = @Schema(implementation = String.class))}), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + + }) + @RateLimited + public Response createPrompt( + @RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt, + @Context UriInfo uriInfo) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Creating prompt with name '{}', on workspace_id '{}'", prompt.name(), workspaceId); + prompt = promptService.create(prompt); + log.info("Prompt created with id '{}' name '{}', on workspace_id '{}'", prompt.id(), prompt.name(), + workspaceId); + + var resourceUri = uriInfo.getAbsolutePathBuilder().path("/%s".formatted(prompt.id())).build(); + + return Response.created(resourceUri).build(); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitUtils.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitUtils.java new file mode 100644 index 000000000..7253d3ee9 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitUtils.java @@ -0,0 +1,14 @@ +package com.comet.opik.domain; + +import lombok.NonNull; +import lombok.experimental.UtilityClass; + +import java.util.UUID; + +@UtilityClass +class CommitUtils { + + public String getCommit(@NonNull UUID id) { + return id.toString().substring(id.toString().length() - 8); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java index c687dc72a..d96b85b49 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java @@ -81,6 +81,7 @@ List find(@Bind("limit") int limit, Optional findByName(@Bind("workspace_id") String workspaceId, @Bind("name") String name); @SqlBatch("UPDATE datasets SET last_created_experiment_at = :experimentCreatedAt WHERE id = :datasetId AND workspace_id = :workspace_id") - int[] recordExperiments(@Bind("workspace_id") String workspaceId, @BindMethods Collection datasets); + int[] recordExperiments(@Bind("workspace_id") String workspaceId, + @BindMethods Collection datasets); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java index 4002a9c59..48df7716f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java @@ -278,7 +278,8 @@ private List enrichDatasetWithAdditionalInformation(List datas return datasets.stream() .map(dataset -> { var resume = experimentSummary.computeIfAbsent(dataset.id(), ExperimentSummary::empty); - var datasetItemSummary = datasetItemSummaryMap.computeIfAbsent(dataset.id(), DatasetItemSummary::empty); + var datasetItemSummary = datasetItemSummaryMap.computeIfAbsent(dataset.id(), + DatasetItemSummary::empty); return dataset.toBuilder() .experimentCount(resume.experimentCount()) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java new file mode 100644 index 000000000..a0919e854 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java @@ -0,0 +1,60 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.error.EntityAlreadyExistsException; +import com.google.common.base.Preconditions; +import org.jdbi.v3.core.statement.UnableToExecuteStatementException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.SQLIntegrityConstraintViolationException; +import java.util.function.Supplier; + +interface EntityConstraintHandler { + + Logger log = LoggerFactory.getLogger(EntityConstraintHandler.class); + + static EntityConstraintHandler handle(EntityConstraintAction entityAction) { + return () -> entityAction; + } + + interface EntityConstraintAction { + T execute(); + } + + EntityConstraintAction wrappedAction(); + + default T withError(Supplier errorProvider) { + try { + return wrappedAction().execute(); + } catch (UnableToExecuteStatementException e) { + if (e.getCause() instanceof SQLIntegrityConstraintViolationException) { + throw errorProvider.get(); + } else { + throw e; + } + } + } + + default T withRetry(int times, Supplier errorProvider) { + Preconditions.checkArgument(times > 0, "Retry times must be greater than 0"); + + return internalRetry(times, errorProvider); + } + + private T internalRetry(int times, Supplier errorProvider) { + try { + return wrappedAction().execute(); + } catch (UnableToExecuteStatementException e) { + if (e.getCause() instanceof SQLIntegrityConstraintViolationException) { + if (times > 0) { + log.warn("Retrying due to constraint violation, remaining attempts: {}", times); + return internalRetry(times - 1, errorProvider); + } + throw errorProvider.get(); + } else { + throw e; + } + } + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java new file mode 100644 index 000000000..6bfaee9d6 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java @@ -0,0 +1,25 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.Prompt; +import com.comet.opik.infrastructure.db.UUIDArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.customizer.BindMethods; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +import java.util.UUID; + +@RegisterConstructorMapper(Prompt.class) +@RegisterArgumentFactory(UUIDArgumentFactory.class) +interface PromptDAO { + + @SqlUpdate("INSERT INTO prompts (id, name, description, created_by, last_updated_by, workspace_id) " + + "VALUES (:bean.id, :bean.name, :bean.description, :bean.createdBy, :bean.lastUpdatedBy, :workspaceId)") + void save(@Bind("workspaceId") String workspaceId, @BindMethods("bean") Prompt prompt); + + @SqlQuery("SELECT * FROM prompts WHERE id = :id AND workspace_id = :workspaceId") + Prompt findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId); + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java new file mode 100644 index 000000000..9207435bf --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -0,0 +1,118 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptVersion; +import com.comet.opik.api.error.EntityAlreadyExistsException; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.google.inject.ImplementedBy; +import io.dropwizard.jersey.errors.ErrorMessage; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.inject.Singleton; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; + +import java.util.UUID; + +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; + +@ImplementedBy(PromptServiceImpl.class) +public interface PromptService { + Prompt create(Prompt prompt); + +} + +@Singleton +@Slf4j +@RequiredArgsConstructor(onConstructor_ = @Inject) +class PromptServiceImpl implements PromptService { + + private static final String ALREADY_EXISTS = "Prompt id or name already exists"; + private static final String VERSION_ALREADY_EXISTS = "Prompt version already exists"; + private final @NonNull Provider requestContext; + private final @NonNull IdGenerator idGenerator; + private final @NonNull TransactionTemplate transactionTemplate; + + @Override + public Prompt create(Prompt prompt) { + + String workspaceId = requestContext.get().getWorkspaceId(); + String userName = requestContext.get().getUserName(); + + var newPrompt = prompt.toBuilder() + .id(prompt.id() == null ? idGenerator.generateId() : prompt.id()) + .createdBy(userName) + .lastUpdatedBy(userName) + .build(); + + IdGenerator.validateVersion(prompt.id(), "prompt"); + + var createdPrompt = EntityConstraintHandler + .handle(() -> savePrompt(workspaceId, newPrompt)) + .withError(this::newPromptConflict); + + log.info("Prompt created with id '{}' name '{}', on workspace_id '{}'", createdPrompt.id(), + createdPrompt.name(), + workspaceId); + + if (!StringUtils.isEmpty(prompt.template())) { + EntityConstraintHandler + .handle(() -> createPromptVersionFromPromptRequest(prompt, createdPrompt, workspaceId)) + .withRetry(3, this::newVersionConflict); + } + + return createdPrompt; + } + + private PromptVersion createPromptVersionFromPromptRequest(Prompt prompt, Prompt createdPrompt, + String workspaceId) { + log.info("Creating prompt version for prompt id '{}'", createdPrompt.id()); + + var createdVersion = transactionTemplate.inTransaction(WRITE, handle -> { + PromptVersionDAO promptVersionDAO = handle.attach(PromptVersionDAO.class); + + UUID versionId = idGenerator.generateId(); + PromptVersion promptVersion = PromptVersion.builder() + .id(versionId) + .promptId(createdPrompt.id()) + .commit(CommitUtils.getCommit(versionId)) + .template(prompt.template()) + .createdBy(createdPrompt.createdBy()) + .build(); + + promptVersionDAO.save(workspaceId, promptVersion); + + return promptVersionDAO.findById(versionId, workspaceId); + }); + + log.info("Created Prompt version for prompt id '{}'", createdPrompt.id()); + + return createdVersion; + } + + private Prompt savePrompt(String workspaceId, Prompt newPrompt) { + return transactionTemplate.inTransaction(WRITE, handle -> { + PromptDAO promptDAO = handle.attach(PromptDAO.class); + + promptDAO.save(workspaceId, newPrompt); + + return promptDAO.findById(newPrompt.id(), workspaceId); + }); + } + + private EntityAlreadyExistsException newConflict(String alreadyExists) { + log.info(alreadyExists); + return new EntityAlreadyExistsException(new ErrorMessage(alreadyExists)); + } + + private EntityAlreadyExistsException newVersionConflict() { + return newConflict(VERSION_ALREADY_EXISTS); + } + + private EntityAlreadyExistsException newPromptConflict() { + return newConflict(ALREADY_EXISTS); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java new file mode 100644 index 000000000..f1f46fb00 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java @@ -0,0 +1,25 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.PromptVersion; +import com.comet.opik.infrastructure.db.UUIDArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.customizer.BindMethods; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +import java.util.UUID; + +@RegisterConstructorMapper(PromptVersion.class) +@RegisterArgumentFactory(UUIDArgumentFactory.class) +interface PromptVersionDAO { + + @SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, created_by, workspace_id) " + + "VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.createdBy, :workspace_id)") + void save(@Bind("workspace_id") String workspaceId, @BindMethods("bean") PromptVersion prompt); + + @SqlQuery("SELECT * FROM prompt_versions WHERE id = :id AND workspace_id = :workspace_id") + PromptVersion findById(@Bind("id") UUID id, @Bind("workspace_id") String workspaceId); + +} diff --git a/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql new file mode 100644 index 000000000..021459eb5 --- /dev/null +++ b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql @@ -0,0 +1,6 @@ +--liquibase formatted sql +--changeset thiagohora:increate_prompt_version_commit_length + +ALTER TABLE prompt_versions MODIFY COLUMN commit VARCHAR(8); + +--rollback ALTER TABLE prompt_versions MODIFY COLUMN commit VARCHAR(7); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java index ca55c88d1..73b585a62 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java @@ -51,7 +51,6 @@ class DatasetEventListenerTest { private static final String BASE_RESOURCE_URI = "%s/v1/private/datasets"; private static final String EXPERIMENT_RESOURCE_URI = "%s/v1/private/experiments"; - private static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); private static final String WORKSPACE_ID = UUID.randomUUID().toString(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java index 8f122ef48..37e907ff6 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java @@ -47,7 +47,7 @@ import static org.assertj.core.api.Assertions.within; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -@DisplayName("Dataset Event Listener") +@DisplayName("Dataset Experiments E2E Test") class DatasetExperimentE2ETest { private static final String BASE_RESOURCE_URI = "%s/v1/private/datasets"; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index e5c2b9224..b3f0910bb 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -136,7 +136,8 @@ class DatasetsResourceTest { public static final String[] IGNORED_FIELDS_DATA_ITEM = {"createdAt", "lastUpdatedAt", "experimentItems", "createdBy", "lastUpdatedBy"}; public static final String[] DATASET_IGNORED_FIELDS = {"id", "createdAt", "lastUpdatedAt", "createdBy", - "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "lastCreatedExperimentAt", "datasetItemsCount"}; + "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "lastCreatedExperimentAt", + "datasetItemsCount"}; public static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java new file mode 100644 index 000000000..ec87ae95c --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -0,0 +1,317 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.comet.opik.api.Prompt; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; +import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.MigrationUtils; +import com.comet.opik.api.resources.utils.MySQLContainerUtils; +import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; +import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.podam.PodamFactoryUtils; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.redis.testcontainers.RedisContainer; +import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.testcontainers.clickhouse.ClickHouseContainer; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.lifecycle.Startables; +import ru.vyarus.dropwizard.guice.test.ClientSupport; +import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; +import uk.co.jemos.podam.api.PodamFactory; + +import java.sql.SQLException; +import java.util.List; +import java.util.UUID; +import java.util.stream.Stream; + +import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; +import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; +import static com.comet.opik.infrastructure.auth.RequestContext.SESSION_COOKIE; +import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; +import static com.comet.opik.infrastructure.auth.TestHttpClientUtils.UNAUTHORIZED_RESPONSE; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.matching; +import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath; +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +@Testcontainers(parallel = true) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@DisplayName("Prompt Resource Test") +class PromptResourceTest { + + private static final String RESOURCE_PATH = "%s/v1/private/prompts"; + + private static final String API_KEY = UUID.randomUUID().toString(); + private static final String USER = UUID.randomUUID().toString(); + private static final String WORKSPACE_ID = UUID.randomUUID().toString(); + private static final String TEST_WORKSPACE = UUID.randomUUID().toString(); + + private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); + private static final ClickHouseContainer CLICKHOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer(); + private static final MySQLContainer MYSQL = MySQLContainerUtils.newMySQLContainer(); + + @RegisterExtension + private static final TestDropwizardAppExtension app; + + private static final WireMockUtils.WireMockRuntime wireMock; + + static { + Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join(); + wireMock = WireMockUtils.startWireMock(); + + DatabaseAnalyticsFactory databaseAnalyticsFactory = ClickHouseContainerUtils + .newDatabaseAnalyticsFactory(CLICKHOUSE_CONTAINER, DATABASE_NAME); + + app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + MYSQL.getJdbcUrl(), databaseAnalyticsFactory, wireMock.runtimeInfo(), REDIS.getRedisURI()); + } + + private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); + + private String baseURI; + private ClientSupport client; + + @BeforeAll + void setUpAll(ClientSupport client, Jdbi jdbi) throws SQLException { + + MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); + + try (var connection = CLICKHOUSE_CONTAINER.createConnection("")) { + MigrationUtils.runDbMigration(connection, CLICKHOUSE_CHANGELOG_FILE, + ClickHouseContainerUtils.migrationParameters()); + } + + this.baseURI = "http://localhost:%d".formatted(client.getPort()); + this.client = client; + + ClientSupportUtils.config(client); + + mockTargetWorkspace(API_KEY, TEST_WORKSPACE, WORKSPACE_ID); + } + + private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { + AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); + } + + @AfterAll + void tearDownAll() { + wireMock.server().stop(); + } + + @Nested + @DisplayName("Api Key Authentication:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class ApiKey { + + private final String fakeApikey = UUID.randomUUID().toString(); + private final String okApikey = UUID.randomUUID().toString(); + + Stream credentials() { + return Stream.of( + arguments(okApikey, true), + arguments(fakeApikey, false), + arguments("", false)); + } + + @BeforeEach + void setUp() { + + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth")) + .withHeader(HttpHeaders.AUTHORIZATION, equalTo(fakeApikey)) + .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) + .willReturn(WireMock.unauthorized())); + + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth")) + .withHeader(HttpHeaders.AUTHORIZATION, equalTo("")) + .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) + .willReturn(WireMock.unauthorized())); + } + + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("create prompt: when api key is present, then return proper response") + void createPrompt__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean success) { + + var prompt = factory.manufacturePojo(Prompt.class); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(okApikey, workspaceName, WORKSPACE_ID); + + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.entity(prompt, MediaType.APPLICATION_JSON_TYPE))) { + + if (success) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); + assertThat(actualResponse.hasEntity()).isFalse(); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(401); + assertThat(actualResponse.hasEntity()).isTrue(); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + } + + } + + @Nested + @DisplayName("Session Token Authentication:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class SessionTokenCookie { + + private final String sessionToken = UUID.randomUUID().toString(); + private final String fakeSessionToken = UUID.randomUUID().toString(); + + Stream credentials() { + return Stream.of( + arguments(sessionToken, true, "OK_" + UUID.randomUUID()), + arguments(fakeSessionToken, false, UUID.randomUUID().toString())); + } + + @BeforeAll + void setUp() { + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth-session")) + .withCookie(SESSION_COOKIE, equalTo(sessionToken)) + .withRequestBody(matchingJsonPath("$.workspaceName", matching("OK_.+"))) + .willReturn(okJson(AuthTestUtils.newWorkspaceAuthResponse(USER, WORKSPACE_ID)))); + + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth-session")) + .withCookie(SESSION_COOKIE, equalTo(fakeSessionToken)) + .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) + .willReturn(WireMock.unauthorized())); + } + + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("create prompt: when session token is present, then return proper response") + void createPrompt__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, boolean success, + String workspaceName) { + var prompt = factory.manufacturePojo(Prompt.class); + + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)).request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.entity(prompt, MediaType.APPLICATION_JSON_TYPE))) { + + if (success) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); + assertThat(actualResponse.hasEntity()).isFalse(); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(401); + assertThat(actualResponse.hasEntity()).isTrue(); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + } + } + + private UUID createPrompt(Prompt prompt, String apiKey, String workspaceName) { + try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .post(Entity.json(prompt))) { + + assertThat(response.getStatus()).isEqualTo(201); + + return TestUtils.getIdFromLocation(response.getLocation()); + } + } + + @Nested + @DisplayName("Create Prompt") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class CreatePrompt { + + @Test + @DisplayName("Should create prompt") + void shouldCreatePrompt() { + + var prompt = factory.manufacturePojo(Prompt.class); + + var promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); + + assertThat(promptId).isNotNull(); + } + + @ParameterizedTest + @MethodSource + @DisplayName("when prompt state is invalid, then return conflict") + void when__promptIsInvalid__thenReturnError(Prompt prompt, int expectedStatusCode, Object expectedBody, + Class expectedResponseClass) { + + try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, API_KEY) + .header(RequestContext.WORKSPACE_HEADER, TEST_WORKSPACE) + .post(Entity.json(prompt))) { + + assertThat(response.getStatus()).isEqualTo(expectedStatusCode); + + var actualBody = response.readEntity(expectedResponseClass); + + assertThat(actualBody).isEqualTo(expectedBody); + } + } + + Stream when__promptIsInvalid__thenReturnError() { + Prompt prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .id(UUID.randomUUID()) + .build(); + + Prompt duplicatedPrompt = factory.manufacturePojo(Prompt.class); + createPrompt(duplicatedPrompt, API_KEY, TEST_WORKSPACE); + + return Stream.of( + Arguments.of(prompt, 400, + new ErrorMessage(List.of("prompt id must be a version 7 UUID")), + ErrorMessage.class), + Arguments.of(duplicatedPrompt.toBuilder().name(UUID.randomUUID().toString()).build(), 409, + new io.dropwizard.jersey.errors.ErrorMessage("Prompt id or name already exists"), + io.dropwizard.jersey.errors.ErrorMessage.class), + Arguments.of(duplicatedPrompt.toBuilder().id(factory.manufacturePojo(UUID.class)).build(), 409, + new io.dropwizard.jersey.errors.ErrorMessage("Prompt id or name already exists"), + io.dropwizard.jersey.errors.ErrorMessage.class), + Arguments.of(factory.manufacturePojo(Prompt.class).toBuilder().description("").build(), 422, + new ErrorMessage(List.of("description must not be blank")), + ErrorMessage.class), + Arguments.of(factory.manufacturePojo(Prompt.class).toBuilder().name("").build(), 422, + new ErrorMessage(List.of("name must not be blank")), ErrorMessage.class)); + } + } + +} \ No newline at end of file diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java new file mode 100644 index 000000000..5a19c88d7 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java @@ -0,0 +1,88 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.error.EntityAlreadyExistsException; +import io.dropwizard.jersey.errors.ErrorMessage; +import org.jdbi.v3.core.statement.StatementContext; +import org.jdbi.v3.core.statement.UnableToExecuteStatementException; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.sql.SQLIntegrityConstraintViolationException; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class EntityConstraintHandlerTest { + + private static final Supplier ENTITY_ALREADY_EXISTS = () -> new EntityAlreadyExistsException( + new ErrorMessage(409, "Entity already exists")); + + @Test + void testWithError() { + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> { + throwDuplicateEntryException(); + return null; + }); + + assertThrows(EntityAlreadyExistsException.class, () -> handler.withError(ENTITY_ALREADY_EXISTS)); + } + + private static void throwDuplicateEntryException() { + throw new UnableToExecuteStatementException(new SQLIntegrityConstraintViolationException( + "Duplicate entry '1' for key 'PRIMARY'"), Mockito.mock(StatementContext.class)); + } + + @Test + void testWithRetrySuccess() { + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> "Success"); + + assertEquals("Success", handler.withRetry(3, ENTITY_ALREADY_EXISTS)); + } + + @Test + void testWithRetryFailure() { + EntityConstraintHandler.EntityConstraintAction action = Mockito + .spy(new EntityConstraintHandler.EntityConstraintAction() { + @Override + public String execute() { + throwDuplicateEntryException(); + return ""; + } + }); + + EntityConstraintHandler handler = EntityConstraintHandler.handle(action); + + final int NUM_OF_RETRIES = 3; + + assertThrows(EntityAlreadyExistsException.class, + () -> handler.withRetry(NUM_OF_RETRIES, ENTITY_ALREADY_EXISTS)); + Mockito.verify(action, Mockito.times(NUM_OF_RETRIES + 1)).execute(); + } + + @Test + void testWithRetryExhausted() { + EntityConstraintHandler.EntityConstraintAction action = Mockito + .spy(new EntityConstraintHandler.EntityConstraintAction() { + @Override + public String execute() { + throwDuplicateEntryException(); + return ""; + } + }); + + EntityConstraintHandler handler = EntityConstraintHandler.handle(action); + + assertThrows(EntityAlreadyExistsException.class, () -> handler.withRetry(1, ENTITY_ALREADY_EXISTS)); + Mockito.verify(action, Mockito.times(2)).execute(); + } + + @Test + void testWithRetryNonConstraintViolation() { + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> { + throw new UnableToExecuteStatementException(new RuntimeException(), Mockito.mock(StatementContext.class)); + }); + + assertThrows(UnableToExecuteStatementException.class, () -> handler.withRetry(3, ENTITY_ALREADY_EXISTS)); + } +} From 39361a8174a7428cf195942134b6a6cfb7d0b027 Mon Sep 17 00:00:00 2001 From: Thiago dos Santos Hora Date: Tue, 5 Nov 2024 10:39:47 +0100 Subject: [PATCH 3/3] [OPIK-309] Expose Prompt library API contracts (#532) * [OPIK-309] Create prompt endpoint * [OPIK-309] Expose API contracts * [OPIK-309] Expose API contracts * Add logic to create first version when specified --- .../comet/opik/api/CreatePromptVersion.java | 17 ++ .../main/java/com/comet/opik/api/Prompt.java | 37 ++- .../com/comet/opik/api/PromptVersion.java | 2 +- .../comet/opik/api/PromptVersionRetrieve.java | 13 + .../api/resources/v1/priv/PromptResource.java | 260 ++++++++++++++++++ .../comet/opik/domain/CommitGenerator.java | 14 + .../com/comet/opik/domain/PromptService.java | 1 - .../resources/v1/priv/PromptResourceTest.java | 1 + 8 files changed, 331 insertions(+), 14 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java new file mode 100644 index 000000000..04064916c --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java @@ -0,0 +1,17 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.NotNull; +import lombok.Builder; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record CreatePromptVersion(@JsonView( { + PromptVersion.View.Detail.class}) @NotBlank String name, + @JsonView({PromptVersion.View.Detail.class}) @NotNull PromptVersion version){ +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java index a726fb4f4..f598ffa08 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java @@ -21,16 +21,25 @@ @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public record Prompt( @JsonView( { - Prompt.View.Public.class, Prompt.View.Write.class}) UUID id, - @JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name, + Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) UUID id, + @JsonView({Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) @NotBlank String name, @JsonView({Prompt.View.Public.class, - Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, + Prompt.View.Write.class, + Prompt.View.Detail.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, @JsonView({ Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){ + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Long versionCount, + @JsonView({ + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PromptVersion latestVersion){ public static class View { public static class Write { @@ -38,14 +47,18 @@ public static class Write { public static class Public { } - } + public static class Detail { + } + } + + @Builder public record PromptPage( @JsonView( { - Project.View.Public.class}) int page, - @JsonView({Project.View.Public.class}) int size, - @JsonView({Project.View.Public.class}) long total, - @JsonView({Project.View.Public.class}) List content) + Prompt.View.Public.class}) int page, + @JsonView({Prompt.View.Public.class}) int size, + @JsonView({Prompt.View.Public.class}) long total, + @JsonView({Prompt.View.Public.class}) List content) implements Page{ diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java index cef10c243..daecd65ef 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java @@ -55,4 +55,4 @@ public static PromptVersion.PromptVersionPage empty(int page) { return new PromptVersion.PromptVersionPage(page, 0, 0, List.of()); } } -} \ No newline at end of file +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java new file mode 100644 index 000000000..60ae9a086 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java @@ -0,0 +1,13 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record PromptVersionRetrieve(@NotBlank String name, String commit) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java index 249bcf623..d1103a77d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java @@ -1,8 +1,14 @@ package com.comet.opik.api.resources.v1.priv; import com.codahale.metrics.annotation.Timed; +import com.comet.opik.api.CreatePromptVersion; import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptVersion; +import com.comet.opik.api.PromptVersionRetrieve; import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.api.Prompt; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.domain.IdGenerator; import com.comet.opik.domain.PromptService; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.ratelimit.RateLimited; @@ -13,6 +19,21 @@ import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.validation.Valid; +import jakarta.validation.constraints.Min; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.PUT; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; import jakarta.inject.Inject; import jakarta.inject.Provider; import jakarta.validation.Valid; @@ -28,16 +49,23 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import java.time.Instant; +import java.util.Set; +import java.util.UUID; +import java.util.stream.IntStream; + @Path("/v1/private/prompts") @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) @Timed @Slf4j @RequiredArgsConstructor(onConstructor_ = @Inject) +@Tag(name = "Prompts", description = "Prompt resources") public class PromptResource { private final @NonNull Provider requestContext; private final @NonNull PromptService promptService; + private final @NonNull IdGenerator idGenerator; @POST @Operation(operationId = "createPrompt", summary = "Create prompt", description = "Create prompt", responses = { @@ -65,4 +93,236 @@ public Response createPrompt( return Response.created(resourceUri).build(); } + @GET + @Operation(operationId = "getPrompts", summary = "Get prompts", description = "Get prompts", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = Prompt.PromptPage.class))), + }) + @JsonView({Prompt.View.Public.class}) + public Response getPrompts( + @QueryParam("page") @Min(1) @DefaultValue("1") int page, + @QueryParam("size") @Min(1) @DefaultValue("10") int size, + @QueryParam("name") String name) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompts by name '{}' on workspace_id '{}'", name, workspaceId); + var promptPage = Prompt.PromptPage.builder() + .page(page) + .size(5) + .total(5) + .content(IntStream.range(0, 5).mapToObj(i -> generatePrompt()).toList()) + .build(); + log.info("Got prompts by name '{}', count '{}' on workspace_id '{}'", name, promptPage.size(), workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptPage).build(); + } + + private Prompt generatePrompt() { + return Prompt.builder() + .id(idGenerator.generateId()) + .name("Prompt 1") + .description("Description 1") + .createdAt(Instant.now()) + .createdBy("User 1") + .lastUpdatedAt(Instant.now()) + .lastUpdatedBy("User 1") + .latestVersion(generatePromptVersion()) + .build(); + } + + @GET + @Path("{id}") + @Operation(operationId = "getPromptById", summary = "Get prompt by id", description = "Get prompt by id", responses = { + @ApiResponse(responseCode = "200", description = "Prompt resource", content = @Content(schema = @Schema(implementation = Prompt.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({Prompt.View.Detail.class}) + public Response getPromptById(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + Prompt prompt = generatePrompt(); + + log.info("Got prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(prompt).build(); + } + + @PUT + @Path("{id}") + @Operation(operationId = "updatePrompt", summary = "Update prompt", description = "Update prompt", responses = { + @ApiResponse(responseCode = "204", description = "No content"), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @RateLimited + public Response updatePrompt( + @PathParam("id") UUID id, + @RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Updating prompt with id '{}' on workspace_id '{}'", id, workspaceId); + + log.info("Updated prompt with id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).build(); + } + + @DELETE + @Path("{id}") + @Operation(operationId = "deletePrompt", summary = "Delete prompt", description = "Delete prompt", responses = { + @ApiResponse(responseCode = "204", description = "No content") + }) + public Response deletePrompt(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Deleting prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + log.info("Deleted prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).build(); + } + + @POST + @Path("/versions") + @Operation(operationId = "createPromptVersion", summary = "Create prompt version", description = "Create prompt version", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))) + }) + @RateLimited + @JsonView({PromptVersion.View.Detail.class}) + public Response createPromptVersion( + @RequestBody(content = @Content(schema = @Schema(implementation = CreatePromptVersion.class))) @JsonView({ + PromptVersion.View.Detail.class}) @Valid CreatePromptVersion promptVersion) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Creating prompt version commit '{}' on workspace_id '{}'", promptVersion.version().commit(), + workspaceId); + + UUID id = idGenerator.generateId(); + log.info("Created prompt version commit '{}' with id '{}' on workspace_id '{}'", + promptVersion.version().commit(), id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED) + .entity(generatePromptVersion(promptVersion, id)) + .build(); + } + + private PromptVersion generatePromptVersion(CreatePromptVersion promptVersion, UUID id) { + return PromptVersion.builder() + .id(id) + .commit(promptVersion.version().commit() == null + ? id.toString().substring(id.toString().length() - 7) + : promptVersion.version().commit()) + .template(promptVersion.version().template()) + .variables( + Set.of("user_message")) + .createdAt(Instant.now()) + .createdBy("User 1") + .build(); + } + + @GET + @Path("/{id}/versions") + @Operation(operationId = "getPromptVersions", summary = "Get prompt versions", description = "Get prompt versions", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.PromptVersionPage.class))), + }) + @JsonView({PromptVersion.View.Public.class}) + public Response getPromptVersions(@PathParam("id") UUID id, + @QueryParam("page") @Min(1) @DefaultValue("1") int page, + @QueryParam("size") @Min(1) @DefaultValue("10") int size) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt versions by id '{}' on workspace_id '{}'", id, workspaceId); + + PromptVersion.PromptVersionPage promptVersionPage = PromptVersion.PromptVersionPage.builder() + .page(1) + .size(5) + .total(5) + .content(IntStream.range(0, 5).mapToObj(i -> generatePromptVersion()).toList()) + .build(); + + log.info("Got prompt versions by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersionPage).build(); + } + + @GET + @Path("/{id}/versions/{versionId}") + @Operation(operationId = "getPromptVersionById", summary = "Get prompt version by id", description = "Get prompt version by id", responses = { + @ApiResponse(responseCode = "200", description = "Prompt version resource", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({PromptVersion.View.Detail.class}) + public Response getPromptVersionById(@PathParam("id") UUID id, @PathParam("versionId") UUID versionId) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId); + + PromptVersion promptVersion = generatePromptVersion().toBuilder() + .id(versionId) + .commit(versionId.toString().substring(versionId.toString().length() - 7)) + .build(); + + log.info("Got prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersion).build(); + } + + @POST + @Path("/prompts/versions/retrieve") + @Operation(operationId = "retrievePromptVersion", summary = "Retrieve prompt version", description = "Retrieve prompt version", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({PromptVersion.View.Detail.class}) + public Response retrievePromptVersion( + @RequestBody(content = @Content(schema = @Schema(implementation = PromptVersionRetrieve.class))) @Valid PromptVersionRetrieve retrieve) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Retrieving prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(), + retrieve.commit(), workspaceId); + + UUID id = idGenerator.generateId(); + + log.info("Retrieved prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(), + retrieve.commit(), workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED) + .entity(generatePromptVersion().toBuilder() + .id(id) + .commit(retrieve.commit() == null + ? id.toString().substring(id.toString().length() - 7) + : retrieve.commit()) + .build()) + .build(); + } + + private PromptVersion generatePromptVersion() { + var id = idGenerator.generateId(); + return PromptVersion.builder() + .id(id) + .commit(id.toString().substring(id.toString().length() - 7)) + .template("Hello %s, My question is ${user_message}".formatted(id)) + .variables( + Set.of("user_message")) + .createdAt(Instant.now()) + .createdBy("User 1") + .build(); + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java new file mode 100644 index 000000000..d738e967d --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java @@ -0,0 +1,14 @@ +package com.comet.opik.domain; + +import lombok.NonNull; +import lombok.experimental.UtilityClass; + +import java.util.UUID; + +@UtilityClass +class CommitGenerator { + + public String generateCommit(@NonNull UUID id) { + return id.toString().substring(id.toString().length() - 8); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java index 9207435bf..9cec53df1 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -22,7 +22,6 @@ @ImplementedBy(PromptServiceImpl.class) public interface PromptService { Prompt create(Prompt prompt); - } @Singleton diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java index ec87ae95c..db3df102c 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -159,6 +159,7 @@ void setUp() { void createPrompt__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean success) { var prompt = factory.manufacturePojo(Prompt.class); + String workspaceName = UUID.randomUUID().toString(); mockTargetWorkspace(okApikey, workspaceName, WORKSPACE_ID);