Skip to content

Commit

Permalink
[8.11] [ML] Fix empty requests being sent to nodes with the model all…
Browse files Browse the repository at this point in the history
…ocations (elastic#100388) (elastic#100395)

* [ML] Fix empty requests being sent to nodes with the model allocations (elastic#100388)

Fix for inference requests being sent to every node with a model allocation. 
If there are more nodes than items in the original request then empty 
requests were sent.

* Fix changelog
  • Loading branch information
davidkyle authored Oct 6, 2023
1 parent 47b89f3 commit 92c4a27
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 14 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/100388.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 100388
summary: Fix for inference requests being sent to every node with a model allocation. If there are more nodes than items in the original request then empty requests were sent.
area: Machine Learning
type: bug
issues:
- 100180
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocations

var nodeCounts = new ArrayList<Tuple<String, Integer>>();
for (int i = 0; i < counts.length; i++) {
nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
if (counts[i] > 0) {
nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
}
}
return nodeCounts;
}
Expand All @@ -232,7 +234,10 @@ public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocations

var nodeCounts = new ArrayList<Tuple<String, Integer>>();
for (int i = 0; i < counts.length; i++) {
nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
// filter out zero counts
if (counts[i] > 0) {
nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
}
}
return nodeCounts;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand All @@ -22,7 +21,6 @@
public class ErrorInferenceResults implements InferenceResults {

public static final String NAME = "error";
public static final ParseField WARNING = new ParseField("error");

private final Exception exception;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,17 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSin
assertThat(nodes.get(0), equalTo(new Tuple<>("node-1", 1)));
}

public void testSingleRequestWith2Nodes() {
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5));
builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, ""));
builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, ""));
TrainedModelAssignment assignment = builder.build();

var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1);
assertThat(nodes, hasSize(1));
assertEquals(nodes.get(0).v2(), Integer.valueOf(1));
}

public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodes() {
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6));
builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, ""));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,17 +354,19 @@ private void sendResponse() {
} else {
for (int i = 0; i < results.length(); i++) {
var resultList = results.get(i);
if (resultList != null) {
for (var result : resultList) {
if (result instanceof ErrorInferenceResults errorResult) {
// Any failure fails all requests
// TODO is this the correct behaviour for batched requests?
finalListener.onFailure(errorResult.getException());
return;
}
if (resultList == null) {
continue;
}

for (var result : resultList) {
if (result instanceof ErrorInferenceResults errorResult) {
// Any failure fails all requests
// TODO is this the correct behaviour for batched requests?
finalListener.onFailure(errorResult.getException());
return;
}
responseBuilder.addInferenceResults(resultList);
}
responseBuilder.addInferenceResults(resultList);
}
finalListener.onResponse(responseBuilder.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ public void removeLogging() throws IOException {
client().performRequest(request);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/100180")
public void testTrainedModelDeployment() throws Exception {
assumeTrue("NLP model deployments added in 8.0", UPGRADE_FROM_VERSION.onOrAfter(Version.V_8_0_0));

Expand Down

0 comments on commit 92c4a27

Please sign in to comment.