From 6e5bb09a11f4da4fdb7f19ad23724c69560c1b5d Mon Sep 17 00:00:00 2001 From: Thiago dos Santos Hora Date: Wed, 16 Oct 2024 11:22:39 +0200 Subject: [PATCH] [OPIK-218] Remove limitations in dataset items (#369) * [OPIK-218] Remove limitations in dataset items * Add tests * Address PR review * Fixing contract * Fix fixture generation * Deprecate metadata * Add PR review comments * Remove dead code --- .../java/com/comet/opik/api/DatasetItem.java | 21 +- .../resources/v1/priv/DatasetsResource.java | 3 +- .../validate/DatasetItemInputValidation.java | 22 ++ .../validate/DatasetItemInputValidator.java | 43 +++ .../com/comet/opik/domain/DatasetItemDAO.java | 169 +++++++++--- .../comet/opik/domain/DatasetItemService.java | 2 +- .../db/DatabaseAnalyticsModule.java | 1 - .../instrumentation/InstrumentAsyncUtils.java | 13 +- .../000004_add_input_data_to_dataset_item.sql | 7 + .../api/resources/utils/AuthTestUtils.java | 2 +- .../v1/priv/DatasetsResourceTest.java | 244 ++++++++++++++---- .../resources/v1/priv/TracesResourceTest.java | 3 +- ...goricalFeedbackDetailTypeManufacturer.java | 4 +- .../DatasetItemTypeManufacturer.java | 18 ++ ...mericalFeedbackDetailTypeManufacturer.java | 4 +- 15 files changed, 446 insertions(+), 110 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/validate/DatasetItemInputValidation.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/validate/DatasetItemInputValidator.java create mode 100644 apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000004_add_input_data_to_dataset_item.sql diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItem.java b/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItem.java index 67303953a..e42fa8a74 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItem.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItem.java @@ -1,5 +1,6 @@ package com.comet.opik.api; +import com.comet.opik.api.validate.DatasetItemInputValidation; import com.comet.opik.api.validate.SourceValidation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonView; @@ -12,21 +13,29 @@ import java.time.Instant; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.UUID; @Builder(toBuilder = true) @JsonIgnoreProperties(ignoreUnknown = true) @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) @SourceValidation +@DatasetItemInputValidation public record DatasetItem( @JsonView( { DatasetItem.View.Public.class, DatasetItem.View.Write.class}) UUID id, - @JsonView({DatasetItem.View.Public.class, DatasetItem.View.Write.class}) @NotNull JsonNode input, - @JsonView({DatasetItem.View.Public.class, DatasetItem.View.Write.class}) JsonNode expectedOutput, - @JsonView({DatasetItem.View.Public.class, DatasetItem.View.Write.class}) JsonNode metadata, + @JsonView({DatasetItem.View.Public.class, + DatasetItem.View.Write.class}) @Schema(deprecated = true, description = "to be deprecated soon, please use data field") JsonNode input, + @JsonView({DatasetItem.View.Public.class, + DatasetItem.View.Write.class}) @Schema(deprecated = true, description = "to be deprecated soon, please use data field") JsonNode expectedOutput, + @JsonView({DatasetItem.View.Public.class, + DatasetItem.View.Write.class}) @Schema(deprecated = true, description = "to be deprecated soon, please use data field") JsonNode metadata, @JsonView({DatasetItem.View.Public.class, DatasetItem.View.Write.class}) UUID traceId, @JsonView({DatasetItem.View.Public.class, DatasetItem.View.Write.class}) UUID spanId, @JsonView({DatasetItem.View.Public.class, DatasetItem.View.Write.class}) @NotNull DatasetItemSource source, + @JsonView({DatasetItem.View.Public.class, + DatasetItem.View.Write.class}) Map data, @JsonView({ DatasetItem.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List experimentItems, @JsonView({DatasetItem.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, @@ -42,7 +51,11 @@ public record DatasetItemPage( DatasetItem.View.Public.class}) List content, @JsonView({DatasetItem.View.Public.class}) int page, @JsonView({DatasetItem.View.Public.class}) int size, - @JsonView({DatasetItem.View.Public.class}) long total) implements Page{ + @JsonView({DatasetItem.View.Public.class}) long total, + @JsonView({DatasetItem.View.Public.class}) Set columns) implements Page{ + + public record Column(String name, String type) { + } } public static class View { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java index 377a7d26b..e1164545b 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java @@ -301,7 +301,8 @@ public Response createDatasetItems( return item.toBuilder().id(idGenerator.generateId()).build(); } return item; - }).toList(); + }) + .toList(); String workspaceId = requestContext.get().getWorkspaceId(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/validate/DatasetItemInputValidation.java b/apps/opik-backend/src/main/java/com/comet/opik/api/validate/DatasetItemInputValidation.java new file mode 100644 index 000000000..f8cded414 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/validate/DatasetItemInputValidation.java @@ -0,0 +1,22 @@ +package com.comet.opik.api.validate; + +import jakarta.validation.Constraint; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.FIELD, ElementType.ANNOTATION_TYPE, ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +@Constraint(validatedBy = {DatasetItemInputValidator.class}) +@Documented +public @interface DatasetItemInputValidation { + + String message() default "must provide either input or data field"; + + Class[] groups() default {}; + + Class[] payload() default {}; +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/validate/DatasetItemInputValidator.java b/apps/opik-backend/src/main/java/com/comet/opik/api/validate/DatasetItemInputValidator.java new file mode 100644 index 000000000..637c94c61 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/validate/DatasetItemInputValidator.java @@ -0,0 +1,43 @@ +package com.comet.opik.api.validate; + +import com.comet.opik.api.DatasetItem; +import com.fasterxml.jackson.databind.JsonNode; +import jakarta.validation.ConstraintValidator; +import jakarta.validation.ConstraintValidatorContext; +import org.apache.commons.collections4.MapUtils; + +import java.util.Map; +import java.util.Optional; + +public class DatasetItemInputValidator implements ConstraintValidator { + + @Override + public boolean isValid(DatasetItem datasetItem, ConstraintValidatorContext context) { + boolean result = datasetItem.input() != null || MapUtils.isNotEmpty(datasetItem.data()); + + if (!result) { + context.disableDefaultConstraintViolation(); + context.buildConstraintViolationWithTemplate("must provide either input or data field") + .addPropertyNode("input") + .addConstraintViolation(); + } + + if (result && datasetItem.data() != null) { + Optional> error = datasetItem.data().entrySet() + .stream() + .filter(entry -> entry.getValue() == null) + .findAny(); + + if (error.isPresent()) { + context.disableDefaultConstraintViolation(); + context.buildConstraintViolationWithTemplate("field must not contain null key or value") + .addPropertyNode("data") + .addConstraintViolation(); + } + + result = error.isEmpty(); + } + + return result; + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java index 20317ad3c..a6a11f849 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java @@ -11,10 +11,12 @@ import com.comet.opik.infrastructure.db.TransactionTemplateAsync; import com.comet.opik.utils.JsonUtils; import com.fasterxml.jackson.databind.JsonNode; +import com.google.common.collect.Sets; import com.google.inject.ImplementedBy; import io.opentelemetry.instrumentation.annotations.WithSpan; import io.r2dbc.spi.Connection; import io.r2dbc.spi.Result; +import io.r2dbc.spi.Row; import io.r2dbc.spi.Statement; import jakarta.inject.Inject; import jakarta.inject.Singleton; @@ -32,12 +34,17 @@ import java.math.BigDecimal; import java.time.Instant; import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.stream.Collectors; import static com.comet.opik.api.DatasetItem.DatasetItemPage; +import static com.comet.opik.api.DatasetItem.DatasetItemPage.Column; import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux; import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.Segment; import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.endSegment; @@ -47,6 +54,8 @@ import static com.comet.opik.utils.TemplateUtils.QueryItem; import static com.comet.opik.utils.TemplateUtils.getQueryItemPlaceHolder; import static com.comet.opik.utils.ValidationUtils.CLICKHOUSE_FIXED_STRING_UUID_FIELD_NULL_VALUE; +import static java.util.function.Predicate.not; +import static java.util.stream.Collectors.toMap; @ImplementedBy(DatasetItemDAOImpl.class) public interface DatasetItemDAO { @@ -78,6 +87,7 @@ INSERT INTO dataset_items ( trace_id, span_id, input, + data, expected_output, metadata, created_at, @@ -94,6 +104,7 @@ INSERT INTO dataset_items ( :traceId, :spanId, :input, + :data, :expectedOutput, :metadata, now64(9), @@ -156,10 +167,12 @@ INSERT INTO dataset_items ( private static final String SELECT_DATASET_ITEMS_COUNT = """ SELECT - count(id) AS count + count(id) AS count, + arrayDistinct(arrayFlatten(groupArray(arrayMap(key -> (key, JSONType(data[key])), mapKeys(data))))) AS columns FROM ( SELECT - id + id, + data FROM dataset_items WHERE dataset_id = :datasetId AND workspace_id = :workspace_id @@ -174,10 +187,12 @@ INSERT INTO dataset_items ( */ private static final String SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_COUNT = """ SELECT - COUNT(DISTINCT di.id) AS count + COUNT(DISTINCT di.id) AS count, + arrayDistinct(arrayFlatten(groupArray(arrayMap(key -> (key, JSONType(di.data[key])), mapKeys(di.data))))) AS columns FROM ( SELECT - id + id, + data FROM dataset_items WHERE dataset_id = :datasetId AND workspace_id = :workspace_id @@ -273,6 +288,7 @@ AND id in ( di.id AS id, di.dataset_id AS dataset_id, di.input AS input, + di.data AS data, di.expected_output AS expected_output, di.metadata AS metadata, di.trace_id AS trace_id, @@ -391,6 +407,7 @@ LEFT JOIN ( di.id, di.dataset_id, di.input, + di.data, di.expected_output, di.metadata, di.trace_id, @@ -438,12 +455,27 @@ private Mono mapAndInsert( int i = 0; for (DatasetItem item : items) { + Map data = new HashMap<>(Optional.ofNullable(item.data()).orElse(Map.of())); + + if (!data.containsKey("input") && item.input() != null) { + data.put("input", item.input()); + } + + if (!data.containsKey("expected_output") && item.expectedOutput() != null) { + data.put("expected_output", item.expectedOutput()); + } + + if (!data.containsKey("metadata") && item.metadata() != null) { + data.put("metadata", item.metadata()); + } + statement.bind("id" + i, item.id()); statement.bind("datasetId" + i, datasetId); statement.bind("source" + i, item.source().getValue()); statement.bind("traceId" + i, getOrDefault(item.traceId())); statement.bind("spanId" + i, getOrDefault(item.spanId())); statement.bind("input" + i, getOrDefault(item.input())); + statement.bind("data" + i, getOrDefault(data)); statement.bind("expectedOutput" + i, getOrDefault(item.expectedOutput())); statement.bind("metadata" + i, getOrDefault(item.metadata())); statement.bind("createdBy" + i, userName); @@ -464,37 +496,78 @@ private String getOrDefault(JsonNode jsonNode) { return Optional.ofNullable(jsonNode).map(JsonNode::toString).orElse(""); } + private Map getOrDefault(Map data) { + return Optional.ofNullable(data) + .filter(not(Map::isEmpty)) + .stream() + .map(Map::entrySet) + .flatMap(Collection::stream) + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toString())) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + private String getOrDefault(UUID value) { return Optional.ofNullable(value).map(UUID::toString).orElse(""); } private Publisher mapItem(Result results) { - return results.map((row, rowMetadata) -> DatasetItem.builder() - .id(row.get("id", UUID.class)) - .input(Optional.ofNullable(row.get("input", String.class)) - .filter(s -> !s.isBlank()) - .map(JsonUtils::getJsonNodeFromString).orElse(null)) - .expectedOutput(Optional.ofNullable(row.get("expected_output", String.class)) - .filter(s -> !s.isBlank()) - .map(JsonUtils::getJsonNodeFromString).orElse(null)) - .metadata(Optional.ofNullable(row.get("metadata", String.class)) - .filter(s -> !s.isBlank()) - .map(JsonUtils::getJsonNodeFromString).orElse(null)) - .source(DatasetItemSource.fromString(row.get("source", String.class))) - .traceId(Optional.ofNullable(row.get("trace_id", String.class)) - .filter(s -> !s.isBlank()) - .map(UUID::fromString) - .orElse(null)) - .spanId(Optional.ofNullable(row.get("span_id", String.class)) - .filter(s -> !s.isBlank()) - .map(UUID::fromString) - .orElse(null)) - .experimentItems(getExperimentItems(row.get("experiment_items_array", List[].class))) - .lastUpdatedAt(row.get("last_updated_at", Instant.class)) - .createdAt(row.get("created_at", Instant.class)) - .createdBy(row.get("created_by", String.class)) - .lastUpdatedBy(row.get("last_updated_by", String.class)) - .build()); + return results.map((row, rowMetadata) -> { + + Map data = getData(row); + + JsonNode input = getJsonNode(row, data, "input"); + JsonNode expectedOutput = getJsonNode(row, data, "expected_output"); + JsonNode metadata = getJsonNode(row, data, "metadata"); + + return DatasetItem.builder() + .id(row.get("id", UUID.class)) + .input(input) + .data(data) + .expectedOutput(expectedOutput) + .metadata(metadata) + .source(DatasetItemSource.fromString(row.get("source", String.class))) + .traceId(Optional.ofNullable(row.get("trace_id", String.class)) + .filter(s -> !s.isBlank()) + .map(UUID::fromString) + .orElse(null)) + .spanId(Optional.ofNullable(row.get("span_id", String.class)) + .filter(s -> !s.isBlank()) + .map(UUID::fromString) + .orElse(null)) + .experimentItems(getExperimentItems(row.get("experiment_items_array", List[].class))) + .lastUpdatedAt(row.get("last_updated_at", Instant.class)) + .createdAt(row.get("created_at", Instant.class)) + .createdBy(row.get("created_by", String.class)) + .lastUpdatedBy(row.get("last_updated_by", String.class)) + .build(); + }); + } + + private Map getData(Row row) { + return Optional.ofNullable(row.get("data", Map.class)) + .filter(s -> !s.isEmpty()) + .map(value -> (Map) value) + .stream() + .map(Map::entrySet) + .flatMap(Collection::stream) + .map(entry -> Map.entry(entry.getKey(), JsonUtils.getJsonNodeFromString(entry.getValue()))) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + private JsonNode getJsonNode(Row row, Map data, String key) { + JsonNode json = null; + + if (data.containsKey(key)) { + json = data.get(key); + } + + if (json == null) { + json = Optional.ofNullable(row.get(key, String.class)) + .filter(s -> !s.isBlank()) + .map(JsonUtils::getJsonNodeFromString).orElse(null); + } + + return json; } private List getExperimentItems(List[] experimentItemsArrays) { @@ -658,12 +731,16 @@ public Mono getItems(@NonNull UUID datasetId, int page, int siz .bind("workspace_id", workspaceId) .execute()) .doFinally(signalType -> endSegment(segmentCount)) - .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, Long.class))) - .reduce(0L, Long::sum) - .flatMap(count -> { + .flatMap(this::mapCount) + .reduce((result1, result2) -> Map.entry(result1.getKey() + result2.getKey(), + Sets.union(result1.getValue(), result2.getValue()))) + .flatMap(result -> { Segment segment = startSegment("dataset_items", "Clickhouse", "select_dataset_items_page"); + long total = result.getKey(); + Set columns = result.getValue(); + return Flux.from(connection.createStatement(SELECT_DATASET_ITEMS) .bind("workspace_id", workspaceId) .bind("datasetId", datasetId) @@ -672,11 +749,21 @@ public Mono getItems(@NonNull UUID datasetId, int page, int siz .execute()) .flatMap(this::mapItem) .collectList() - .flatMap(items -> Mono.just(new DatasetItemPage(items, page, items.size(), count))) + .flatMap(items -> Mono + .just(new DatasetItemPage(items, page, items.size(), total, columns))) .doFinally(signalType -> endSegment(segment)); }))); } + private Publisher>> mapCount(Result result) { + return result.map((row, rowMetadata) -> Map.entry( + row.get(0, Long.class), + ((List>) row.get(1, List.class)) + .stream() + .map(columnArray -> new Column(columnArray.getFirst(), columnArray.get(1))) + .collect(Collectors.toSet()))); + } + private ST newFindTemplate(String query, DatasetItemSearchCriteria datasetItemSearchCriteria) { var template = new ST(query); @@ -727,9 +814,11 @@ public Mono getItems( return makeFluxContextAware(bindWorkspaceIdToFlux(statement)) .doFinally(signalType -> endSegment(segmentCount)) - .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, Long.class))) - .reduce(0L, Long::sum) - .flatMap(count -> { + .flatMap(this::mapCount) + .reduce((result1, result2) -> Map.entry(result1.getKey() + result2.getKey(), + Sets.union(result1.getValue(), result2.getValue()))) + .flatMap(result -> { + Segment segment = startSegment("dataset_items", "Clickhouse", "select_dataset_items_filters"); ST selectTemplate = newFindTemplate(SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS, @@ -744,11 +833,15 @@ public Mono getItems( bindSearchCriteria(datasetItemSearchCriteria, selectStatement); + Long total = result.getKey(); + Set columns = result.getValue(); + return makeFluxContextAware(bindWorkspaceIdToFlux(selectStatement)) .doFinally(signalType -> endSegment(segment)) .flatMap(this::mapItem) .collectList() - .flatMap(items -> Mono.just(new DatasetItemPage(items, page, items.size(), count))); + .flatMap(items -> Mono + .just(new DatasetItemPage(items, page, items.size(), total, columns))); }); }); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java index d69a09c56..704888045 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java @@ -107,7 +107,7 @@ public Mono get(@NonNull UUID id) { return dao.get(id) .switchIfEmpty(Mono.defer(() -> Mono.error(failWithNotFound("Dataset item not found")))); } - + @WithSpan public Flux getItems(@NonNull String workspaceId, @NonNull DatasetItemStreamRequest request) { log.info("Getting dataset items by '{}' on workspaceId '{}'", request, workspaceId); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/DatabaseAnalyticsModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/DatabaseAnalyticsModule.java index 6edfae177..5afa85317 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/DatabaseAnalyticsModule.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/DatabaseAnalyticsModule.java @@ -4,7 +4,6 @@ import com.comet.opik.infrastructure.OpikConfiguration; import com.google.inject.Provides; import io.opentelemetry.api.GlobalOpenTelemetry; -import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.instrumentation.r2dbc.v1_0.R2dbcTelemetry; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryOptions; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/instrumentation/InstrumentAsyncUtils.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/instrumentation/InstrumentAsyncUtils.java index c0d0007db..7d4a6d79f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/instrumentation/InstrumentAsyncUtils.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/instrumentation/InstrumentAsyncUtils.java @@ -13,18 +13,19 @@ @UtilityClass public class InstrumentAsyncUtils { - public record Segment(Scope scope, Span span) {} + public record Segment(Scope scope, Span span) { + } public static Segment startSegment(String segmentName, String product, String operationName) { Tracer tracer = GlobalOpenTelemetry.get().getTracer("com.comet.opik"); Span span = tracer - .spanBuilder("custom-reactive-%s".formatted(segmentName)) - .setParent(Context.current().with(Span.current())) - .startSpan() - .setAttribute("product", product) - .setAttribute("operation", operationName); + .spanBuilder("custom-reactive-%s".formatted(segmentName)) + .setParent(Context.current().with(Span.current())) + .startSpan() + .setAttribute("product", product) + .setAttribute("operation", operationName); return new Segment(span.makeCurrent(), span); } diff --git a/apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000004_add_input_data_to_dataset_item.sql b/apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000004_add_input_data_to_dataset_item.sql new file mode 100644 index 000000000..8c5300b8f --- /dev/null +++ b/apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000004_add_input_data_to_dataset_item.sql @@ -0,0 +1,7 @@ +--liquibase formatted sql +--changeset thiagohora:add_input_data_to_dataset_item + +ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.dataset_items + ADD COLUMN IF NOT EXISTS data Map(String, String) DEFAULT map(); + +--rollback ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.dataset_items DROP COLUMN data; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/AuthTestUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/AuthTestUtils.java index 80d04c816..60c1cc8db 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/AuthTestUtils.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/AuthTestUtils.java @@ -6,8 +6,8 @@ import static com.comet.opik.infrastructure.auth.RequestContext.SESSION_COOKIE; import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; -import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath; import static com.github.tomakehurst.wiremock.client.WireMock.matching; +import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath; import static com.github.tomakehurst.wiremock.client.WireMock.okJson; import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index e1b259fb1..e537c835f 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -38,7 +38,6 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.uuid.Generators; import com.fasterxml.uuid.impl.TimeBasedEpochGenerator; -import com.github.tomakehurst.wiremock.client.WireMock; import com.redis.testcontainers.RedisContainer; import jakarta.ws.rs.client.Entity; import jakarta.ws.rs.core.GenericType; @@ -47,6 +46,7 @@ import jakarta.ws.rs.core.Response; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.RandomStringUtils; +import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.client.ChunkedInput; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; @@ -73,9 +73,13 @@ import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; @@ -84,6 +88,7 @@ import java.util.stream.Stream; import static com.comet.opik.api.DatasetItem.DatasetItemPage; +import static com.comet.opik.api.DatasetItem.DatasetItemPage.Column; import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; import static com.comet.opik.api.resources.utils.WireMockUtils.WireMockRuntime; @@ -95,7 +100,12 @@ import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath; import static com.github.tomakehurst.wiremock.client.WireMock.okJson; import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.unauthorized; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.toMap; +import static java.util.stream.Collectors.toSet; +import static java.util.stream.Collectors.toUnmodifiableSet; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.params.provider.Arguments.arguments; @@ -114,8 +124,9 @@ class DatasetsResourceTest { public static final String[] IGNORED_FIELDS_LIST = {"feedbackScores", "createdAt", "lastUpdatedAt", "createdBy", "lastUpdatedBy"}; public static final String[] IGNORED_FIELDS_DATA_ITEM = {"createdAt", "lastUpdatedAt", "experimentItems", - "createdBy", - "lastUpdatedBy"}; + "createdBy", "lastUpdatedBy"}; + public static final String[] DATASET_IGNORED_FIELDS = {"id", "createdAt", "lastUpdatedAt", "createdBy", + "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "experimentCount"}; public static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); @@ -135,9 +146,6 @@ class DatasetsResourceTest { private static final WireMockRuntime wireMock; - public static final String[] DATASET_IGNORED_FIELDS = {"id", "createdAt", "lastUpdatedAt", "createdBy", - "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "experimentCount"}; - static { MYSQL.start(); CLICKHOUSE.start(); @@ -264,13 +272,13 @@ void setUp() { post(urlPathEqualTo("/opik/auth")) .withHeader(HttpHeaders.AUTHORIZATION, equalTo(fakeApikey)) .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) - .willReturn(WireMock.unauthorized())); + .willReturn(unauthorized())); wireMock.server().stubFor( post(urlPathEqualTo("/opik/auth")) .withHeader(HttpHeaders.AUTHORIZATION, equalTo("")) .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) - .willReturn(WireMock.unauthorized())); + .willReturn(unauthorized())); } @ParameterizedTest @@ -735,7 +743,7 @@ void setUp() { post(urlPathEqualTo("/opik/auth-session")) .withCookie(SESSION_COOKIE, equalTo(fakeSessionToken)) .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) - .willReturn(WireMock.unauthorized())); + .willReturn(unauthorized())); } Stream credentials() { @@ -1457,7 +1465,7 @@ void getDatasetById__whenDatasetHasExperimentsLinkedToIt__thenReturnDatasetWithE .toList(); var traceIdToScoresMap = Stream.concat(scores1.stream(), scores2.stream()) - .collect(Collectors.groupingBy(FeedbackScoreBatchItem::id)); + .collect(groupingBy(FeedbackScoreBatchItem::id)); // When storing the scores in batch, adding some more unrelated random ones var feedbackScoreBatch = factory.manufacturePojo(FeedbackScoreBatch.class); @@ -1486,7 +1494,7 @@ void getDatasetById__whenDatasetHasExperimentsLinkedToIt__thenReturnDatasetWithE .experimentItems(Stream.concat( experimentItemsBatch.experimentItems().stream(), experimentItems.stream()) - .collect(Collectors.toUnmodifiableSet())) + .collect(toUnmodifiableSet())) .build(); Instant beforeCreateExperimentItems = Instant.now(); @@ -1860,7 +1868,7 @@ void getDatasets__whenDatasetsHaveExperimentsLinkedToThem__thenReturnDatasetsWit .toList(); var traceIdToScoresMap = scores.stream() - .collect(Collectors.groupingBy(FeedbackScoreBatchItem::id)); + .collect(groupingBy(FeedbackScoreBatchItem::id)); // When storing the scores in batch, adding some more unrelated random ones var feedbackScoreBatch = factory.manufacturePojo(FeedbackScoreBatch.class); @@ -1884,7 +1892,7 @@ void getDatasets__whenDatasetsHaveExperimentsLinkedToThem__thenReturnDatasetsWit .map(FeedbackScoreMapper.INSTANCE::toFeedbackScore) .toList()) .build()) - .collect(Collectors.toSet()); + .collect(toSet()); var experimentItemsBatch = factory.manufacturePojo(ExperimentItemsBatch.class).toBuilder() .experimentItems(experimentItems) @@ -2209,7 +2217,7 @@ void createDatasetItem__whenDatasetItemIsNotValid__thenReturn422(DatasetItemBatc .request() .header(HttpHeaders.AUTHORIZATION, API_KEY) .header(WORKSPACE_HEADER, TEST_WORKSPACE) - .put(Entity.entity(batch, MediaType.APPLICATION_JSON_TYPE))) { + .put(Entity.json(batch))) { assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); assertThat(actualResponse.hasEntity()).isTrue(); @@ -2355,9 +2363,17 @@ public Stream invalidDatasetItems() { arguments(factory.manufacturePojo(DatasetItemBatch.class).toBuilder() .items(List.of(factory.manufacturePojo(DatasetItem.class).toBuilder() .input(null) + .data(null) + .build())) + .build(), + "items[0].input must provide either input or data field"), + arguments(factory.manufacturePojo(DatasetItemBatch.class).toBuilder() + .items(List.of(factory.manufacturePojo(DatasetItem.class).toBuilder() + .input(null) + .data(Map.of()) .build())) .build(), - "items[0].input must not be null"), + "items[0].input must provide either input or data field"), arguments(factory.manufacturePojo(DatasetItemBatch.class).toBuilder() .items(List.of(factory.manufacturePojo(DatasetItem.class).toBuilder() .source(null) @@ -2549,6 +2565,40 @@ void createDatasetItem__whenDatasetItemWorkspaceAndSpanWorkspaceDoesNotMatch__th } } + @Test + @DisplayName("when data is null, the accept the request") + void create__whenDataIsNull__thenAcceptTheRequest() { + var item = factory.manufacturePojo(DatasetItem.class).toBuilder() + .data(null) + .build(); + + var batch = factory.manufacturePojo(DatasetItemBatch.class).toBuilder() + .items(List.of(item)) + .datasetId(null) + .build(); + + putAndAssert(batch, TEST_WORKSPACE, API_KEY); + + getItemAndAssert(item, TEST_WORKSPACE, API_KEY); + } + + @Test + @DisplayName("when input is null but data is present, the accept the request") + void create__whenInputIsNullButDataIsPresent__thenAcceptTheRequest() { + var item = factory.manufacturePojo(DatasetItem.class).toBuilder() + .input(null) + .build(); + + var batch = factory.manufacturePojo(DatasetItemBatch.class).toBuilder() + .items(List.of(item)) + .datasetId(null) + .build(); + + putAndAssert(batch, TEST_WORKSPACE, API_KEY); + + getItemAndAssert(item, TEST_WORKSPACE, API_KEY); + } + } private UUID createTrace(Trace trace, String apiKey, String workspaceName) { @@ -2654,9 +2704,7 @@ void streamDataItems__whenStreamingDatasetItems__thenReturnItemsSortedByCreatedD List actualItems = getStreamedItems(response); - assertThat(actualItems) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_DATA_ITEM) - .isEqualTo(items.reversed()); + assertPage(items.reversed(), actualItems); } } @@ -2693,9 +2741,7 @@ void streamDataItems__whenStreamingDatasetItemsWithFilters__thenReturnItemsSorte List actualItems = getStreamedItems(response); - assertThat(actualItems) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_DATA_ITEM) - .isEqualTo(items.reversed().subList(2, 5)); + assertPage(items.reversed().subList(2, 5), actualItems); } } @@ -2737,9 +2783,7 @@ void streamDataItems__whenStreamingHasMaxSize__thenReturnItemsSortedByCreatedDat List actualItems = getStreamedItems(response); - assertThat(actualItems) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_DATA_ITEM) - .isEqualTo(expectedFirstPage); + assertPage(expectedFirstPage, actualItems); } streamRequest = DatasetItemStreamRequest.builder() @@ -2760,9 +2804,7 @@ void streamDataItems__whenStreamingHasMaxSize__thenReturnItemsSortedByCreatedDat List actualItems = getStreamedItems(response); - assertThat(actualItems) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_DATA_ITEM) - .isEqualTo(items.reversed().subList(500, 1000)); + assertPage(items.reversed().subList(500, 1000), actualItems); } } } @@ -2779,15 +2821,52 @@ private void getItemAndAssert(DatasetItem expectedDatasetItem, String workspaceN var actualEntity = actualResponse.readEntity(DatasetItem.class); assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + Map data = Optional.ofNullable(expectedDatasetItem.data()) + .orElse(Map.of()); + + expectedDatasetItem = mergeInputMap(expectedDatasetItem, data); + assertThat(actualEntity.id()).isEqualTo(expectedDatasetItem.id()); assertThat(actualEntity).usingRecursiveComparison() - .ignoringFields("createdAt", "lastUpdatedAt", "experimentItems", "createdBy", "lastUpdatedBy") + .ignoringFields(IGNORED_FIELDS_DATA_ITEM) .isEqualTo(expectedDatasetItem); assertThat(actualEntity.createdAt()).isInThePast(); assertThat(actualEntity.lastUpdatedAt()).isInThePast(); } + private static DatasetItem mergeInputMap(DatasetItem expectedDatasetItem, + Map data) { + + Map newMap = new HashMap<>(); + + if (expectedDatasetItem.expectedOutput() != null) { + newMap.put("expected_output", expectedDatasetItem.expectedOutput()); + } + + if (expectedDatasetItem.input() != null) { + newMap.put("input", expectedDatasetItem.input()); + } + + if (expectedDatasetItem.metadata() != null) { + newMap.put("metadata", expectedDatasetItem.metadata()); + } + + Map mergedMap = Stream + .concat(data.entrySet().stream(), newMap.entrySet().stream()) + .collect(toMap( + Map.Entry::getKey, + Map.Entry::getValue, + (v1, v2) -> v2 // In case of conflict, use the value from map2 + )); + + expectedDatasetItem = expectedDatasetItem.toBuilder() + .data(mergedMap) + .build(); + + return expectedDatasetItem; + } + @Nested @DisplayName("Delete items:") @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -2902,6 +2981,13 @@ void getDatasetItemsByDatasetId() { .datasetId(datasetId) .build(); + List> data = batch.items() + .stream() + .map(DatasetItem::data) + .toList(); + + Set columns = addDeprecatedFields(data); + putAndAssert(batch, TEST_WORKSPACE, API_KEY); try (var actualResponse = client.target(BASE_RESOURCE_URI.formatted(baseURI)) @@ -2920,12 +3006,9 @@ void getDatasetItemsByDatasetId() { assertThat(actualEntity.content()).hasSize(items.size()); assertThat(actualEntity.page()).isEqualTo(1); assertThat(actualEntity.total()).isEqualTo(items.size()); + assertThat(actualEntity.columns()).isEqualTo(columns); - var actualItems = actualEntity.content(); - - assertThat(actualItems) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_DATA_ITEM) - .isEqualTo(items.reversed()); + assertPage(items.reversed(), actualEntity.content()); } } @@ -2944,6 +3027,13 @@ void getDatasetItemsByDatasetId__whenDefiningPageSize__thenReturnPageWithLimitRe .datasetId(datasetId) .build(); + List> data = batch.items() + .stream() + .map(DatasetItem::data) + .toList(); + + Set columns = addDeprecatedFields(data); + putAndAssert(batch, TEST_WORKSPACE, API_KEY); try (var actualResponse = client.target(BASE_RESOURCE_URI.formatted(baseURI)) @@ -2963,12 +3053,9 @@ void getDatasetItemsByDatasetId__whenDefiningPageSize__thenReturnPageWithLimitRe assertThat(actualEntity.content()).hasSize(1); assertThat(actualEntity.page()).isEqualTo(1); assertThat(actualEntity.total()).isEqualTo(items.size()); + assertThat(actualEntity.columns()).isEqualTo(columns); - var actualItems = actualEntity.content(); - - assertThat(actualItems) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_DATA_ITEM) - .isEqualTo(List.of(items.reversed().getFirst())); + assertPage(List.of(items.reversed().getFirst()), actualEntity.content()); } } @@ -3000,6 +3087,13 @@ void getDatasetItemsByDatasetId__whenItemsWereUpdated__thenReturnCorrectItemsCou putAndAssert(updatedBatch, TEST_WORKSPACE, API_KEY); + List> data = updatedBatch.items() + .stream() + .map(DatasetItem::data) + .toList(); + + Set columns = addDeprecatedFields(data); + try (var actualResponse = client.target(BASE_RESOURCE_URI.formatted(baseURI)) .path(datasetId.toString()) .path("items") @@ -3016,17 +3110,37 @@ void getDatasetItemsByDatasetId__whenItemsWereUpdated__thenReturnCorrectItemsCou assertThat(actualEntity.content()).hasSize(updatedItems.size()); assertThat(actualEntity.page()).isEqualTo(1); assertThat(actualEntity.total()).isEqualTo(updatedItems.size()); + assertThat(actualEntity.columns()).isEqualTo(columns); - var actualItems = actualEntity.content(); - - assertThat(actualItems) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_DATA_ITEM) - .isEqualTo(updatedItems.reversed()); + assertPage(updatedItems.reversed(), actualEntity.content()); } } } + private static void assertPage(List expectedItems, List actualItems) { + + List ignoredFields = new ArrayList<>(Arrays.asList(IGNORED_FIELDS_DATA_ITEM)); + ignoredFields.add("data"); + + assertThat(actualItems) + .usingRecursiveFieldByFieldElementComparatorIgnoringFields(ignoredFields.toArray(String[]::new)) + .isEqualTo(expectedItems); + + assertThat(actualItems).hasSize(expectedItems.size()); + for (int i = 0; i < actualItems.size(); i++) { + var actualDatasetItem = actualItems.get(i); + var expectedDatasetItem = expectedItems.get(i); + + Map data = Optional.ofNullable(expectedDatasetItem.data()) + .orElse(Map.of()); + + expectedDatasetItem = mergeInputMap(expectedDatasetItem, data); + + assertThat(actualDatasetItem.data()).isEqualTo(expectedDatasetItem.data()); + } + } + @Nested @TestInstance(TestInstance.Lifecycle.PER_CLASS) class FindDatasetItemsWithExperimentItems { @@ -3067,7 +3181,7 @@ void find() { .toList(); var traceIdToScoresMap = Stream.concat(scores1.stream(), scores2.stream()) - .collect(Collectors.groupingBy(FeedbackScoreBatchItem::id)); + .collect(groupingBy(FeedbackScoreBatchItem::id)); // When storing the scores in batch, adding some more unrelated random ones var feedbackScoreBatch = factory.manufacturePojo(FeedbackScoreBatch.class); @@ -3115,7 +3229,7 @@ void find() { .map(FeedbackScoreMapper.INSTANCE::toFeedbackScore) .toList()) .build())) - .collect(Collectors.groupingBy(ExperimentItem::datasetItemId)); + .collect(groupingBy(ExperimentItem::datasetItemId)); // Dataset item 2 covers the case of experiments items related to a trace without input, output and scores. // It also has 2 experiment items per each of the 5 experiments. @@ -3151,10 +3265,16 @@ void find() { experimentItemsBatch = experimentItemsBatch.toBuilder() .experimentItems(Stream.concat(experimentItemsBatch.experimentItems().stream(), datasetItemIdToExperimentItemMap.values().stream().flatMap(Collection::stream)) - .collect(Collectors.toUnmodifiableSet())) + .collect(toUnmodifiableSet())) .build(); createAndAssert(experimentItemsBatch, apiKey, workspaceName); + List> data = expectedDatasetItems.stream() + .map(DatasetItem::data) + .toList(); + + Set columns = addDeprecatedFields(data); + var page = 1; var pageSize = 5; // Filtering by experiments 1 and 3. @@ -3178,12 +3298,11 @@ void find() { assertThat(actualPage.page()).isEqualTo(page); assertThat(actualPage.size()).isEqualTo(expectedDatasetItems.size()); assertThat(actualPage.total()).isEqualTo(expectedDatasetItems.size()); + assertThat(actualPage.columns()).isEqualTo(columns); var actualDatasetItems = actualPage.content(); - assertThat(actualDatasetItems) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_DATA_ITEM) - .containsExactlyElementsOf(expectedDatasetItems); + assertPage(expectedDatasetItems, actualPage.content()); for (var i = 0; i < actualDatasetItems.size(); i++) { var actualDatasetItem = actualDatasetItems.get(i); @@ -3292,6 +3411,8 @@ void find__whenFilteringBySupportedFields__thenReturnMatchingRows(Filter filter) apiKey, workspaceName); + Set columns = addDeprecatedFields(List.of(items.getFirst().data())); + List filters = List.of(filter); try (var actualResponse = client.target(BASE_RESOURCE_URI.formatted(baseURI)) @@ -3313,6 +3434,8 @@ void find__whenFilteringBySupportedFields__thenReturnMatchingRows(Filter filter) assertThat(actualPage.total()).isEqualTo(1); assertThat(actualPage.page()).isEqualTo(1); assertThat(actualPage.content()).hasSize(1); + assertThat(actualPage.columns()).isEqualTo(columns); + assertDatasetItemPage(actualPage, items, experimentItems); } } @@ -3556,12 +3679,27 @@ static Stream find__whenFilterInvalidOperatorForFieldType__thenReturn } } + private static Set addDeprecatedFields(List> data) { + + HashSet columns = data + .stream() + .map(Map::entrySet) + .flatMap(Collection::stream) + .map(entry -> new Column(entry.getKey(), + StringUtils.capitalize(entry.getValue().getNodeType().name().toLowerCase()))) + .collect(Collectors.toCollection(HashSet::new)); + + columns.add(new Column("input", "Object")); + columns.add(new Column("expected_output", "Object")); + columns.add(new Column("metadata", "Object")); + + return columns; + } + private void assertDatasetItemPage(DatasetItemPage actualPage, List items, List experimentItems) { - assertThat(actualPage.content().getFirst()) - .usingRecursiveComparison() - .ignoringFields(IGNORED_FIELDS_DATA_ITEM) - .isEqualTo(items.getFirst()); + + assertPage(List.of(items.getFirst()), List.of(actualPage.content().getFirst())); var actualExperimentItems = actualPage.content().getFirst().experimentItems(); assertThat(actualExperimentItems).hasSize(1); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java index 23450048d..1dd2fa3e2 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java @@ -3429,7 +3429,8 @@ void createAndGet__whenTraceInputIsBig__thenReturnSpan() { int size = 1000; Map jsonMap = IntStream.range(0, size) - .mapToObj(i -> Map.entry(RandomStringUtils.randomAlphabetic(10), RandomStringUtils.randomAscii(size))) + .mapToObj( + i -> Map.entry(RandomStringUtils.randomAlphabetic(10), RandomStringUtils.randomAscii(size))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); var expectedTrace = factory.manufacturePojo(Trace.class).toBuilder() diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/CategoricalFeedbackDetailTypeManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/CategoricalFeedbackDetailTypeManufacturer.java index 1d90dd72a..29ee7e7a9 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/CategoricalFeedbackDetailTypeManufacturer.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/CategoricalFeedbackDetailTypeManufacturer.java @@ -4,7 +4,7 @@ import uk.co.jemos.podam.api.DataProviderStrategy; import uk.co.jemos.podam.api.PodamUtils; import uk.co.jemos.podam.common.ManufacturingContext; -import uk.co.jemos.podam.typeManufacturers.TypeManufacturer; +import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer; import java.math.BigDecimal; import java.math.RoundingMode; @@ -17,7 +17,7 @@ import static com.comet.opik.api.FeedbackDefinition.CategoricalFeedbackDefinition.CategoricalFeedbackDetail; import static com.comet.opik.utils.ValidationUtils.SCALE; -public class CategoricalFeedbackDetailTypeManufacturer implements TypeManufacturer { +public class CategoricalFeedbackDetailTypeManufacturer extends AbstractTypeManufacturer { @Override public CategoricalFeedbackDetail getType(DataProviderStrategy dataProviderStrategy, diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/DatasetItemTypeManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/DatasetItemTypeManufacturer.java index 837407b53..e59d4e83b 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/DatasetItemTypeManufacturer.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/DatasetItemTypeManufacturer.java @@ -3,15 +3,20 @@ import com.comet.opik.api.DatasetItem; import com.comet.opik.api.DatasetItemSource; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.commons.lang3.RandomStringUtils; import uk.co.jemos.podam.api.AttributeMetadata; import uk.co.jemos.podam.api.DataProviderStrategy; import uk.co.jemos.podam.common.ManufacturingContext; import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer; import java.time.Instant; +import java.util.Map; import java.util.Random; import java.util.Set; import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; public class DatasetItemTypeManufacturer extends AbstractTypeManufacturer { @@ -33,6 +38,18 @@ public DatasetItem getType(DataProviderStrategy strategy, AttributeMetadata meta ? strategy.getTypeValue(metadata, context, UUID.class) : null; + Map data = IntStream.range(0, 5) + .mapToObj(i -> { + if (i % 2 == 0) { + return Map.entry(RandomStringUtils.randomAlphanumeric(10), + TextNode.valueOf(RandomStringUtils.randomAlphanumeric(10))); + } + + return Map.entry(RandomStringUtils.randomAlphanumeric(10), + strategy.getTypeValue(metadata, context, JsonNode.class)); + }) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + return DatasetItem.builder() .source(source) .traceId(traceId) @@ -41,6 +58,7 @@ public DatasetItem getType(DataProviderStrategy strategy, AttributeMetadata meta .input(strategy.getTypeValue(metadata, context, JsonNode.class)) .expectedOutput(strategy.getTypeValue(metadata, context, JsonNode.class)) .metadata(strategy.getTypeValue(metadata, context, JsonNode.class)) + .data(data) .createdAt(Instant.now()) .lastUpdatedAt(Instant.now()) .build(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/NumericalFeedbackDetailTypeManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/NumericalFeedbackDetailTypeManufacturer.java index d95ab78fe..ef4ff789b 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/NumericalFeedbackDetailTypeManufacturer.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/NumericalFeedbackDetailTypeManufacturer.java @@ -4,7 +4,7 @@ import uk.co.jemos.podam.api.DataProviderStrategy; import uk.co.jemos.podam.api.PodamUtils; import uk.co.jemos.podam.common.ManufacturingContext; -import uk.co.jemos.podam.typeManufacturers.TypeManufacturer; +import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer; import java.math.BigDecimal; import java.math.RoundingMode; @@ -12,7 +12,7 @@ import static com.comet.opik.api.FeedbackDefinition.NumericalFeedbackDefinition.NumericalFeedbackDetail; import static com.comet.opik.utils.ValidationUtils.*; -public class NumericalFeedbackDetailTypeManufacturer implements TypeManufacturer { +public class NumericalFeedbackDetailTypeManufacturer extends AbstractTypeManufacturer { @Override public NumericalFeedbackDetail getType(DataProviderStrategy dataProviderStrategy,