Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: uri.nudelman <[email protected]>
  • Loading branch information
uriofferup committed Aug 13, 2024
1 parent 94171f8 commit 4f2b98d
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.MissingRequiredPropertiesException;
import org.opensearch.client.util.ObjectBuilder;

@JsonpDeserializable
public class NeuralQuery extends QueryBase implements QueryVariant {

private final String field;
private final String queryText;
private final String queryImage;
private final int k;
@Nullable
private final String modelId;
Expand All @@ -34,7 +36,11 @@ private NeuralQuery(NeuralQuery.Builder builder) {
super(builder);

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.queryText = ApiTypeHelper.requireNonNull(builder.queryText, this, "queryText");
if (builder.queryText == null && builder.queryImage == null && !ApiTypeHelper.requiredPropertiesCheckDisabled()) {
throw new MissingRequiredPropertiesException(this, "queryText", "queryImage");
}
this.queryText = builder.queryText;
this.queryImage = builder.queryImage;
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.modelId = builder.modelId;
this.filter = builder.filter;
Expand Down Expand Up @@ -72,6 +78,15 @@ public final String queryText() {
return this.queryText;
}

/**
* Required - Search query image.
*
* @return Search query image.
*/
public final String queryImage() {
return this.queryImage;
}

/**
* Required - The number of neighbors to return.
*
Expand Down Expand Up @@ -112,7 +127,13 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

super.serializeInternal(generator, mapper);

generator.write("query_text", this.queryText);
if (this.queryText != null) {
generator.write("query_text", this.queryText);
}

if (this.queryImage != null) {
generator.write("query_image", this.queryImage);
}

if (this.modelId != null) {
generator.write("model_id", this.modelId);
Expand All @@ -129,7 +150,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}

public Builder toBuilder() {
return new Builder().field(field).queryText(queryText).k(k).modelId(modelId).filter(filter);
return new Builder().field(field).queryText(queryText).queryImage(queryImage).k(k).modelId(modelId).filter(filter);
}

/**
Expand All @@ -138,6 +159,7 @@ public Builder toBuilder() {
public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builder> implements ObjectBuilder<NeuralQuery> {
private String field;
private String queryText;
private String queryImage;
private Integer k;
@Nullable
private String modelId;
Expand Down Expand Up @@ -166,6 +188,17 @@ public NeuralQuery.Builder queryText(@Nullable String queryText) {
return this;
}

/**
* Required - Search query image.
*
* @param queryImage Search query image.
* @return This builder.
*/
public NeuralQuery.Builder queryImage(@Nullable String queryImage) {
this.queryImage = queryImage;
return this;
}

/**
* Optional - The model_id field if the default model for the index or field is set.
* Required - The model_id field if there is no default model set for the index or field.
Expand Down Expand Up @@ -227,6 +260,7 @@ protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuer
setupQueryBaseDeserializer(op);

op.add(NeuralQuery.Builder::queryText, JsonpDeserializer.stringDeserializer(), "query_text");
op.add(NeuralQuery.Builder::queryImage, JsonpDeserializer.stringDeserializer(), "query_image");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(NeuralQuery.Builder::filter, Query._DESERIALIZER, "filter");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/*
* Licensed to Elasticsearch B.V. under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch B.V. licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.client.util;

import java.util.StringJoiner;

/**
* Thrown by {@link ObjectBuilder#build()} when one of the required properties is missing.
* <p>
* If you think this is an error and that the reported property is actually optional, a workaround is
* available in {@link ApiTypeHelper} to disable checks. Use with caution.
*/
public class MissingRequiredPropertiesException extends RuntimeException {
private Class<?> clazz;
private String[] properties;

public MissingRequiredPropertiesException(Object obj, String... properties) {
super(
"Missing at least one required property between "
+ buildPropertiesMsg(properties)
+ " in '"
+ obj.getClass().getSimpleName()
+ "'"
);
this.clazz = obj.getClass();
this.properties = properties;
}

/**
* The class where the missing property was found
*/
public Class<?> getObjectClass() {
return clazz;
}

public String[] getPropertiesName() {
return properties;
}

private static String buildPropertiesMsg(String[] properties) {
final StringJoiner sj = new StringJoiner(",", "'", "'");
for (final String property : properties) {
sj.add(property);
}
return sj.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

import org.junit.Test;
import org.opensearch.client.opensearch.model.ModelTestCase;
import org.opensearch.client.util.MissingRequiredPropertiesException;

public class NeuralQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
public void toBuilder_queryText() {
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryText("queryText")
.k(1)
Expand All @@ -23,4 +24,37 @@ public void toBuilder() {

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilder_queryImage() {
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryImage("queryImage")
.k(1)
.filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery())
.build();
NeuralQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilder_both() {
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryText("queryText")
.queryImage("queryImage")
.k(1)
.filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery())
.build();
NeuralQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilder_missing_query() {
assertThrows(
MissingRequiredPropertiesException.class,
() -> new NeuralQuery.Builder().field("field").k(1).filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery()).build()
);
}
}

0 comments on commit 4f2b98d

Please sign in to comment.