diff --git a/bin/run.sh b/bin/run.sh index 5edfa6c7c..79d6e7f24 100755 --- a/bin/run.sh +++ b/bin/run.sh @@ -1,3 +1,3 @@ #!/bin/sh -java -cp `ls target/*-fatjar.jar` -Xms512M -Xmx192G --add-modules jdk.incubator.vector $@ +java -cp `ls target/*-fatjar.jar` -Xms512M -Xmx512G --add-modules jdk.incubator.vector $@ diff --git a/src/main/java/io/anserini/collection/ParquetDenseVectorCollection.java b/src/main/java/io/anserini/collection/ParquetDenseVectorCollection.java index 462d22c8f..b3192bf3a 100644 --- a/src/main/java/io/anserini/collection/ParquetDenseVectorCollection.java +++ b/src/main/java/io/anserini/collection/ParquetDenseVectorCollection.java @@ -137,17 +137,18 @@ private void initializeParquetReader(java.nio.file.Path path) throws IOException // Read each record from the Parquet file while ((record = reader.read()) != null) { // Extract the docid (String) from the record - String docid = record.getString("docid", 0); + String docid = record.getString("doc_id", 0); ids.add(docid); // Extract the vector (double[]) from the record - Group vectorGroup = record.getGroup("vector", 0); // Access the 'vector' field + Group vectorGroup = record.getGroup("embedding", 0); // Access the 'vector' field int vectorSize = vectorGroup.getFieldRepetitionCount(0); // Get the number of elements in the vector double[] vector = new double[vectorSize]; for (int i = 0; i < vectorSize; i++) { Group listGroup = vectorGroup.getGroup(0, i); // Access the 'list' group - vector[i] = listGroup.getDouble("element", 0); // Get the double value from the 'element' field + vector[i] = listGroup.getFloat("element", 0); // Get the double value from the 'element' field } + vector = normalizeVector(vector); vectors.add(vector); } @@ -155,6 +156,39 @@ private void initializeParquetReader(java.nio.file.Path path) throws IOException currentIndex = 0; } + /** + * Computes the L2 norm (Euclidean norm) of a vector. + * @param vector the vector to compute the norm of + * @return the L2 norm of the vector + */ + private static double computeL2Norm(double[] vector) { + double sumOfSquares = 0.0; + for (double v : vector) { + sumOfSquares += v * v; + } + return Math.sqrt(sumOfSquares); + } + + /** + * Normalizes a vector to have a norm of 1. + * @param vector the vector to normalize + * @return a new vector that is the normalized version of the input vector + */ + private static double[] normalizeVector(double[] vector) { + double norm = computeL2Norm(vector); + double[] normalizedVector = new double[vector.length]; + + if (norm == 0) { + throw new IllegalArgumentException("Zero vector cannot be normalized."); + } + + for (int i = 0; i < vector.length; i++) { + normalizedVector[i] = vector[i] / norm; + } + + return normalizedVector; + } + /** * Reads the next document in the segment. *