diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java index 677ca1aa9d..82551042e1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java @@ -45,16 +45,16 @@ public class FileUtils { * @throws IOException */ public static List splitFileIntoChunks(File file, Path outputPath, int chunkSize) throws IOException { - int fileSize = (int) file.length(); + long fileSize = file.length(); ArrayList nameList = new ArrayList<>(); try (InputStream inStream = new BufferedInputStream(new FileInputStream(file))) { int numberOfChunk = 0; - int totalBytesRead = 0; + long totalBytesRead = 0; while (totalBytesRead < fileSize) { String partName = numberOfChunk + ""; - int bytesRemaining = fileSize - totalBytesRead; + long bytesRemaining = fileSize - totalBytesRead; if (bytesRemaining < chunkSize) { - chunkSize = bytesRemaining; + chunkSize = (int) bytesRemaining; } byte[] temporary = new byte[chunkSize]; int bytesRead = inStream.read(temporary, 0, chunkSize); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/FileUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/FileUtilsTest.java new file mode 100644 index 0000000000..390286aadd --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/FileUtilsTest.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class FileUtilsTest { + private TemporaryFolder tempDir; + + @Before + public void setUp() throws Exception { + tempDir = new TemporaryFolder(); + tempDir.create(); + } + + @After + public void tearDown() { + if (tempDir != null) { + tempDir.delete(); + } + } + + @Test + public void testSplitFileIntoChunks() throws Exception { + // Write file. + Random random = new Random(); + File file = tempDir.newFile("large_file"); + byte[] data = new byte[1017]; + random.nextBytes(data); + Files.write(file.toPath(), data); + + // Split file into chunks. + int chunkSize = 325; + List chunkPaths = FileUtils.splitFileIntoChunks(file, tempDir.newFolder().toPath(), chunkSize); + + // Verify. + int currentPosition = 0; + for (String chunkPath : chunkPaths) { + byte[] chunk = Files.readAllBytes(Path.of(chunkPath)); + assertTrue("Chunk size", currentPosition + chunk.length <= data.length); + Assert.assertArrayEquals(Arrays.copyOfRange(data, currentPosition, currentPosition + chunk.length), chunk); + currentPosition += chunk.length; + } + assertEquals(currentPosition, data.length); + } +}