Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up vectorutil float scalar methods, unroll properly, use fma where possible #12737

Merged
merged 34 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b192064
Speed up vectorutil scalar methods, unroll properly, use fma where po…
rmuir Oct 31, 2023
93fed5f
clean up vector tails too, for consistency
rmuir Oct 31, 2023
f2be84f
detect AMD and don't use FMA there which causes slowdowns
rmuir Oct 31, 2023
990d27a
Merge branch 'stabilize_benchmark' into float_scalar_fma_unroll
Nov 2, 2023
d2635bf
update ARM fma logic, use it for serious ARM to get more out of it
Nov 2, 2023
cd5f8c7
Replace assert with IAE in StringsToAutomaton#build if data is not so…
shubhamvishu Oct 30, 2023
611c708
CHANGES entry for GITHUB#12427
gsmiller Oct 30, 2023
ff8cebc
Do not close merge threadpool in Lucene99HnswVectorsWriter
zhaih Oct 31, 2023
acfaf98
Fix test after #12549.
jpountz Oct 31, 2023
371812e
Fix file handle leak in Lucene99ScalarQuantizedVectorsWriter. (#12739)
jpountz Oct 31, 2023
108ad91
Fix NullPointerException in Monitor.getQuery when query is not presen…
daviscook477 Oct 31, 2023
4dac8eb
Specialize the 2nd clause of conjunctions. (#12713)
jpountz Oct 31, 2023
4cac882
Fix test failure.
jpountz Nov 1, 2023
68606f8
Fix javac task inputs so that they include modular dependencies #1274…
dweiss Nov 2, 2023
6e46124
Clean up UnCompiledNode.inputCount (#12735)
dungba88 Nov 2, 2023
bc17c88
LUCENE-10144:fix resource leak due to Files.list (#354)
lujiefsi Nov 2, 2023
7c41cfe
LUCENE-10100: configuration items of the alg file are adapted to the …
xiaoshi2013 Nov 2, 2023
7803482
ReleaseWizard - Upgrade 'consolemenu' dependency to v0.7.1 (#11855)
janhoy Nov 2, 2023
21be692
Remove unnecessary sort in writeFieldUpdates (#12273)
luyuncheng Nov 2, 2023
f77c2da
unify exception thrown by regexp & check repetition range (#12277)
tang-hi Nov 2, 2023
b152b8f
Speed up sorting on unique string fields. (#11903)
jpountz Nov 2, 2023
f46eb5a
move CSVUtil to common from analyzer nori and kuromoji (#12390)
twosom Nov 2, 2023
dacb9cb
Fix comment on decode method in PForUtil (#12495)
vsop-479 Nov 2, 2023
58bed62
Merge branch 'main' into float_scalar_fma_unroll
rmuir Nov 2, 2023
755717f
Merge branch 'main' into float_scalar_fma_unroll
uschindler Nov 3, 2023
0cd28c5
Fix code after merge to use new APIs
uschindler Nov 3, 2023
22e8377
Use Objects for null safety
uschindler Nov 3, 2023
4a9d074
fix policeman merge bug
rmuir Nov 3, 2023
4086abf
clean up logic
rmuir Nov 3, 2023
555841c
update logic for newer zen cores with lower latency FMA
Nov 3, 2023
b702bdf
add sysprop override: for uwe's use only
rmuir Nov 3, 2023
b65fb7a
tighten up AMD logic to perfection
rmuir Nov 4, 2023
503eebd
Add documentation
uschindler Nov 4, 2023
6432330
greatly improve Graviton2, the only issue here is apple silicon...
rmuir Nov 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,72 +17,46 @@

package org.apache.lucene.internal.vectorization;

import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;

final class DefaultVectorUtilSupport implements VectorUtilSupport {

DefaultVectorUtilSupport() {}

// the way FMA should work! if available use it, otherwise fall back to mul/add
@SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
private static float fma(float a, float b, float c) {
if (Constants.HAS_FAST_SCALAR_FMA) {
return Math.fma(a, b, c);
} else {
return a * b + c;
}
}

@Override
public float dotProduct(float[] a, float[] b) {
float res = 0f;
/*
* If length of vector is larger than 8, we use unrolled dot product to accelerate the
* calculation.
*/
int i;
for (i = 0; i < a.length % 8; i++) {
res += b[i] * a[i];
}
if (a.length < 8) {
return res;
}
for (; i + 31 < a.length; i += 32) {
res +=
b[i + 0] * a[i + 0]
+ b[i + 1] * a[i + 1]
+ b[i + 2] * a[i + 2]
+ b[i + 3] * a[i + 3]
+ b[i + 4] * a[i + 4]
+ b[i + 5] * a[i + 5]
+ b[i + 6] * a[i + 6]
+ b[i + 7] * a[i + 7];
res +=
b[i + 8] * a[i + 8]
+ b[i + 9] * a[i + 9]
+ b[i + 10] * a[i + 10]
+ b[i + 11] * a[i + 11]
+ b[i + 12] * a[i + 12]
+ b[i + 13] * a[i + 13]
+ b[i + 14] * a[i + 14]
+ b[i + 15] * a[i + 15];
res +=
b[i + 16] * a[i + 16]
+ b[i + 17] * a[i + 17]
+ b[i + 18] * a[i + 18]
+ b[i + 19] * a[i + 19]
+ b[i + 20] * a[i + 20]
+ b[i + 21] * a[i + 21]
+ b[i + 22] * a[i + 22]
+ b[i + 23] * a[i + 23];
res +=
b[i + 24] * a[i + 24]
+ b[i + 25] * a[i + 25]
+ b[i + 26] * a[i + 26]
+ b[i + 27] * a[i + 27]
+ b[i + 28] * a[i + 28]
+ b[i + 29] * a[i + 29]
+ b[i + 30] * a[i + 30]
+ b[i + 31] * a[i + 31];
int i = 0;

// if the array is big, unroll it
if (a.length > 32) {
float acc1 = 0;
float acc2 = 0;
float acc3 = 0;
float acc4 = 0;
int upperBound = a.length & ~(4 - 1);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic is from the javadocs of Species.loopBound of vector api where width is a power of 2. I used it in these functions for consistency, and because i assume it means the compiler will do a good job. we could maybe put in a static method for other places doing crap like this (e.g. stringhelper's murmurhash) as a followup? I'm guessing any other places do it ad-hoc like what was here before. I wanted to keep this PR minimally invasive though.

for (; i < upperBound; i += 4) {
acc1 = fma(a[i], b[i], acc1);
acc2 = fma(a[i + 1], b[i + 1], acc2);
acc3 = fma(a[i + 2], b[i + 2], acc3);
acc4 = fma(a[i + 3], b[i + 3], acc4);
}
res += acc1 + acc2 + acc3 + acc4;
}
for (; i + 7 < a.length; i += 8) {
res +=
b[i + 0] * a[i + 0]
+ b[i + 1] * a[i + 1]
+ b[i + 2] * a[i + 2]
+ b[i + 3] * a[i + 3]
+ b[i + 4] * a[i + 4]
+ b[i + 5] * a[i + 5]
+ b[i + 6] * a[i + 6]
+ b[i + 7] * a[i + 7];

for (; i < a.length; i++) {
res = fma(a[i], b[i], res);
}
return res;
}
Expand All @@ -92,50 +66,80 @@ public float cosine(float[] a, float[] b) {
float sum = 0.0f;
float norm1 = 0.0f;
float norm2 = 0.0f;
int dim = a.length;
int i = 0;

for (int i = 0; i < dim; i++) {
float elem1 = a[i];
float elem2 = b[i];
sum += elem1 * elem2;
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
// if the array is big, unroll it
if (a.length > 32) {
float sum1 = 0;
float sum2 = 0;
float norm1_1 = 0;
float norm1_2 = 0;
float norm2_1 = 0;
float norm2_2 = 0;

int upperBound = a.length & ~(2 - 1);
for (; i < upperBound; i += 2) {
// one
sum1 = fma(a[i], b[i], sum1);
norm1_1 = fma(a[i], a[i], norm1_1);
norm2_1 = fma(b[i], b[i], norm2_1);

// two
sum2 = fma(a[i + 1], b[i + 1], sum2);
norm1_2 = fma(a[i + 1], a[i + 1], norm1_2);
norm2_2 = fma(b[i + 1], b[i + 1], norm2_2);
}
sum += sum1 + sum2;
norm1 += norm1_1 + norm1_2;
norm2 += norm2_1 + norm2_2;
}

for (; i < a.length; i++) {
sum = fma(a[i], b[i], sum);
norm1 = fma(a[i], a[i], norm1);
norm2 = fma(b[i], b[i], norm2);
}
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}

@Override
public float squareDistance(float[] a, float[] b) {
float squareSum = 0.0f;
int dim = a.length;
int i;
for (i = 0; i + 8 <= dim; i += 8) {
squareSum += squareDistanceUnrolled(a, b, i);
float res = 0;
int i = 0;

// if the array is big, unroll it
if (a.length > 32) {
float acc1 = 0;
float acc2 = 0;
float acc3 = 0;
float acc4 = 0;

int upperBound = a.length & ~(4 - 1);
for (; i < upperBound; i += 4) {
// one
float diff1 = a[i] - b[i];
acc1 = fma(diff1, diff1, acc1);

// two
float diff2 = a[i + 1] - b[i + 1];
acc2 = fma(diff2, diff2, acc2);

// three
float diff3 = a[i + 2] - b[i + 2];
acc3 = fma(diff3, diff3, acc3);

// four
float diff4 = a[i + 3] - b[i + 3];
acc4 = fma(diff4, diff4, acc4);
}
res += acc1 + acc2 + acc3 + acc4;
}
for (; i < dim; i++) {

for (; i < a.length; i++) {
float diff = a[i] - b[i];
squareSum += diff * diff;
res = fma(diff, diff, res);
}
return squareSum;
}

private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
float diff0 = v1[index + 0] - v2[index + 0];
float diff1 = v1[index + 1] - v2[index + 1];
float diff2 = v1[index + 2] - v2[index + 2];
float diff3 = v1[index + 3] - v2[index + 3];
float diff4 = v1[index + 4] - v2[index + 4];
float diff5 = v1[index + 5] - v2[index + 5];
float diff6 = v1[index + 6] - v2[index + 6];
float diff7 = v1[index + 7] - v2[index + 7];
return diff0 * diff0
+ diff1 * diff1
+ diff2 * diff2
+ diff3 * diff3
+ diff4 * diff4
+ diff5 * diff5
+ diff6 * diff6
+ diff7 * diff7;
return res;
}

@Override
Expand Down
77 changes: 70 additions & 7 deletions lucene/core/src/java/org/apache/lucene/util/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Objects;
import java.util.logging.Logger;

/** Some useful constants. */
Expand Down Expand Up @@ -67,12 +66,6 @@ private Constants() {} // can't construct
/** True iff running on a 64bit JVM */
public static final boolean JRE_IS_64BIT = is64Bit();

/** true iff we know fast FMA is supported, to deliver less error */
public static final boolean HAS_FAST_FMA =
(IS_CLIENT_VM == false)
&& Objects.equals(OS_ARCH, "amd64")
&& HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);

private static boolean is64Bit() {
final String datamodel = getSysProp("sun.arch.data.model");
if (datamodel != null) {
Expand All @@ -82,6 +75,76 @@ private static boolean is64Bit() {
}
}

/** true if FMA likely means a cpu instruction and not BigDecimal logic */
private static final boolean HAS_FMA =
(IS_CLIENT_VM == false) && HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);

/** maximum supported vectorsize */
private static final int MAX_VECTOR_SIZE =
HotspotVMOptions.get("MaxVectorSize").map(Integer::valueOf).orElse(0);

/** true for an AMD cpu with SSE4a instructions */
private static final boolean HAS_SSE4A =
HotspotVMOptions.get("UseXmmI2F").map(Boolean::valueOf).orElse(false);

/** true iff we know VFMA has faster throughput than separate vmul/vadd */
public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA();

/** true iff we know FMA has faster throughput than separate mul/add */
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();

private static boolean hasFastVectorFMA() {
if (HAS_FMA) {
String value = getSysProp("lucene.useVectorFMA", "auto");
if ("auto".equals(value)) {
// newer Neoverse cores have their act together
// the problem is just apple silicon (this is a practical heuristic)
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
return true;
}
// zen cores or newer, its a wash, turn it on as it doesn't hurt
// starts to yield gains for vectors only at zen4+
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 32) {
return true;
}
// intel has their act together
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
return true;
}
} else {
return Boolean.parseBoolean(value);
}
}
// everyone else is slow, until proven otherwise by benchmarks
return false;
}

private static boolean hasFastScalarFMA() {
if (HAS_FMA) {
String value = getSysProp("lucene.useScalarFMA", "auto");
if ("auto".equals(value)) {
// newer Neoverse cores have their act together
// the problem is just apple silicon (this is a practical heuristic)
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
return true;
}
// latency becomes 4 for the Zen3 (0x19h), but still a wash
// until the Zen4 anyway, and big drop on previous zens:
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 64) {
return true;
}
// intel has their act together
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
return true;
}
} else {
return Boolean.parseBoolean(value);
}
}
// everyone else is slow, until proven otherwise by benchmarks
return false;
}

private static String getSysProp(String property) {
try {
return doPrivileged(() -> System.getProperty(property));
Expand Down
26 changes: 25 additions & 1 deletion lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,31 @@
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;

/** Utilities for computations with numeric arrays */
/**
* Utilities for computations with numeric arrays, especially algebraic operations like vector dot
* products. This class uses SIMD vectorization if the corresponding Java module is available and
* enabled. To enable vectorized code, pass {@code --add-modules jdk.incubator.vector} to Java's
* command line.
*
* <p>It will use CPU's <a href="https://en.wikipedia.org/wiki/Fused_multiply%E2%80%93add">FMA
* instructions</a> if it is known to perform faster than separate multiply+add. This requires at
* least Hotspot C2 enabled, which is the default for OpenJDK based JVMs.
*
* <p>To explicitly disable or enable FMA usage, pass the following system properties:
*
* <ul>
* <li>{@code -Dlucene.useScalarFMA=(auto|true|false)} for scalar operations
* <li>{@code -Dlucene.useVectorFMA=(auto|true|false)} for vectorized operations (with vector
* incubator module)
* </ul>
*
* <p>The default is {@code auto}, which enables this for known CPU types and JVM settings. If
* Hotspot C2 is disabled, FMA and vectorization are <strong>not</strong> used.
*
* <p>Vectorization and FMA is only supported for Hotspot-based JVMs; it won't work on OpenJ9-based
* JVMs unless they provide {@link com.sun.management.HotSpotDiagnosticMXBean}. Please also make
* sure that you have the {@code jdk.management} module enabled in modularized applications.
*/
public final class VectorUtil {

private static final VectorUtilSupport IMPL =
Expand Down
Loading
Loading