Skip to content

Commit

Permalink
[ML] Zone aware planner renaming & related refactoring (elastic#111522)
Browse files Browse the repository at this point in the history
* Renaming - code mentioned modelId but was actually deploymentId

* Documenting

* add a test case and more renaming

* Renaming & remove TODOs

* Update MlAutoscalingStats javadoc to match autoscaler comments

* precommit
  • Loading branch information
maxhniebergall authored Sep 25, 2024
1 parent 138e100 commit 43ec760
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,30 @@
* <p>
* The word "total" in an attribute name indicates that the attribute is a sum across all nodes.
*
* @param currentTotalNodes the count of nodes that are currently in the cluster
* @param currentPerNodeMemoryBytes the minimum size (memory) of all nodes in the cluster
* @param currentTotalModelMemoryBytes the sum of model memory over every assignment/deployment
* @param currentTotalProcessorsInUse the sum of processors used over every assignment/deployment
* @param currentPerNodeMemoryOverheadBytes always equal to MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD
* @param wantedMinNodes the minimum number of nodes that must be provided by the autoscaler
* @param wantedExtraPerNodeMemoryBytes the amount of additional memory that must be provided on every node
* (this value must be >0 to trigger a scale up based on memory)
* @param wantedExtraPerNodeNodeProcessors the number of additional processors that must be provided on every node
* (this value must be >0 to trigger a scale up based on processors)
* @param wantedExtraModelMemoryBytes the amount of additional model memory that is newly required
* (due to a new assignment/deployment)
* @param wantedExtraProcessors the number of additional processors that are required to be added to the cluster
* @param unwantedNodeMemoryBytesToRemove the amount of memory that should be removed from the cluster. If this is equal to the amount of
* memory provided by a node, a node will be removed.
* @param currentTotalNodes The count of nodes that are currently in the cluster,
* used to confirm that both sides have same view of current state
* @param currentPerNodeMemoryBytes The minimum size (memory) of all nodes in the cluster
* used to confirm that both sides have same view of current state.
* @param currentTotalModelMemoryBytes The sum of model memory over every assignment/deployment, used to calculate requirements
* @param currentTotalProcessorsInUse The sum of processors used over every assignment/deployment, not used by autoscaler
* @param currentPerNodeMemoryOverheadBytes Always equal to MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD,
* @param wantedMinNodes The minimum number of nodes that must be provided by the autoscaler
* @param wantedExtraPerNodeMemoryBytes If there are jobs or trained models that have been started but cannot be allocated on the
* ML nodes currently within the cluster then this will be the *max* of the ML native memory
* requirements of those jobs/trained models. The metric is in terms of ML native memory,
* not container memory.
* @param wantedExtraPerNodeNodeProcessors If there are trained model allocations that have been started but cannot be allocated on the
* ML nodes currently within the cluster then this will be the *max* of the vCPU requirements of
* those allocations. Zero otherwise.
* @param wantedExtraModelMemoryBytes If there are jobs or trained models that have been started but cannot be allocated on the ML
* nodes currently within the cluster then this will be the *sum* of the ML native memory
* requirements of those jobs/trained models. The metric is in terms of ML native memory,
* not container memory.
* @param wantedExtraProcessors If there are trained model allocations that have been started but cannot be allocated on the
* ML nodes currently within the cluster then this will be the *sum* of the vCPU requirements
* of those allocations. Zero otherwise.
* @param unwantedNodeMemoryBytesToRemove The size of the ML node to be removed, in GB rounded to the nearest GB,
* or zero if no nodes could be removed.
*/

public record MlAutoscalingStats(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ private static AssignmentPlan mergePlans(
nodesByZone.values().forEach(allNodes::addAll);

final List<AssignmentPlan.Deployment> allDeployments = new ArrayList<>();
allDeployments.addAll(planForNormalPriorityModels.models());
allDeployments.addAll(planForLowPriorityModels.models());
allDeployments.addAll(planForNormalPriorityModels.deployments());
allDeployments.addAll(planForLowPriorityModels.deployments());

final Map<String, AssignmentPlan.Node> originalNodeById = allNodes.stream()
.collect(Collectors.toMap(AssignmentPlan.Node::id, Function.identity()));
Expand All @@ -139,7 +139,7 @@ private static void copyAssignments(
AssignmentPlan.Builder dest,
Map<String, AssignmentPlan.Node> originalNodeById
) {
for (AssignmentPlan.Deployment m : source.models()) {
for (AssignmentPlan.Deployment m : source.deployments()) {
Map<AssignmentPlan.Node, Integer> nodeAssignments = source.assignments(m).orElse(Map.of());
for (Map.Entry<AssignmentPlan.Node, Integer> assignment : nodeAssignments.entrySet()) {
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
Expand Down Expand Up @@ -328,14 +328,14 @@ private static long getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(

private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(AssignmentPlan assignmentPlan) {
TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.Builder.empty();
for (AssignmentPlan.Deployment deployment : assignmentPlan.models()) {
TrainedModelAssignment existingAssignment = currentMetadata.getDeploymentAssignment(deployment.id());
for (AssignmentPlan.Deployment deployment : assignmentPlan.deployments()) {
TrainedModelAssignment existingAssignment = currentMetadata.getDeploymentAssignment(deployment.deploymentId());

TrainedModelAssignment.Builder assignmentBuilder = existingAssignment == null && createAssignmentRequest.isPresent()
? TrainedModelAssignment.Builder.empty(createAssignmentRequest.get())
: TrainedModelAssignment.Builder.empty(
currentMetadata.getDeploymentAssignment(deployment.id()).getTaskParams(),
currentMetadata.getDeploymentAssignment(deployment.id()).getAdaptiveAllocationsSettings()
currentMetadata.getDeploymentAssignment(deployment.deploymentId()).getTaskParams(),
currentMetadata.getDeploymentAssignment(deployment.deploymentId()).getAdaptiveAllocationsSettings()
);
if (existingAssignment != null) {
assignmentBuilder.setStartTime(existingAssignment.getStartTime());
Expand Down Expand Up @@ -366,7 +366,7 @@ private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(Assignme
assignmentBuilder.calculateAndSetAssignmentState();

explainAssignments(assignmentPlan, nodeLoads, deployment).ifPresent(assignmentBuilder::setReason);
builder.addNewAssignment(deployment.id(), assignmentBuilder);
builder.addNewAssignment(deployment.deploymentId(), assignmentBuilder);
}
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) {
}

return new Deployment(
m.id(),
m.deploymentId(),
m.memoryBytes(),
m.allocations() - calculatePreservedAllocations(m),
m.threadsPerAllocation(),
Expand All @@ -71,11 +71,14 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) {
// they will not match the models/nodes members we have in this class.
// Therefore, we build a lookup table based on the ids, so we can merge the plan
// with its preserved allocations.
final Map<Tuple<String, String>, Integer> plannedAssignmentsByModelNodeIdPair = new HashMap<>();
for (Deployment m : assignmentPlan.models()) {
Map<Node, Integer> assignments = assignmentPlan.assignments(m).orElse(Map.of());
for (Map.Entry<Node, Integer> nodeAssignment : assignments.entrySet()) {
plannedAssignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue());
final Map<Tuple<String, String>, Integer> plannedAssignmentsByDeploymentNodeIdPair = new HashMap<>();
for (Deployment d : assignmentPlan.deployments()) {
Map<Node, Integer> assignmentsOfDeployment = assignmentPlan.assignments(d).orElse(Map.of());
for (Map.Entry<Node, Integer> nodeAssignment : assignmentsOfDeployment.entrySet()) {
plannedAssignmentsByDeploymentNodeIdPair.put(
Tuple.tuple(d.deploymentId(), nodeAssignment.getKey().id()),
nodeAssignment.getValue()
);
}
}

Expand All @@ -93,8 +96,8 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) {
}
}
for (Deployment deploymentNewAllocations : deployments) {
int newAllocations = plannedAssignmentsByModelNodeIdPair.getOrDefault(
Tuple.tuple(deploymentNewAllocations.id(), n.id()),
int newAllocations = plannedAssignmentsByDeploymentNodeIdPair.getOrDefault(
Tuple.tuple(deploymentNewAllocations.deploymentId(), n.id()),
0
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,22 @@
*/
public class AssignmentPlan implements Comparable<AssignmentPlan> {

/**
*
* @param deploymentId
* @param memoryBytes
* @param allocations
* @param threadsPerAllocation
* @param currentAllocationsByNodeId
* @param maxAssignedAllocations this value is used by the ZoneAwareAssignmentPlan and AssignmentPlanner to keep track of the
* maximum number of allocations which have been assigned. It is mainly for assigning over AZs.
* @param adaptiveAllocationsSettings
* @param priority
* @param perDeploymentMemoryBytes
* @param perAllocationMemoryBytes
*/
public record Deployment(
String id,
String deploymentId,
long memoryBytes,
int allocations,
int threadsPerAllocation,
Expand All @@ -44,7 +58,7 @@ public record Deployment(
long perAllocationMemoryBytes
) {
public Deployment(
String id,
String deploymentId,
long modelBytes,
int allocations,
int threadsPerAllocation,
Expand All @@ -55,7 +69,7 @@ public Deployment(
long perAllocationMemoryBytes
) {
this(
id,
deploymentId,
modelBytes,
allocations,
threadsPerAllocation,
Expand All @@ -82,7 +96,7 @@ boolean hasEverBeenAllocated() {

public long estimateMemoryUsageBytes(int allocations) {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
id,
deploymentId,
memoryBytes,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
Expand All @@ -92,13 +106,13 @@ public long estimateMemoryUsageBytes(int allocations) {

long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
id,
deploymentId,
memoryBytes,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
allocationsNew
) - StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
id,
deploymentId,
memoryBytes,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
Expand All @@ -109,7 +123,7 @@ long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew)

long minimumMemoryRequiredBytes() {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
id,
deploymentId,
memoryBytes,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
Expand All @@ -136,7 +150,7 @@ int findExcessAllocations(int maxAllocations, long availableMemoryBytes) {

@Override
public String toString() {
return id
return deploymentId
+ " (mem = "
+ ByteSizeValue.ofBytes(memoryBytes)
+ ") (allocations = "
Expand Down Expand Up @@ -186,7 +200,7 @@ private AssignmentPlan(
this.remainingModelAllocations = Objects.requireNonNull(remainingModelAllocations);
}

public Set<Deployment> models() {
public Set<Deployment> deployments() {
return assignments.keySet();
}

Expand All @@ -208,7 +222,7 @@ public int compareTo(AssignmentPlan o) {
}

public boolean satisfiesCurrentAssignments() {
return models().stream().allMatch(this::isSatisfyingCurrentAssignmentsForModel);
return deployments().stream().allMatch(this::isSatisfyingCurrentAssignmentsForModel);
}

private boolean isSatisfyingCurrentAssignmentsForModel(Deployment m) {
Expand All @@ -225,18 +239,18 @@ public boolean satisfiesAllocations(Deployment m) {
}

public boolean satisfiesAllModels() {
return models().stream().allMatch(this::satisfiesAllocations);
return deployments().stream().allMatch(this::satisfiesAllocations);
}

public boolean arePreviouslyAssignedModelsAssigned() {
return models().stream()
return deployments().stream()
.filter(Deployment::hasEverBeenAllocated)
.map(this::totalAllocations)
.allMatch(totalAllocations -> totalAllocations > 0);
}

public long countPreviouslyAssignedModelsThatAreStillAssigned() {
return models().stream()
return deployments().stream()
.filter(Deployment::hasEverBeenAllocated)
.map(this::totalAllocations)
.filter(totalAllocations -> totalAllocations > 0)
Expand Down Expand Up @@ -301,11 +315,11 @@ public String prettyPrint() {
msg.append(" ->");
for (Tuple<Deployment, Integer> modelAllocations : nodeToModel.get(n)
.stream()
.sorted(Comparator.comparing(x -> x.v1().id()))
.sorted(Comparator.comparing(x -> x.v1().deploymentId()))
.toList()) {
if (modelAllocations.v2() > 0) {
msg.append(" ");
msg.append(modelAllocations.v1().id());
msg.append(modelAllocations.v1().deploymentId());
msg.append(" (mem = ");
msg.append(ByteSizeValue.ofBytes(modelAllocations.v1().memoryBytes()));
msg.append(")");
Expand Down Expand Up @@ -415,7 +429,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
+ "] to assign ["
+ allocations
+ "] allocations to deployment ["
+ deployment.id()
+ deployment.deploymentId()
+ "]"
);
}
Expand All @@ -426,7 +440,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
+ "] to assign ["
+ allocations
+ "] allocations to deployment ["
+ deployment.id()
+ deployment.deploymentId()
+ "]; required threads per allocation ["
+ deployment.threadsPerAllocation()
+ "]"
Expand Down Expand Up @@ -464,7 +478,7 @@ public void accountMemory(Deployment m, Node n) {
private void accountMemory(Deployment m, Node n, long requiredMemory) {
remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory);
if (remainingNodeMemory.containsKey(n) && remainingNodeMemory.get(n) < 0) {
throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.id() + "]");
throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.deploymentId() + "]");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class AssignmentPlanner {

public AssignmentPlanner(List<Node> nodes, List<AssignmentPlan.Deployment> deployments) {
this.nodes = nodes.stream().sorted(Comparator.comparing(Node::id)).toList();
this.deployments = deployments.stream().sorted(Comparator.comparing(AssignmentPlan.Deployment::id)).toList();
this.deployments = deployments.stream().sorted(Comparator.comparing(AssignmentPlan.Deployment::deploymentId)).toList();
}

public AssignmentPlan computePlan() {
Expand Down Expand Up @@ -111,7 +111,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat
.filter(m -> m.hasEverBeenAllocated())
.map(
m -> new AssignmentPlan.Deployment(
m.id(),
m.deploymentId(),
m.memoryBytes(),
1,
m.threadsPerAllocation(),
Expand All @@ -130,21 +130,21 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat
).solvePlan(true);

Map<String, String> modelIdToNodeIdWithSingleAllocation = new HashMap<>();
for (AssignmentPlan.Deployment m : planWithSingleAllocationForPreviouslyAssignedModels.models()) {
for (AssignmentPlan.Deployment m : planWithSingleAllocationForPreviouslyAssignedModels.deployments()) {
Optional<Map<Node, Integer>> assignments = planWithSingleAllocationForPreviouslyAssignedModels.assignments(m);
Set<Node> nodes = assignments.orElse(Map.of()).keySet();
if (nodes.isEmpty() == false) {
assert nodes.size() == 1;
modelIdToNodeIdWithSingleAllocation.put(m.id(), nodes.iterator().next().id());
modelIdToNodeIdWithSingleAllocation.put(m.deploymentId(), nodes.iterator().next().id());
}
}

List<AssignmentPlan.Deployment> planDeployments = deployments.stream().map(m -> {
Map<String, Integer> currentAllocationsByNodeId = modelIdToNodeIdWithSingleAllocation.containsKey(m.id())
? Map.of(modelIdToNodeIdWithSingleAllocation.get(m.id()), 1)
Map<String, Integer> currentAllocationsByNodeId = modelIdToNodeIdWithSingleAllocation.containsKey(m.deploymentId())
? Map.of(modelIdToNodeIdWithSingleAllocation.get(m.deploymentId()), 1)
: Map.of();
return new AssignmentPlan.Deployment(
m.id(),
m.deploymentId(),
m.memoryBytes(),
m.allocations(),
m.threadsPerAllocation(),
Expand Down
Loading

0 comments on commit 43ec760

Please sign in to comment.