Skip to content

Commit

Permalink
Implement relaxed simd operations on ARM-64
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg [email protected]
  • Loading branch information
zherczeg authored and clover2123 committed Oct 22, 2024
1 parent 4ae176a commit 28ce56c
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 33 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ jobs:
with:
submodules: true
- name: Build in arm32 container
uses: uraimo/run-on-arch-action@v2.7.2
uses: uraimo/run-on-arch-action@v2.8.1
with:
arch: armv7
distro: ubuntu_latest
Expand Down Expand Up @@ -214,7 +214,7 @@ jobs:
with:
submodules: true
- name: Build in arm64 container
uses: uraimo/run-on-arch-action@v2.7.2
uses: uraimo/run-on-arch-action@v2.8.1
with:
arch: aarch64
distro: ubuntu22.04
Expand Down
4 changes: 2 additions & 2 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
#define OTPopcntV128 OTOp1V128
#define OTSwizzleV128 OTOp2V128
#define OTShiftV128Tmp OTShiftV128
#define OTOp3DotAddV128 OTOp2V128
#define OTOp3DotAddV128 OTOp3V128

#elif (defined SLJIT_CONFIG_ARM_32 && SLJIT_CONFIG_ARM_32)

Expand All @@ -314,7 +314,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
#define OTPMinMaxV128 OTOp2V128
#define OTPopcntV128 OTOp1V128
#define OTShiftV128Tmp OTShiftV128
#define OTOp3DotAddV128 OTOp2V128
#define OTOp3DotAddV128 OTOp3V128

#endif /* SLJIT_CONFIG_ARM */

Expand Down
162 changes: 133 additions & 29 deletions src/jit/SimdArm64Inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ enum Type : uint32_t {
fdiv = 0x6e20fc00,
fmax = 0x4e20f400,
fmin = 0x4ea0f400,
fmla = 0x4e20cc00,
fmls = 0x4ea0cc00,
fmul = 0x6e20dc00,
fneg = 0x6ea0f800,
frintm = 0x4e219800, // floor
Expand Down Expand Up @@ -185,11 +187,15 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr)
break;
case ByteCode::I32X4TruncSatF32X4SOpcode:
case ByteCode::I32X4TruncSatF32X4UOpcode:
case ByteCode::I32X4RelaxedTruncF32X4SOpcode:
case ByteCode::I32X4RelaxedTruncF32X4UOpcode:
srcType = SLJIT_SIMD_ELEM_32 | SLJIT_SIMD_FLOAT;
dstType = SLJIT_SIMD_ELEM_32;
break;
case ByteCode::I32X4TruncSatF64X2SZeroOpcode:
case ByteCode::I32X4TruncSatF64X2UZeroOpcode:
case ByteCode::I32X4RelaxedTruncF64X2SZeroOpcode:
case ByteCode::I32X4RelaxedTruncF64X2UZeroOpcode:
srcType = SLJIT_SIMD_ELEM_64 | SLJIT_SIMD_FLOAT;
dstType = SLJIT_SIMD_ELEM_32;
break;
Expand Down Expand Up @@ -336,16 +342,20 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitOp(compiler, SimdOp::uxtl | (0x1 << 20) | (0x1 << 30), dst, args[0].arg, 0);
break;
case ByteCode::I32X4TruncSatF32X4SOpcode:
case ByteCode::I32X4RelaxedTruncF32X4SOpcode:
simdEmitOp(compiler, SimdOp::fcvtzs | SimdOp::FS4, dst, args[0].arg, 0);
break;
case ByteCode::I32X4TruncSatF32X4UOpcode:
case ByteCode::I32X4RelaxedTruncF32X4UOpcode:
simdEmitOp(compiler, SimdOp::fcvtzu | SimdOp::FS4, dst, args[0].arg, 0);
break;
case ByteCode::I32X4TruncSatF64X2SZeroOpcode:
case ByteCode::I32X4RelaxedTruncF64X2SZeroOpcode:
simdEmitOp(compiler, SimdOp::fcvtzs | SimdOp::FD2, dst, args[0].arg, 0);
simdEmitOp(compiler, SimdOp::sqxtn | SimdOp::S4, dst, dst, 0);
break;
case ByteCode::I32X4TruncSatF64X2UZeroOpcode:
case ByteCode::I32X4RelaxedTruncF64X2UZeroOpcode:
simdEmitOp(compiler, SimdOp::fcvtzu | SimdOp::FD2, dst, args[0].arg, 0);
simdEmitOp(compiler, SimdOp::uqxtn | SimdOp::S4, dst, dst, 0);
break;
Expand Down Expand Up @@ -587,28 +597,26 @@ static void simdEmitNarrowUnsigned(sljit_compiler* compiler, sljit_s32 rd, sljit
simdEmitOp(compiler, SimdOp::sqxtun | size | (0x1 << 30), rd, rm, 0);
}

static void simdEmitDot(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm)
static void simdEmitDot(sljit_compiler* compiler, uint32_t type, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm)
{
// The rd can be tmpReg1
#ifdef __ARM_FEATURE_DOTPROD
simdEmitOp(compiler, SimdOp::sdot | SimdOp::S4, rd, rn, rm);
simdEmitOp(compiler, SimdOp::sdot | type, rd, rn, rm);
#else
auto tmpReg1 = SLJIT_TMP_FR0;
auto tmpReg2 = SLJIT_TMP_FR1;
sljit_s32 tmpReg1 = SLJIT_TMP_FR0;
sljit_s32 tmpReg2 = SLJIT_TMP_FR1;
uint32_t lowType = type - (0x1 << SimdOp::sizeOffset);

// tmpReg1 = rn * rm lower
simdEmitOp(compiler, SimdOp::smull | SimdOp::H8, tmpReg1, rn, rm);
// tmpReg2 = rn * rm upper
simdEmitOp(compiler, SimdOp::smull | SimdOp::H8 | (0x1 << 30), tmpReg2, rn, rm);
// rd = tmpReg1[1], tmpReg2[1], tmpReg1[0], tmpReg2[0]
simdEmitOp(compiler, SimdOp::zip1 | SimdOp::S4, rd, tmpReg1, tmpReg2);
// tmpReg1 = tmpReg1[3], tmpReg2[3], tmpReg1[2], tmpReg2[2]
simdEmitOp(compiler, SimdOp::zip2 | SimdOp::S4, tmpReg1, tmpReg1, tmpReg2);
// rd = rd[3] + rd[2], rd[1] + rd[0]
simdEmitOp(compiler, SimdOp::saddlp | SimdOp::S4, rd, rd, 0);
// tmpReg1 = tmpReg1[3] + tmpReg1[2], tmpReg1[1] + tmpReg1[0]
simdEmitOp(compiler, SimdOp::saddlp | SimdOp::S4, tmpReg1, tmpReg1, 0);
// rd = rd[1], tmpReg1[1], rd[0], tmpReg1[0]
simdEmitOp(compiler, SimdOp::uzp1 | SimdOp::S4, rd, rd, tmpReg1);
simdEmitOp(compiler, SimdOp::smull | lowType | (0x1 << 30), tmpReg2, rn, rm);
// tmpReg1 = rn * rm lower
simdEmitOp(compiler, SimdOp::smull | lowType, tmpReg1, rn, rm);
// Widening result
simdEmitOp(compiler, SimdOp::saddlp | type, tmpReg2, tmpReg2, 0);
simdEmitOp(compiler, SimdOp::saddlp | type, tmpReg1, tmpReg1, 0);
// Combine + narrow
simdEmitOp(compiler, SimdOp::xtn | type, rd, tmpReg1, 0);
simdEmitOp(compiler, SimdOp::xtn | type | (0x1 << 30), rd, tmpReg2, 0);
#endif
}

Expand Down Expand Up @@ -643,6 +651,7 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::I8X16MaxUOpcode:
case ByteCode::I8X16AvgrUOpcode:
case ByteCode::I8X16SwizzleOpcode:
case ByteCode::I8X16RelaxedSwizzleOpcode:
srcType = SLJIT_SIMD_ELEM_8;
dstType = SLJIT_SIMD_ELEM_8;
break;
Expand Down Expand Up @@ -674,9 +683,11 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::I16X8MaxUOpcode:
case ByteCode::I16X8AvgrUOpcode:
case ByteCode::I16X8Q15mulrSatSOpcode:
case ByteCode::I16X8RelaxedQ15mulrSOpcode:
srcType = SLJIT_SIMD_ELEM_16;
dstType = SLJIT_SIMD_ELEM_16;
break;
case ByteCode::I16X8DotI8X16I7X16SOpcode:
case ByteCode::I16X8ExtmulLowI8X16SOpcode:
case ByteCode::I16X8ExtmulHighI8X16SOpcode:
case ByteCode::I16X8ExtmulLowI8X16UOpcode:
Expand Down Expand Up @@ -750,6 +761,8 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::F32X4PMaxOpcode:
case ByteCode::F32X4MaxOpcode:
case ByteCode::F32X4MinOpcode:
case ByteCode::F32X4RelaxedMaxOpcode:
case ByteCode::F32X4RelaxedMinOpcode:
srcType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_32;
dstType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_32;
break;
Expand All @@ -768,6 +781,8 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::F64X2PMaxOpcode:
case ByteCode::F64X2MaxOpcode:
case ByteCode::F64X2MinOpcode:
case ByteCode::F64X2RelaxedMaxOpcode:
case ByteCode::F64X2RelaxedMinOpcode:
srcType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_64;
dstType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_64;
break;
Expand Down Expand Up @@ -965,8 +980,12 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitNarrowUnsigned(compiler, dst, args[0].arg, args[1].arg, SimdOp::H8);
break;
case ByteCode::I16X8Q15mulrSatSOpcode:
case ByteCode::I16X8RelaxedQ15mulrSOpcode:
simdEmitOp(compiler, SimdOp::sqrdmulh | SimdOp::H8, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8DotI8X16I7X16SOpcode:
simdEmitDot(compiler, SimdOp::H8, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I32X4AddOpcode:
simdEmitOp(compiler, SimdOp::add | SimdOp::S4, dst, args[0].arg, args[1].arg);
break;
Expand Down Expand Up @@ -1032,7 +1051,7 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitOp(compiler, SimdOp::umull | SimdOp::H8 | (0x1 << 30), dst, args[0].arg, args[1].arg);
break;
case ByteCode::I32X4DotI16X8SOpcode:
simdEmitDot(compiler, dst, args[0].arg, args[1].arg);
simdEmitDot(compiler, SimdOp::S4, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I64X2AddOpcode:
simdEmitOp(compiler, SimdOp::add | SimdOp::D2, dst, args[0].arg, args[1].arg);
Expand Down Expand Up @@ -1069,9 +1088,11 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitOp(compiler, SimdOp::fdiv | SimdOp::FS4, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F32X4MaxOpcode:
case ByteCode::F32X4RelaxedMaxOpcode:
simdEmitOp(compiler, SimdOp::fmax | SimdOp::FS4, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F32X4MinOpcode:
case ByteCode::F32X4RelaxedMinOpcode:
simdEmitOp(compiler, SimdOp::fmin | SimdOp::FS4, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F32X4MulOpcode:
Expand All @@ -1093,9 +1114,11 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitOp(compiler, SimdOp::fdiv | SimdOp::FD2, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F64X2MaxOpcode:
case ByteCode::F64X2RelaxedMaxOpcode:
simdEmitOp(compiler, SimdOp::fmax | SimdOp::FD2, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F64X2MinOpcode:
case ByteCode::F64X2RelaxedMinOpcode:
simdEmitOp(compiler, SimdOp::fmin | SimdOp::FD2, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F64X2MulOpcode:
Expand Down Expand Up @@ -1154,6 +1177,7 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitOp(compiler, SimdOp::bic, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I8X16SwizzleOpcode:
case ByteCode::I8X16RelaxedSwizzleOpcode:
simdEmitOp(compiler, SimdOp::tbl, dst, args[0].arg, args[1].arg);
break;
default:
Expand All @@ -1166,26 +1190,106 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
}
}

static void simdEmitDotAdd(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm, sljit_s32 ro)
{
sljit_s32 tmpReg1 = SLJIT_TMP_FR0;

simdEmitDot(compiler, SimdOp::H8, tmpReg1, rn, rm);
simdEmitOp(compiler, SimdOp::saddlp | SimdOp::H8, tmpReg1, tmpReg1, 0);
simdEmitOp(compiler, SimdOp::add | SimdOp::S4, rd, ro, tmpReg1);
}

static void emitTernarySIMD(sljit_compiler* compiler, Instruction* instr)
{
Operand* operands = instr->operands();
JITArg args[3];
JITArg args[4];

sljit_s32 srcType = SLJIT_SIMD_ELEM_128;
sljit_s32 dstType = SLJIT_SIMD_ELEM_128;
bool moveToDst = true;

switch (instr->opcode()) {
case ByteCode::V128BitSelectOpcode:
srcType = SLJIT_SIMD_ELEM_128;
dstType = SLJIT_SIMD_ELEM_128;
break;
case ByteCode::I8X16RelaxedLaneSelectOpcode:
srcType = SLJIT_SIMD_ELEM_8;
dstType = SLJIT_SIMD_ELEM_8;
break;
case ByteCode::I16X8RelaxedLaneSelectOpcode:
srcType = SLJIT_SIMD_ELEM_16;
dstType = SLJIT_SIMD_ELEM_16;
break;
case ByteCode::I32X4RelaxedLaneSelectOpcode:
srcType = SLJIT_SIMD_ELEM_32;
dstType = SLJIT_SIMD_ELEM_32;
break;
case ByteCode::I64X2RelaxedLaneSelectOpcode:
srcType = SLJIT_SIMD_ELEM_64;
dstType = SLJIT_SIMD_ELEM_64;
break;
case ByteCode::I32X4DotI8X16I7X16AddSOpcode:
srcType = SLJIT_SIMD_ELEM_8;
dstType = SLJIT_SIMD_ELEM_32;
moveToDst = false;
break;
case ByteCode::F32X4RelaxedMaddOpcode:
case ByteCode::F32X4RelaxedNmaddOpcode:
srcType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_32;
dstType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_32;
break;
case ByteCode::F64X2RelaxedMaddOpcode:
case ByteCode::F64X2RelaxedNmaddOpcode:
srcType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_64;
dstType = SLJIT_SIMD_FLOAT | SLJIT_SIMD_ELEM_64;
break;
default:
ASSERT_NOT_REACHED();
break;
}

simdOperandToArg(compiler, operands, args[0], SLJIT_SIMD_ELEM_128, instr->requiredReg(0));
simdOperandToArg(compiler, operands + 1, args[1], SLJIT_SIMD_ELEM_128, instr->requiredReg(1));
simdOperandToArg(compiler, operands + 2, args[2], SLJIT_SIMD_ELEM_128, instr->requiredReg(2));
simdOperandToArg(compiler, operands, args[0], srcType, instr->requiredReg(0));
simdOperandToArg(compiler, operands + 1, args[1], srcType, instr->requiredReg(1));
simdOperandToArg(compiler, operands + 2, args[2], dstType, instr->requiredReg(2));

sljit_s32 dst = instr->requiredReg(2);
args[3].set(operands + 3);
sljit_s32 dst = GET_TARGET_REG(args[3].arg, instr->requiredReg(2));

if (dst != args[2].arg) {
sljit_emit_simd_mov(compiler, SLJIT_SIMD_LOAD | SLJIT_SIMD_REG_128 | SLJIT_SIMD_ELEM_128, dst, args[2].arg, args[2].argw);
if (moveToDst && dst != args[2].arg) {
sljit_emit_simd_mov(compiler, SLJIT_SIMD_REG_128 | srcType, dst, args[2].arg, 0);
}

simdEmitOp(compiler, SimdOp::bsl, dst, args[0].arg, args[1].arg);
switch (instr->opcode()) {
case ByteCode::V128BitSelectOpcode:
case ByteCode::I8X16RelaxedLaneSelectOpcode:
case ByteCode::I16X8RelaxedLaneSelectOpcode:
case ByteCode::I32X4RelaxedLaneSelectOpcode:
case ByteCode::I64X2RelaxedLaneSelectOpcode:
simdEmitOp(compiler, SimdOp::bsl, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I32X4DotI8X16I7X16AddSOpcode:
simdEmitDotAdd(compiler, dst, args[0].arg, args[1].arg, args[2].arg);
break;
case ByteCode::F32X4RelaxedMaddOpcode:
simdEmitOp(compiler, SimdOp::fmla | SimdOp::FS4, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F32X4RelaxedNmaddOpcode:
simdEmitOp(compiler, SimdOp::fmls | SimdOp::FS4, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F64X2RelaxedMaddOpcode:
simdEmitOp(compiler, SimdOp::fmla | SimdOp::FD2, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F64X2RelaxedNmaddOpcode:
simdEmitOp(compiler, SimdOp::fmls | SimdOp::FD2, dst, args[0].arg, args[1].arg);
break;
default:
ASSERT_NOT_REACHED();
break;
}

args[2].set(operands + 3);
if (SLJIT_IS_MEM(args[2].arg)) {
sljit_emit_simd_mov(compiler, SLJIT_SIMD_STORE | SLJIT_SIMD_REG_128 | SLJIT_SIMD_ELEM_128, dst, args[2].arg, args[2].argw);
if (SLJIT_IS_MEM(args[3].arg)) {
sljit_emit_simd_mov(compiler, SLJIT_SIMD_STORE | SLJIT_SIMD_REG_128 | dstType, dst, args[3].arg, args[3].argw);
}
}

Expand Down

0 comments on commit 28ce56c

Please sign in to comment.