From fab848844969ba173b592a4ef20bff7b53e018ae Mon Sep 17 00:00:00 2001 From: jonathan Date: Tue, 10 Sep 2024 09:25:33 +0800 Subject: [PATCH] fix: Fixed the bug of sequence error during merging files --- .../org/opensearch/ml/model/MLModelManager.java | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 070d0def34..17b43ee60d 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -58,11 +58,14 @@ import java.util.Base64; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Queue; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; @@ -1687,7 +1690,7 @@ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener liste Semaphore semaphore = new Semaphore(1); AtomicBoolean stopNow = new AtomicBoolean(false); String modelZip = mlEngine.getDeployModelZipPath(modelId, modelName); - ConcurrentLinkedDeque chunkFiles = new ConcurrentLinkedDeque(); + ConcurrentHashMap chunkFiles = new ConcurrentHashMap<>(); AtomicInteger retrievedChunks = new AtomicInteger(0); for (int i = 0; i < totalChunks; i++) { semaphore.tryAcquire(10, TimeUnit.SECONDS); @@ -1699,11 +1702,15 @@ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener liste this.getModel(modelChunkId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(model -> { Path chunkPath = mlEngine.getDeployModelChunkPath(modelId, currentChunk); FileUtils.write(Base64.getDecoder().decode(model.getContent()), chunkPath.toString()); - chunkFiles.add(new File(chunkPath.toUri())); + chunkFiles.put(currentChunk, new File(chunkPath.toUri())); retrievedChunks.getAndIncrement(); if (retrievedChunks.get() == totalChunks) { + Queue orderedChunkFiles = chunkFiles.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .map(Map.Entry::getValue) + .collect(Collectors.toCollection(LinkedList::new)); File modelZipFile = new File(modelZip); - FileUtils.mergeFiles(chunkFiles, modelZipFile); + FileUtils.mergeFiles(orderedChunkFiles, modelZipFile); listener.onResponse(modelZipFile); } semaphore.release();