From f8a479a6facca5506ea934c6feae0e3581cb689d Mon Sep 17 00:00:00 2001 From: Max Lepikhin <46848373+maxlepikhin@users.noreply.github.com> Date: Thu, 12 Dec 2024 10:43:09 -0700 Subject: [PATCH] Fix: ml/engine/utils/FileUtils casts long file length to int incorrectly (#3198) * Use longs when splitting model zip file Signed-off-by: Max Lepikhin * add test Signed-off-by: Max Lepikhin * spotless Signed-off-by: Max Lepikhin * clean up test Signed-off-by: Max Lepikhin --------- Signed-off-by: Max Lepikhin Signed-off-by: tkykenmt --- .../opensearch/ml/engine/utils/FileUtils.java | 8 +-- .../ml/engine/utils/FileUtilsTest.java | 63 +++++++++++++++++++ 2 files changed, 67 insertions(+), 4 deletions(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/FileUtilsTest.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java index 677ca1aa9d..82551042e1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java @@ -45,16 +45,16 @@ public class FileUtils { * @throws IOException */ public static List splitFileIntoChunks(File file, Path outputPath, int chunkSize) throws IOException { - int fileSize = (int) file.length(); + long fileSize = file.length(); ArrayList nameList = new ArrayList<>(); try (InputStream inStream = new BufferedInputStream(new FileInputStream(file))) { int numberOfChunk = 0; - int totalBytesRead = 0; + long totalBytesRead = 0; while (totalBytesRead < fileSize) { String partName = numberOfChunk + ""; - int bytesRemaining = fileSize - totalBytesRead; + long bytesRemaining = fileSize - totalBytesRead; if (bytesRemaining < chunkSize) { - chunkSize = bytesRemaining; + chunkSize = (int) bytesRemaining; } byte[] temporary = new byte[chunkSize]; int bytesRead = inStream.read(temporary, 0, chunkSize); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/FileUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/FileUtilsTest.java new file mode 100644 index 0000000000..390286aadd --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/FileUtilsTest.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class FileUtilsTest { + private TemporaryFolder tempDir; + + @Before + public void setUp() throws Exception { + tempDir = new TemporaryFolder(); + tempDir.create(); + } + + @After + public void tearDown() { + if (tempDir != null) { + tempDir.delete(); + } + } + + @Test + public void testSplitFileIntoChunks() throws Exception { + // Write file. + Random random = new Random(); + File file = tempDir.newFile("large_file"); + byte[] data = new byte[1017]; + random.nextBytes(data); + Files.write(file.toPath(), data); + + // Split file into chunks. + int chunkSize = 325; + List chunkPaths = FileUtils.splitFileIntoChunks(file, tempDir.newFolder().toPath(), chunkSize); + + // Verify. + int currentPosition = 0; + for (String chunkPath : chunkPaths) { + byte[] chunk = Files.readAllBytes(Path.of(chunkPath)); + assertTrue("Chunk size", currentPosition + chunk.length <= data.length); + Assert.assertArrayEquals(Arrays.copyOfRange(data, currentPosition, currentPosition + chunk.length), chunk); + currentPosition += chunk.length; + } + assertEquals(currentPosition, data.length); + } +}