From 3a2ee7f07ebc91b86cb49d8f9be4c8b40cf4effb Mon Sep 17 00:00:00 2001 From: Robert Muir Date: Sat, 4 Nov 2023 19:25:58 -0400 Subject: [PATCH] Speed up vectorutil float scalar methods, unroll properly, use fma where possible (#12737) Co-authored-by: Uwe Schindler --- .../DefaultVectorUtilSupport.java | 188 +++++++++--------- .../org/apache/lucene/util/Constants.java | 77 ++++++- .../org/apache/lucene/util/VectorUtil.java | 26 ++- .../PanamaVectorUtilSupport.java | 24 ++- .../PanamaVectorizationProvider.java | 2 +- 5 files changed, 208 insertions(+), 109 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java index de546c9269a5..750e0ee136ae 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java @@ -17,72 +17,46 @@ package org.apache.lucene.internal.vectorization; +import org.apache.lucene.util.Constants; +import org.apache.lucene.util.SuppressForbidden; + final class DefaultVectorUtilSupport implements VectorUtilSupport { DefaultVectorUtilSupport() {} + // the way FMA should work! if available use it, otherwise fall back to mul/add + @SuppressForbidden(reason = "Uses FMA only where fast and carefully contained") + private static float fma(float a, float b, float c) { + if (Constants.HAS_FAST_SCALAR_FMA) { + return Math.fma(a, b, c); + } else { + return a * b + c; + } + } + @Override public float dotProduct(float[] a, float[] b) { float res = 0f; - /* - * If length of vector is larger than 8, we use unrolled dot product to accelerate the - * calculation. - */ - int i; - for (i = 0; i < a.length % 8; i++) { - res += b[i] * a[i]; - } - if (a.length < 8) { - return res; - } - for (; i + 31 < a.length; i += 32) { - res += - b[i + 0] * a[i + 0] - + b[i + 1] * a[i + 1] - + b[i + 2] * a[i + 2] - + b[i + 3] * a[i + 3] - + b[i + 4] * a[i + 4] - + b[i + 5] * a[i + 5] - + b[i + 6] * a[i + 6] - + b[i + 7] * a[i + 7]; - res += - b[i + 8] * a[i + 8] - + b[i + 9] * a[i + 9] - + b[i + 10] * a[i + 10] - + b[i + 11] * a[i + 11] - + b[i + 12] * a[i + 12] - + b[i + 13] * a[i + 13] - + b[i + 14] * a[i + 14] - + b[i + 15] * a[i + 15]; - res += - b[i + 16] * a[i + 16] - + b[i + 17] * a[i + 17] - + b[i + 18] * a[i + 18] - + b[i + 19] * a[i + 19] - + b[i + 20] * a[i + 20] - + b[i + 21] * a[i + 21] - + b[i + 22] * a[i + 22] - + b[i + 23] * a[i + 23]; - res += - b[i + 24] * a[i + 24] - + b[i + 25] * a[i + 25] - + b[i + 26] * a[i + 26] - + b[i + 27] * a[i + 27] - + b[i + 28] * a[i + 28] - + b[i + 29] * a[i + 29] - + b[i + 30] * a[i + 30] - + b[i + 31] * a[i + 31]; + int i = 0; + + // if the array is big, unroll it + if (a.length > 32) { + float acc1 = 0; + float acc2 = 0; + float acc3 = 0; + float acc4 = 0; + int upperBound = a.length & ~(4 - 1); + for (; i < upperBound; i += 4) { + acc1 = fma(a[i], b[i], acc1); + acc2 = fma(a[i + 1], b[i + 1], acc2); + acc3 = fma(a[i + 2], b[i + 2], acc3); + acc4 = fma(a[i + 3], b[i + 3], acc4); + } + res += acc1 + acc2 + acc3 + acc4; } - for (; i + 7 < a.length; i += 8) { - res += - b[i + 0] * a[i + 0] - + b[i + 1] * a[i + 1] - + b[i + 2] * a[i + 2] - + b[i + 3] * a[i + 3] - + b[i + 4] * a[i + 4] - + b[i + 5] * a[i + 5] - + b[i + 6] * a[i + 6] - + b[i + 7] * a[i + 7]; + + for (; i < a.length; i++) { + res = fma(a[i], b[i], res); } return res; } @@ -92,50 +66,80 @@ public float cosine(float[] a, float[] b) { float sum = 0.0f; float norm1 = 0.0f; float norm2 = 0.0f; - int dim = a.length; + int i = 0; - for (int i = 0; i < dim; i++) { - float elem1 = a[i]; - float elem2 = b[i]; - sum += elem1 * elem2; - norm1 += elem1 * elem1; - norm2 += elem2 * elem2; + // if the array is big, unroll it + if (a.length > 32) { + float sum1 = 0; + float sum2 = 0; + float norm1_1 = 0; + float norm1_2 = 0; + float norm2_1 = 0; + float norm2_2 = 0; + + int upperBound = a.length & ~(2 - 1); + for (; i < upperBound; i += 2) { + // one + sum1 = fma(a[i], b[i], sum1); + norm1_1 = fma(a[i], a[i], norm1_1); + norm2_1 = fma(b[i], b[i], norm2_1); + + // two + sum2 = fma(a[i + 1], b[i + 1], sum2); + norm1_2 = fma(a[i + 1], a[i + 1], norm1_2); + norm2_2 = fma(b[i + 1], b[i + 1], norm2_2); + } + sum += sum1 + sum2; + norm1 += norm1_1 + norm1_2; + norm2 += norm2_1 + norm2_2; + } + + for (; i < a.length; i++) { + sum = fma(a[i], b[i], sum); + norm1 = fma(a[i], a[i], norm1); + norm2 = fma(b[i], b[i], norm2); } return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } @Override public float squareDistance(float[] a, float[] b) { - float squareSum = 0.0f; - int dim = a.length; - int i; - for (i = 0; i + 8 <= dim; i += 8) { - squareSum += squareDistanceUnrolled(a, b, i); + float res = 0; + int i = 0; + + // if the array is big, unroll it + if (a.length > 32) { + float acc1 = 0; + float acc2 = 0; + float acc3 = 0; + float acc4 = 0; + + int upperBound = a.length & ~(4 - 1); + for (; i < upperBound; i += 4) { + // one + float diff1 = a[i] - b[i]; + acc1 = fma(diff1, diff1, acc1); + + // two + float diff2 = a[i + 1] - b[i + 1]; + acc2 = fma(diff2, diff2, acc2); + + // three + float diff3 = a[i + 2] - b[i + 2]; + acc3 = fma(diff3, diff3, acc3); + + // four + float diff4 = a[i + 3] - b[i + 3]; + acc4 = fma(diff4, diff4, acc4); + } + res += acc1 + acc2 + acc3 + acc4; } - for (; i < dim; i++) { + + for (; i < a.length; i++) { float diff = a[i] - b[i]; - squareSum += diff * diff; + res = fma(diff, diff, res); } - return squareSum; - } - - private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) { - float diff0 = v1[index + 0] - v2[index + 0]; - float diff1 = v1[index + 1] - v2[index + 1]; - float diff2 = v1[index + 2] - v2[index + 2]; - float diff3 = v1[index + 3] - v2[index + 3]; - float diff4 = v1[index + 4] - v2[index + 4]; - float diff5 = v1[index + 5] - v2[index + 5]; - float diff6 = v1[index + 6] - v2[index + 6]; - float diff7 = v1[index + 7] - v2[index + 7]; - return diff0 * diff0 - + diff1 * diff1 - + diff2 * diff2 - + diff3 * diff3 - + diff4 * diff4 - + diff5 * diff5 - + diff6 * diff6 - + diff7 * diff7; + return res; } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/Constants.java b/lucene/core/src/java/org/apache/lucene/util/Constants.java index b10b00b8904e..979f98fcbce3 100644 --- a/lucene/core/src/java/org/apache/lucene/util/Constants.java +++ b/lucene/core/src/java/org/apache/lucene/util/Constants.java @@ -18,7 +18,6 @@ import java.security.AccessController; import java.security.PrivilegedAction; -import java.util.Objects; import java.util.logging.Logger; /** Some useful constants. */ @@ -91,12 +90,6 @@ private Constants() {} // can't construct /** True iff running on a 64bit JVM */ public static final boolean JRE_IS_64BIT = is64Bit(); - /** true iff we know fast FMA is supported, to deliver less error */ - public static final boolean HAS_FAST_FMA = - (IS_CLIENT_VM == false) - && Objects.equals(OS_ARCH, "amd64") - && HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false); - private static boolean is64Bit() { final String datamodel = getSysProp("sun.arch.data.model"); if (datamodel != null) { @@ -106,6 +99,76 @@ private static boolean is64Bit() { } } + /** true if FMA likely means a cpu instruction and not BigDecimal logic */ + private static final boolean HAS_FMA = + (IS_CLIENT_VM == false) && HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false); + + /** maximum supported vectorsize */ + private static final int MAX_VECTOR_SIZE = + HotspotVMOptions.get("MaxVectorSize").map(Integer::valueOf).orElse(0); + + /** true for an AMD cpu with SSE4a instructions */ + private static final boolean HAS_SSE4A = + HotspotVMOptions.get("UseXmmI2F").map(Boolean::valueOf).orElse(false); + + /** true iff we know VFMA has faster throughput than separate vmul/vadd */ + public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA(); + + /** true iff we know FMA has faster throughput than separate mul/add */ + public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA(); + + private static boolean hasFastVectorFMA() { + if (HAS_FMA) { + String value = getSysProp("lucene.useVectorFMA", "auto"); + if ("auto".equals(value)) { + // newer Neoverse cores have their act together + // the problem is just apple silicon (this is a practical heuristic) + if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) { + return true; + } + // zen cores or newer, its a wash, turn it on as it doesn't hurt + // starts to yield gains for vectors only at zen4+ + if (HAS_SSE4A && MAX_VECTOR_SIZE >= 32) { + return true; + } + // intel has their act together + if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) { + return true; + } + } else { + return Boolean.parseBoolean(value); + } + } + // everyone else is slow, until proven otherwise by benchmarks + return false; + } + + private static boolean hasFastScalarFMA() { + if (HAS_FMA) { + String value = getSysProp("lucene.useScalarFMA", "auto"); + if ("auto".equals(value)) { + // newer Neoverse cores have their act together + // the problem is just apple silicon (this is a practical heuristic) + if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) { + return true; + } + // latency becomes 4 for the Zen3 (0x19h), but still a wash + // until the Zen4 anyway, and big drop on previous zens: + if (HAS_SSE4A && MAX_VECTOR_SIZE >= 64) { + return true; + } + // intel has their act together + if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) { + return true; + } + } else { + return Boolean.parseBoolean(value); + } + } + // everyone else is slow, until proven otherwise by benchmarks + return false; + } + private static String getSysProp(String property) { try { return doPrivileged(() -> System.getProperty(property)); diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index ef52e605dbc9..4a792c182441 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -20,7 +20,31 @@ import org.apache.lucene.internal.vectorization.VectorUtilSupport; import org.apache.lucene.internal.vectorization.VectorizationProvider; -/** Utilities for computations with numeric arrays */ +/** + * Utilities for computations with numeric arrays, especially algebraic operations like vector dot + * products. This class uses SIMD vectorization if the corresponding Java module is available and + * enabled. To enable vectorized code, pass {@code --add-modules jdk.incubator.vector} to Java's + * command line. + * + *

It will use CPU's FMA + * instructions if it is known to perform faster than separate multiply+add. This requires at + * least Hotspot C2 enabled, which is the default for OpenJDK based JVMs. + * + *

To explicitly disable or enable FMA usage, pass the following system properties: + * + *

+ * + *

The default is {@code auto}, which enables this for known CPU types and JVM settings. If + * Hotspot C2 is disabled, FMA and vectorization are not used. + * + *

Vectorization and FMA is only supported for Hotspot-based JVMs; it won't work on OpenJ9-based + * JVMs unless they provide {@link com.sun.management.HotSpotDiagnosticMXBean}. Please also make + * sure that you have the {@code jdk.management} module enabled in modularized applications. + */ public final class VectorUtil { private static final VectorUtilSupport IMPL = diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index d4e8a50ef8f4..ccd838cb8dd0 100644 --- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -29,6 +29,7 @@ import jdk.incubator.vector.VectorShape; import jdk.incubator.vector.VectorSpecies; import org.apache.lucene.util.Constants; +import org.apache.lucene.util.SuppressForbidden; /** * VectorUtil methods implemented with Panama incubating vector API. @@ -79,13 +80,22 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { // the way FMA should work! if available use it, otherwise fall back to mul/add private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) { - if (Constants.HAS_FAST_FMA) { + if (Constants.HAS_FAST_VECTOR_FMA) { return a.fma(b, c); } else { return a.mul(b).add(c); } } + @SuppressForbidden(reason = "Uses FMA only where fast and carefully contained") + private static float fma(float a, float b, float c) { + if (Constants.HAS_FAST_SCALAR_FMA) { + return Math.fma(a, b, c); + } else { + return a * b + c; + } + } + @Override public float dotProduct(float[] a, float[] b) { int i = 0; @@ -99,7 +109,7 @@ public float dotProduct(float[] a, float[] b) { // scalar tail for (; i < a.length; i++) { - res += b[i] * a[i]; + res = fma(a[i], b[i], res); } return res; } @@ -165,11 +175,9 @@ public float cosine(float[] a, float[] b) { // scalar tail for (; i < a.length; i++) { - float elem1 = a[i]; - float elem2 = b[i]; - sum += elem1 * elem2; - norm1 += elem1 * elem1; - norm2 += elem2 * elem2; + sum = fma(a[i], b[i], sum); + norm1 = fma(a[i], a[i], norm1); + norm2 = fma(b[i], b[i], norm2); } return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } @@ -230,7 +238,7 @@ public float squareDistance(float[] a, float[] b) { // scalar tail for (; i < a.length; i++) { float diff = a[i] - b[i]; - res += diff * diff; + res = fma(diff, diff, res); } return res; } diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index ffd18df1a270..11901d74f424 100644 --- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -63,7 +63,7 @@ private static T doPrivileged(PrivilegedAction action) { Locale.ENGLISH, "Java vector incubator API enabled; uses preferredBitSize=%d%s%s", PanamaVectorUtilSupport.VECTOR_BITSIZE, - Constants.HAS_FAST_FMA ? "; FMA enabled" : "", + Constants.HAS_FAST_VECTOR_FMA ? "; FMA enabled" : "", PanamaVectorUtilSupport.HAS_FAST_INTEGER_VECTORS ? "" : "; floating-point vectors only"));