diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 0c3aa40c5c28..fb91b89a2a4a 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -137,6 +137,8 @@ New Features * GITHUB#13597: Align doc value skipper interval boundaries when an interval contains a constant value. (Ignacio Vera) +* GITHUB#13604: Add Kmeans clustering on vectors (Mayya Sharipova, Jim Ferenczi, Tom Veasey) + Improvements --------------------- diff --git a/lucene/sandbox/src/java/module-info.java b/lucene/sandbox/src/java/module-info.java index c51a25691ef2..3daace50cee4 100644 --- a/lucene/sandbox/src/java/module-info.java +++ b/lucene/sandbox/src/java/module-info.java @@ -22,6 +22,7 @@ exports org.apache.lucene.payloads; exports org.apache.lucene.sandbox.codecs.idversion; + exports org.apache.lucene.sandbox.codecs.quantization; exports org.apache.lucene.sandbox.document; exports org.apache.lucene.sandbox.queries; exports org.apache.lucene.sandbox.search; diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java new file mode 100644 index 000000000000..bb9d3ca63df5 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.quantization; + +import static org.apache.lucene.sandbox.codecs.quantization.SampleReader.createSampleReader; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.NeighborQueue; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; + +/** KMeans clustering algorithm for vectors */ +public class KMeans { + public static final int MAX_NUM_CENTROIDS = Short.MAX_VALUE; // 32767 + public static final int DEFAULT_RESTARTS = 5; + public static final int DEFAULT_ITRS = 10; + public static final int DEFAULT_SAMPLE_SIZE = 100_000; + + private final RandomAccessVectorValues.Floats vectors; + private final int numVectors; + private final int numCentroids; + private final Random random; + private final KmeansInitializationMethod initializationMethod; + private final int restarts; + private final int iters; + + /** + * Cluster vectors into a given number of clusters + * + * @param vectors float vectors + * @param similarityFunction vector similarity function. For COSINE similarity, vectors must be + * normalized. + * @param numClusters number of cluster to cluster vector into + * @return results of clustering: produced centroids and for each vector its centroid + * @throws IOException when if there is an error accessing vectors + */ + public static Results cluster( + RandomAccessVectorValues.Floats vectors, + VectorSimilarityFunction similarityFunction, + int numClusters) + throws IOException { + return cluster( + vectors, + numClusters, + true, + 42L, + KmeansInitializationMethod.PLUS_PLUS, + similarityFunction == VectorSimilarityFunction.COSINE, + DEFAULT_RESTARTS, + DEFAULT_ITRS, + DEFAULT_SAMPLE_SIZE); + } + + /** + * Expert: Cluster vectors into a given number of clusters + * + * @param vectors float vectors + * @param numClusters number of cluster to cluster vector into + * @param assignCentroidsToVectors if {@code true} assign centroids for all vectors. Centroids are + * computed on a sample of vectors. If this parameter is {@code true}, in results also return + * for all vectors what centroids they belong to. + * @param seed random seed + * @param initializationMethod Kmeans initialization method + * @param normalizeCenters for cosine distance, set to true, to use spherical k-means where + * centers are normalized + * @param restarts how many times to run Kmeans algorithm + * @param iters how many iterations to do within a single run + * @param sampleSize sample size to select from all vectors on which to run Kmeans algorithm + * @return results of clustering: produced centroids and if {@code assignCentroidsToVectors == + * true} also for each vector its centroid + * @throws IOException if there is error accessing vectors + */ + public static Results cluster( + RandomAccessVectorValues.Floats vectors, + int numClusters, + boolean assignCentroidsToVectors, + long seed, + KmeansInitializationMethod initializationMethod, + boolean normalizeCenters, + int restarts, + int iters, + int sampleSize) + throws IOException { + if (vectors.size() == 0) { + return null; + } + if (numClusters < 1 || numClusters > MAX_NUM_CENTROIDS) { + throw new IllegalArgumentException( + "[numClusters] must be between [1] and [" + MAX_NUM_CENTROIDS + "]"); + } + // adjust sampleSize and numClusters + sampleSize = Math.max(sampleSize, 100 * numClusters); + if (sampleSize > vectors.size()) { + sampleSize = vectors.size(); + // Decrease the number of clusters if needed + int maxNumClusters = Math.max(1, sampleSize / 100); + numClusters = Math.min(numClusters, maxNumClusters); + } + + Random random = new Random(seed); + float[][] centroids; + if (numClusters == 1) { + centroids = new float[1][vectors.dimension()]; + } else { + RandomAccessVectorValues.Floats sampleVectors = + vectors.size() <= sampleSize ? vectors : createSampleReader(vectors, sampleSize, seed); + KMeans kmeans = + new KMeans(sampleVectors, numClusters, random, initializationMethod, restarts, iters); + centroids = kmeans.computeCentroids(normalizeCenters); + } + + short[] vectorCentroids = null; + // Assign each vector to the nearest centroid and update the centres + if (assignCentroidsToVectors) { + vectorCentroids = new short[vectors.size()]; + // Use kahan summation to get more precise results + KMeans.runKMeansStep(vectors, centroids, vectorCentroids, true, normalizeCenters); + } + return new Results(centroids, vectorCentroids); + } + + private KMeans( + RandomAccessVectorValues.Floats vectors, + int numCentroids, + Random random, + KmeansInitializationMethod initializationMethod, + int restarts, + int iters) { + this.vectors = vectors; + this.numVectors = vectors.size(); + this.numCentroids = numCentroids; + this.random = random; + this.initializationMethod = initializationMethod; + this.restarts = restarts; + this.iters = iters; + } + + private float[][] computeCentroids(boolean normalizeCenters) throws IOException { + short[] vectorCentroids = new short[numVectors]; + double minSquaredDist = Double.MAX_VALUE; + double squaredDist = 0; + float[][] bestCentroids = null; + + for (int restart = 0; restart < restarts; restart++) { + float[][] centroids = + switch (initializationMethod) { + case FORGY -> initializeForgy(); + case RESERVOIR_SAMPLING -> initializeReservoirSampling(); + case PLUS_PLUS -> initializePlusPlus(); + }; + double prevSquaredDist = Double.MAX_VALUE; + for (int iter = 0; iter < iters; iter++) { + squaredDist = runKMeansStep(vectors, centroids, vectorCentroids, false, normalizeCenters); + // Check for convergence + if (prevSquaredDist <= (squaredDist + 1e-6)) { + break; + } + prevSquaredDist = squaredDist; + } + if (squaredDist < minSquaredDist) { + minSquaredDist = squaredDist; + bestCentroids = centroids; + } + } + return bestCentroids; + } + + /** + * Initialize centroids using Forgy method: randomly select numCentroids vectors for initial + * centroids + */ + private float[][] initializeForgy() throws IOException { + Set selection = new HashSet<>(); + while (selection.size() < numCentroids) { + selection.add(random.nextInt(numVectors)); + } + float[][] initialCentroids = new float[numCentroids][]; + int i = 0; + for (Integer selectedIdx : selection) { + float[] vector = vectors.vectorValue(selectedIdx); + initialCentroids[i++] = ArrayUtil.copyOfSubArray(vector, 0, vector.length); + } + return initialCentroids; + } + + /** Initialize centroids using a reservoir sampling method */ + private float[][] initializeReservoirSampling() throws IOException { + float[][] initialCentroids = new float[numCentroids][]; + for (int index = 0; index < numVectors; index++) { + float[] vector = vectors.vectorValue(index); + if (index < numCentroids) { + initialCentroids[index] = ArrayUtil.copyOfSubArray(vector, 0, vector.length); + } else if (random.nextDouble() < numCentroids * (1.0 / index)) { + int c = random.nextInt(numCentroids); + initialCentroids[c] = ArrayUtil.copyOfSubArray(vector, 0, vector.length); + } + } + return initialCentroids; + } + + /** Initialize centroids using Kmeans++ method */ + private float[][] initializePlusPlus() throws IOException { + float[][] initialCentroids = new float[numCentroids][]; + // Choose the first centroid uniformly at random + int firstIndex = random.nextInt(numVectors); + float[] value = vectors.vectorValue(firstIndex); + initialCentroids[0] = ArrayUtil.copyOfSubArray(value, 0, value.length); + + // Store distances of each point to the nearest centroid + float[] minDistances = new float[numVectors]; + Arrays.fill(minDistances, Float.MAX_VALUE); + + // Step 2 and 3: Select remaining centroids + for (int i = 1; i < numCentroids; i++) { + // Update distances with the new centroid + double totalSum = 0; + for (int j = 0; j < numVectors; j++) { + // TODO: replace with RandomVectorScorer::score possible on quantized vectors + float dist = VectorUtil.squareDistance(vectors.vectorValue(j), initialCentroids[i - 1]); + if (dist < minDistances[j]) { + minDistances[j] = dist; + } + totalSum += minDistances[j]; + } + + // Randomly select next centroid + double r = totalSum * random.nextDouble(); + double cumulativeSum = 0; + int nextCentroidIndex = -1; + for (int j = 0; j < numVectors; j++) { + cumulativeSum += minDistances[j]; + if (cumulativeSum >= r && minDistances[j] > 0) { + nextCentroidIndex = j; + break; + } + } + // Update centroid + value = vectors.vectorValue(nextCentroidIndex); + initialCentroids[i] = ArrayUtil.copyOfSubArray(value, 0, value.length); + } + return initialCentroids; + } + + /** + * Run kmeans step + * + * @param vectors float vectors + * @param centroids centroids, new calculated centroids are written here + * @param docCentroids for each document which centroid it belongs to, results will be written + * here + * @param useKahanSummation for large datasets use Kahan summation to calculate centroids, since + * we can easily reach the limits of float precision + * @param normalizeCentroids if centroids should be normalized; used for cosine similarity only + * @throws IOException if there is an error accessing vector values + */ + private static double runKMeansStep( + RandomAccessVectorValues.Floats vectors, + float[][] centroids, + short[] docCentroids, + boolean useKahanSummation, + boolean normalizeCentroids) + throws IOException { + short numCentroids = (short) centroids.length; + + float[][] newCentroids = new float[numCentroids][centroids[0].length]; + int[] newCentroidSize = new int[numCentroids]; + float[][] compensations = null; + if (useKahanSummation) { + compensations = new float[numCentroids][centroids[0].length]; + } + + double sumSquaredDist = 0; + for (int docID = 0; docID < vectors.size(); docID++) { + float[] vector = vectors.vectorValue(docID); + short bestCentroid = 0; + if (numCentroids > 1) { + float minSquaredDist = Float.MAX_VALUE; + for (short c = 0; c < numCentroids; c++) { + // TODO: replace with RandomVectorScorer::score possible on quantized vectors + float squareDist = VectorUtil.squareDistance(centroids[c], vector); + if (squareDist < minSquaredDist) { + bestCentroid = c; + minSquaredDist = squareDist; + } + } + sumSquaredDist += minSquaredDist; + } + + newCentroidSize[bestCentroid] += 1; + for (int dim = 0; dim < vector.length; dim++) { + if (useKahanSummation) { + float y = vector[dim] - compensations[bestCentroid][dim]; + float t = newCentroids[bestCentroid][dim] + y; + compensations[bestCentroid][dim] = (t - newCentroids[bestCentroid][dim]) - y; + newCentroids[bestCentroid][dim] = t; + } else { + newCentroids[bestCentroid][dim] += vector[dim]; + } + } + docCentroids[docID] = bestCentroid; + } + + List unassignedCentroids = new ArrayList<>(); + for (int c = 0; c < numCentroids; c++) { + if (newCentroidSize[c] > 0) { + for (int dim = 0; dim < newCentroids[c].length; dim++) { + centroids[c][dim] = newCentroids[c][dim] / newCentroidSize[c]; + } + } else { + unassignedCentroids.add(c); + } + } + if (unassignedCentroids.size() > 0) { + assignCentroids(vectors, centroids, unassignedCentroids); + } + if (normalizeCentroids) { + for (int c = 0; c < centroids.length; c++) { + VectorUtil.l2normalize(centroids[c], false); + } + } + return sumSquaredDist; + } + + /** + * For centroids that did not get any points, assign outlying points to them chose points by + * descending distance to the current centroid set + */ + static void assignCentroids( + RandomAccessVectorValues.Floats vectors, + float[][] centroids, + List unassignedCentroidsIdxs) + throws IOException { + int[] assignedCentroidsIdxs = new int[centroids.length - unassignedCentroidsIdxs.size()]; + int assignedIndex = 0; + for (int i = 0; i < centroids.length; i++) { + if (unassignedCentroidsIdxs.contains(i) == false) { + assignedCentroidsIdxs[assignedIndex++] = i; + } + } + NeighborQueue queue = new NeighborQueue(unassignedCentroidsIdxs.size(), false); + for (int i = 0; i < vectors.size(); i++) { + float[] vector = vectors.vectorValue(i); + for (short j = 0; j < assignedCentroidsIdxs.length; j++) { + float squareDist = VectorUtil.squareDistance(centroids[assignedCentroidsIdxs[j]], vector); + queue.insertWithOverflow(i, squareDist); + } + } + for (int i = 0; i < unassignedCentroidsIdxs.size(); i++) { + float[] vector = vectors.vectorValue(queue.topNode()); + int unassignedCentroidIdx = unassignedCentroidsIdxs.get(i); + centroids[unassignedCentroidIdx] = ArrayUtil.copyArray(vector); + queue.pop(); + } + } + + /** Kmeans initialization methods */ + public enum KmeansInitializationMethod { + FORGY, + RESERVOIR_SAMPLING, + PLUS_PLUS + } + + /** + * Results of KMeans clustering + * + * @param centroids the produced centroids + * @param vectorCentroids for each vector which centroid it belongs to, we use short type, as we + * expect less than {@code MAX_NUM_CENTROIDS} which is equal to 32767 centroids. Can be {@code + * null} if they were not computed. + */ + public record Results(float[][] centroids, short[] vectorCentroids) {} +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java new file mode 100644 index 000000000000..9a718c811017 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.sandbox.codecs.quantization; + +import java.io.IOException; +import java.util.Random; +import java.util.function.IntUnaryOperator; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; + +/** A reader of vector values that samples a subset of the vectors. */ +public class SampleReader implements RandomAccessVectorValues.Floats { + private final RandomAccessVectorValues.Floats origin; + private final int sampleSize; + private final IntUnaryOperator sampleFunction; + + SampleReader( + RandomAccessVectorValues.Floats origin, int sampleSize, IntUnaryOperator sampleFunction) { + this.origin = origin; + this.sampleSize = sampleSize; + this.sampleFunction = sampleFunction; + } + + @Override + public int size() { + return sampleSize; + } + + @Override + public int dimension() { + return origin.dimension(); + } + + @Override + public Floats copy() throws IOException { + throw new IllegalStateException("Not supported"); + } + + @Override + public IndexInput getSlice() { + return origin.getSlice(); + } + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + return origin.vectorValue(sampleFunction.applyAsInt(targetOrd)); + } + + @Override + public int getVectorByteLength() { + return origin.getVectorByteLength(); + } + + @Override + public int ordToDoc(int ord) { + throw new IllegalStateException("Not supported"); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + throw new IllegalStateException("Not supported"); + } + + public static SampleReader createSampleReader( + RandomAccessVectorValues.Floats origin, int k, long seed) { + int[] samples = reservoirSample(origin.size(), k, seed); + return new SampleReader(origin, samples.length, i -> samples[i]); + } + + /** + * Sample k elements from n elements according to reservoir sampling algorithm. + * + * @param n number of elements + * @param k number of samples + * @param seed random seed + * @return array of k samples + */ + public static int[] reservoirSample(int n, int k, long seed) { + Random rnd = new Random(seed); + int[] reservoir = new int[k]; + for (int i = 0; i < k; i++) { + reservoir[i] = i; + } + for (int i = k; i < n; i++) { + int j = rnd.nextInt(i + 1); + if (j < k) { + reservoir[j] = i; + } + } + return reservoir; + } + + /** + * Sample k elements from the origin array using reservoir sampling algorithm. + * + * @param origin original array + * @param k number of samples + * @param seed random seed + * @return array of k samples + */ + public static int[] reservoirSampleFromArray(int[] origin, int k, long seed) { + Random rnd = new Random(seed); + if (k >= origin.length) { + return origin; + } + int[] reservoir = new int[k]; + for (int i = 0; i < k; i++) { + reservoir[i] = origin[i]; + } + for (int i = k; i < origin.length; i++) { + int j = rnd.nextInt(i + 1); + if (j < k) { + reservoir[j] = origin[i]; + } + } + return reservoir; + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/package-info.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/package-info.java new file mode 100644 index 000000000000..477fa361de3e --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** This package implements KMeans algorithm for clustering vectors */ +package org.apache.lucene.sandbox.codecs.quantization; diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java new file mode 100644 index 000000000000..61c0e58c91ef --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.sandbox.codecs.quantization; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; + +public class TestKMeans extends LuceneTestCase { + + public void testKMeansAPI() throws IOException { + int nClusters = random().nextInt(1, 10); + int nVectors = random().nextInt(nClusters * 100, nClusters * 200); + int dims = random().nextInt(2, 20); + int randIdx = random().nextInt(VectorSimilarityFunction.values().length); + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx]; + RandomAccessVectorValues.Floats vectors = generateData(nVectors, dims, nClusters); + + // default case + { + KMeans.Results results = KMeans.cluster(vectors, similarityFunction, nClusters); + assertEquals(nClusters, results.centroids().length); + assertEquals(nVectors, results.vectorCentroids().length); + } + // expert case + { + boolean assignCentroidsToVectors = random().nextBoolean(); + int randIdx2 = random().nextInt(KMeans.KmeansInitializationMethod.values().length); + KMeans.KmeansInitializationMethod initializationMethod = + KMeans.KmeansInitializationMethod.values()[randIdx2]; + int restarts = random().nextInt(1, 6); + int iters = random().nextInt(1, 10); + int sampleSize = random().nextInt(10, nVectors * 2); + + KMeans.Results results = + KMeans.cluster( + vectors, + nClusters, + assignCentroidsToVectors, + random().nextLong(), + initializationMethod, + similarityFunction == VectorSimilarityFunction.COSINE, + restarts, + iters, + sampleSize); + assertEquals(nClusters, results.centroids().length); + if (assignCentroidsToVectors) { + assertEquals(nVectors, results.vectorCentroids().length); + } else { + assertNull(results.vectorCentroids()); + } + } + } + + public void testKMeansSpecialCases() throws IOException { + { + // nClusters > nVectors + int nClusters = 20; + int nVectors = 10; + RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters); + KMeans.Results results = + KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters); + // assert that we get 1 centroid, as nClusters will be adjusted + assertEquals(1, results.centroids().length); + assertEquals(nVectors, results.vectorCentroids().length); + } + { + // small sample size + int sampleSize = 2; + int nClusters = 2; + int nVectors = 300; + RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters); + KMeans.KmeansInitializationMethod initializationMethod = + KMeans.KmeansInitializationMethod.PLUS_PLUS; + KMeans.Results results = + KMeans.cluster( + vectors, + nClusters, + true, + random().nextLong(), + initializationMethod, + false, + 1, + 2, + sampleSize); + assertEquals(nClusters, results.centroids().length); + assertEquals(nVectors, results.vectorCentroids().length); + } + { + // test unassigned centroids + int nClusters = 4; + int nVectors = 400; + RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters); + KMeans.Results results = + KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters); + float[][] centroids = results.centroids(); + List unassignedIdxs = List.of(0, 3); + KMeans.assignCentroids(vectors, centroids, unassignedIdxs); + assertEquals(nClusters, centroids.length); + } + } + + private static RandomAccessVectorValues.Floats generateData( + int nSamples, int nDims, int nClusters) { + List vectors = new ArrayList<>(nSamples); + float[][] centroids = new float[nClusters][nDims]; + // Generate random centroids + for (int i = 0; i < nClusters; i++) { + for (int j = 0; j < nDims; j++) { + centroids[i][j] = random().nextFloat() * 100; + } + } + // Generate data points around centroids + for (int i = 0; i < nSamples; i++) { + int cluster = random().nextInt(nClusters); + float[] vector = new float[nDims]; + for (int j = 0; j < nDims; j++) { + vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5; + } + vectors.add(vector); + } + return RandomAccessVectorValues.fromFloats(vectors, nDims); + } +}