Skip to content

Commit

Permalink
fix: Fixed the bug of sequence error during merging files
Browse files Browse the repository at this point in the history
  • Loading branch information
jlibx committed Sep 10, 2024
1 parent 55d28e0 commit fab8488
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,14 @@
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -1687,7 +1690,7 @@ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener<File> liste
Semaphore semaphore = new Semaphore(1);
AtomicBoolean stopNow = new AtomicBoolean(false);
String modelZip = mlEngine.getDeployModelZipPath(modelId, modelName);
ConcurrentLinkedDeque<File> chunkFiles = new ConcurrentLinkedDeque();
ConcurrentHashMap<Integer, File> chunkFiles = new ConcurrentHashMap<>();
AtomicInteger retrievedChunks = new AtomicInteger(0);
for (int i = 0; i < totalChunks; i++) {
semaphore.tryAcquire(10, TimeUnit.SECONDS);
Expand All @@ -1699,11 +1702,15 @@ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener<File> liste
this.getModel(modelChunkId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(model -> {
Path chunkPath = mlEngine.getDeployModelChunkPath(modelId, currentChunk);
FileUtils.write(Base64.getDecoder().decode(model.getContent()), chunkPath.toString());
chunkFiles.add(new File(chunkPath.toUri()));
chunkFiles.put(currentChunk, new File(chunkPath.toUri()));
retrievedChunks.getAndIncrement();
if (retrievedChunks.get() == totalChunks) {
Queue<File> orderedChunkFiles = chunkFiles.entrySet().stream()
.sorted(Map.Entry.comparingByKey())
.map(Map.Entry::getValue)
.collect(Collectors.toCollection(LinkedList::new));
File modelZipFile = new File(modelZip);
FileUtils.mergeFiles(chunkFiles, modelZipFile);
FileUtils.mergeFiles(orderedChunkFiles, modelZipFile);
listener.onResponse(modelZipFile);
}
semaphore.release();
Expand Down

0 comments on commit fab8488

Please sign in to comment.