Skip to content

Commit

Permalink
Remove infer tests since they will be covered in the notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
saikatsarkar056 committed Sep 17, 2024
1 parent 78e63d5 commit 00d7ea6
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 380 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,52 +76,6 @@ public void shutdown() throws IOException {
webServer.close();
}

public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var apiKey = "apiKey";
var model = "model";
var input = "input";
var projectId = "projectId";
URI uri = URI.create("https://abc.com");
var apiVersion = "apiVersion";

var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

try (var sender = senderFactory.createSender()) {
sender.start();

String responseJson = """
{
"results": [
{
"embedding": [
0.0123,
-0.0123
],
"input": "abc"
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var action = createAction(getUrl(webServer), apiKey, model, projectId, uri, apiVersion, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of(input)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var result = listener.actionGet(TIMEOUT);

assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F }))));
assertThat(webServer.requests(), hasSize(1));
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));

var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap, is(Map.of("project_id", "projectId", "inputs", List.of(input), "model_id", "model")));
}
}

public void testExecute_ThrowsElasticsearchException() {
var sender = mock(Sender.class);
var apiKey = "apiKey";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,340 +423,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept
verifyNoMoreInteractions(sender);
}

public void testInfer_SendsEmbeddingsRequest() throws IOException {
var input = "input";

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) {
String responseJson = """
{
"results": [
{
"embedding": [
0.0123,
-0.0123
],
"input": "input"
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var model = IbmWatsonxEmbeddingsModelTests.createModel(
modelId,
projectId,
URI.create(url),
apiVersion,
apiKey,
getUrl(webServer)
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(
model,
null,
List.of(input),
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var result = listener.actionGet(TIMEOUT);

assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F }))));
assertThat(webServer.requests(), hasSize(1));
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType()));

var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap, Matchers.is(Map.of("project_id", projectId, "inputs", List.of(input), "model_id", modelId)));
}
}

public void testChunkedInfer_Batches() throws IOException {
var input = List.of("foo", "bar");

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) {
String responseJson = """
{
"results": [
{
"embedding": [
0.0123,
-0.0123
],
"input": "foo"
},
{
"embedding": [
0.0456,
-0.0456
],
"input": "bar"
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var model = IbmWatsonxEmbeddingsModelTests.createModel(
modelId,
projectId,
URI.create(url),
apiVersion,
apiKey,
getUrl(webServer)
);
PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
service.chunkedInfer(
model,
null,
input,
new HashMap<>(),
InputType.INGEST,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, hasSize(2));

// first result
{
assertThat(results.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(input.get(0), floatResult.chunks().get(0).matchedText());
assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding()));
}

// second result
{
assertThat(results.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(input.get(1), floatResult.chunks().get(0).matchedText());
assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding()));
}

assertThat(webServer.requests(), hasSize(1));
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType()));

var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap, is(Map.of("project_id", projectId, "inputs", List.of("foo", "bar"), "model_id", modelId)));
}
}

public void testInfer_ResourceNotFound() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) {

String responseJson = """
{
"error": {
"message": "error"
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));

var model = IbmWatsonxEmbeddingsModelTests.createModel(
modelId,
projectId,
URI.create(url),
apiVersion,
apiKey,
getUrl(webServer)
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(
model,
null,
List.of("abc"),
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(error.getMessage(), containsString("Resource not found at "));
assertThat(error.getMessage(), containsString("Error message: [error]"));
assertThat(webServer.requests(), hasSize(1));
}
}

public void testCheckModelConfig_UpdatesDimensions() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
var similarityMeasure = SimilarityMeasure.DOT_PRODUCT;

try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) {
String responseJson = """
{
"results": [
{
"embedding": [
0.0123,
-0.0123
],
"input": "foo"
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var model = IbmWatsonxEmbeddingsModelTests.createModel(
getUrl(webServer),
modelId,
projectId,
URI.create(url),
apiVersion,
apiKey,
1,
similarityMeasure
);

PlainActionFuture<Model> listener = new PlainActionFuture<>();
service.checkModelConfig(model, listener);
var result = listener.actionGet(TIMEOUT);

// Updates dimensions to two as two embeddings were returned instead of one as specified before
assertThat(
result,
is(
IbmWatsonxEmbeddingsModelTests.createModel(
getUrl(webServer),
modelId,
projectId,
URI.create(url),
apiVersion,
apiKey,
2,
similarityMeasure
)
)
);
}
}

public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
var twoDimension = 2;

try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) {
String responseJson = """
{
"results": [
{
"embedding": [
0.0123,
-0.0123
],
"input": "foo"
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var model = IbmWatsonxEmbeddingsModelTests.createModel(
getUrl(webServer),
modelId,
projectId,
URI.create(url),
apiVersion,
apiKey,
twoDimension,
null
);

PlainActionFuture<Model> listener = new PlainActionFuture<>();
service.checkModelConfig(model, listener);
var result = listener.actionGet(TIMEOUT);

assertThat(
result,
is(
IbmWatsonxEmbeddingsModelTests.createModel(
getUrl(webServer),
modelId,
projectId,
URI.create(url),
apiVersion,
apiKey,
twoDimension,
SimilarityMeasure.DOT_PRODUCT
)
)
);
}
}

public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
var twoDimension = 2;

try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) {
String responseJson = """
{
"results": [
{
"embedding": [
0.0123,
-0.0123
],
"input": "foo"
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var model = IbmWatsonxEmbeddingsModelTests.createModel(
getUrl(webServer),
modelId,
projectId,
URI.create(url),
apiVersion,
apiKey,
twoDimension,
SimilarityMeasure.COSINE
);

PlainActionFuture<Model> listener = new PlainActionFuture<>();
service.checkModelConfig(model, listener);
var result = listener.actionGet(TIMEOUT);

assertThat(
result,
is(
IbmWatsonxEmbeddingsModelTests.createModel(
getUrl(webServer),
modelId,
projectId,
URI.create(url),
apiVersion,
apiKey,
twoDimension,
SimilarityMeasure.COSINE
)
)
);
}
}

private static ActionListener<Model> getModelListenerForException(Class<?> exceptionClass, String expectedMessage) {
return ActionListener.<Model>wrap((model) -> fail("Model parsing should have failed"), e -> {
assertThat(e, Matchers.instanceOf(exceptionClass));
Expand Down

0 comments on commit 00d7ea6

Please sign in to comment.