Skip to content

Commit

Permalink
ZIR-306: Add SHA-2 accelerator, fix division overconstraint bug, add …
Browse files Browse the repository at this point in the history
…rv32im compliance tests to C++ (#144)

* Sha2 passes smoke test: todo, add some wrapper code and more test vectors

* Improve sha tests to cover all the relevant cases

* Fixed div

* Remove unused code

* Fixes for CI

---------

Co-authored-by: Frank Laub <[email protected]>
  • Loading branch information
jbruestle and flaub authored Dec 29, 2024
1 parent 38a7797 commit cc4d252
Show file tree
Hide file tree
Showing 28 changed files with 1,152 additions and 61 deletions.
47 changes: 47 additions & 0 deletions zirgen/circuit/rv32im/v2/dsl/arr.zir
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// This file contains utilities that work with bits and twits.
// RUN: zirgen --test %s

// Vector / List functions

// Shifts + Rotates
component RotateLeft<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i - n, SIZE)) { in[i - n] } else { in[SIZE + i - n] }
}
}

component RotateRight<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i + n, SIZE)) { in[i + n] } else { in[i + n - SIZE] }
}
}

component ShiftLeft<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i - n, SIZE)) { in[i - n] } else { 0 }
}
}

component ShiftRight<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i + n, SIZE)) { in[i + n] } else { 0 }
}
}

component EqArr<SIZE: Val>(a: Array<Val, SIZE>, b: Array<Val, SIZE>) {
for i : 0..SIZE {
a[i] = b[i];
}
}

// Tests....

test ShiftAndRotate {
// TODO: Now that these support non-bit values, maybe make new tests
// Remember: array entry 0 is the low bit, so there seem backwards
EqArr<8>(ShiftRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 0, 0]);
EqArr<8>(ShiftLeft<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [0, 0, 1, 1, 1, 0, 1, 0]);
EqArr<8>(RotateRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 1, 1]);
EqArr<8>(RotateLeft<8>([1, 1, 1, 0, 1, 0, 0, 1], 2), [0, 1, 1, 1, 1, 0, 1, 0]);
}

10 changes: 7 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/bits.zir
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ function AssertTwit(val: Val) {
val * (1 - val) * (2 - val) * (3 - val) = 0;
}

// Simple bit ops
component BitAnd(a: Val, b: Val) {
Reg(a * b)
a * b
}

component BitOr(a: Val, b: Val) {
Reg(1 - (1 - a) * (1 - b))
a + b - a * b
}

component BitXor(a: Val, b: Val) {
a + b - 2 * a * b
}

// Set a register nodeterministically, and then verify it is a twit
Expand Down Expand Up @@ -81,4 +86,3 @@ test TwitInRange{
test_fails TwitOutOfRange {
AssertTwit(4);
}

9 changes: 8 additions & 1 deletion zirgen/circuit/rv32im/v2/dsl/consts.zir
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ component StatePoseidonStoreState() { 23 }
component StatePoseidonExtRound() { 24 }
component StatePoseidonIntRounds() { 25 }

component StateDecode() { 32 }
component StateShaEcall() { 32 }
component StateShaLoadState() { 33 }
component StateShaLoadData() { 34 }
component StateShaMix() { 35 }
component StateShaStoreState() { 36 }

component StateDecode() { 40 }

component RegA0() { 10 }
component RegA1() { 11 }
component RegA2() { 12 }
component RegA3() { 13 }
component RegA4() { 14 }

component RegA7() { 17 }
23 changes: 20 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/inst_div.zir
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,36 @@ component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) {
settings := MultiplySettings(signed, signed, signed);
// Do the accumulate
mul := MultiplyAccumulate(quot, denom, rem, settings);
// Check the main result (numer = quot * denom + rem
// Check the main result (numer = quot * denom + rem)
AssertEqU32(mul.outLow, numer);
// The top bits should all be 0 or all be 1
topBitType := NondetBitReg(1 - Isz(mul.outHigh.low));
AssertEqU32(mul.outHigh, ValU32(0xffff * topBitType, 0xffff * topBitType));
// Check if denom is zero
isZero := IsZero(denom.low + denom.high);
// Get top bit of numerator
topNum := NondetBitReg((numer.high & 0x8000) / 0x8000);
// Verify we got it right
U16Reg((numer.high - 0x8000 * topNum) * 2);
numNeg := topNum * signed;
// Get the absolute value of the denominator
denomNeg := mul.bNeg;
denomAbs := NormalizeU32(DenormedValU32(
denomNeg * (0x10000 - denom.low) + (1 - denomNeg) * denom.low,
denomNeg * (0xffff - denom.high) + (1 - denomNeg) * denom.high
));
// Flip the sign of the remainder if the numerator is negative
remNormal := NormalizeU32(DenormedValU32(
numNeg * (0x10000 - rem.low) + (1 - numNeg) * rem.low,
numNeg * (0xffff - rem.high) + (1 - numNeg) * rem.high
));
// Decide if we need to swap order of
// If non-zero, make sure 0 <= rem < denom
if (isZero) {
AssertEqU32(rem, numer);
} else {
cmp := CmpLessThanUnsigned(rem, denom);
cmp.is_less_than = 1;
lt := CmpLessThanUnsigned(remNormal, denomAbs);
lt.is_less_than = 1;
};
DivideReturn(quot, rem)
}
Expand Down
8 changes: 5 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ component MachineECall(cycle: Reg, input: InstInput, pc_addr: Val) {
input.mode = 1;
dispatch_idx := MemoryRead(cycle, MachineRegBase() + RegA7());
dispatch_idx.high = 0;
dispatch := OneHot<4>(dispatch_idx.low);
dispatch := OneHot<5>(dispatch_idx.low);
state := dispatch -> (
StateTerminate(),
StateHostReadSetup(),
StateHostWrite(),
StatePoseidonEcall()
StatePoseidonEcall(),
StateShaEcall()
);
ECallOutput(state, 0, 0, 0)
}
Expand Down Expand Up @@ -172,6 +173,7 @@ component ECall0(cycle: Reg, inst_input: InstInput) {
s2 := Reg(output.s2);
isDecode := IsZero(output.state - StateDecode());
isP2Entry := IsZero(output.state - StatePoseidonEcall());
addPC := NormalizeU32(AddU32(inst_input.pc_u32, ValU32((isDecode + isP2Entry) * 4, 0)));
isShaEcall := IsZero(output.state - StateShaEcall());
addPC := NormalizeU32(AddU32(inst_input.pc_u32, ValU32((isDecode + isP2Entry + isShaEcall) * 4, 0)));
InstOutput(addPC, output.state, 1)
}
2 changes: 2 additions & 0 deletions zirgen/circuit/rv32im/v2/dsl/inst_p2.zir
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ component PoseidonPaging(cycle: Reg, mode: Val, prev: PoseidonState) {

component Poseidon0(cycle:Reg, inst_input: InstInput) {
DoCycleTable(cycle);
inst_input.state = StatePoseidonEntry() + inst_input.minor;
state : PoseidonState;
state := inst_input.minor_onehot -> (
PoseidonEntry(cycle, inst_input.pc_u32, inst_input.mode),
Expand All @@ -480,6 +481,7 @@ component Poseidon0(cycle:Reg, inst_input: InstInput) {

component Poseidon1(cycle:Reg, inst_input: InstInput) {
DoCycleTable(cycle);
inst_input.state = StatePoseidonExtRound() + inst_input.minor;
state : PoseidonState;
state := inst_input.minor_onehot -> (
PoseidonExtRound(state@1),
Expand Down
Loading

0 comments on commit cc4d252

Please sign in to comment.