Skip to content

Commit

Permalink
fix: witnesscalc build (#91)
Browse files Browse the repository at this point in the history
* `aes-gcm` compiles

* removing consts

* stash pop

* checkpoint

* uncommented in main template

Now the issues seem to only be in GCTR and KeyExpansion

* unravelled a new error

* this is working still, fold is not

* FIXED KEY EXPANSION

* VERY CLOSE

* this still works

* this still works

* it works now

YOU CANNOT HAVE A LOOP THAT EVER HAS A 0 SIZE OR ELSE THIS THING POOPS

* if logic to fix STUPID witcalc

* few notes

* update aes-gcm-fold.circom

* Update ff.circom

* MAKE `FieldInv` fast af boi

* improve FieldInv test

* Fixup the tests

* fix tests

* rm unnecessary diffs

---------

Signed-off-by: Thor Kamphefner <[email protected]>
Signed-off-by: Waylon Jepsen <[email protected]>
Co-authored-by: Waylon Jepsen <[email protected]>
Co-authored-by: devloper <[email protected]>
Co-authored-by: Sambhav Dusad <[email protected]>
  • Loading branch information
4 people authored Oct 21, 2024
1 parent f0b1dcc commit ac209d7
Show file tree
Hide file tree
Showing 13 changed files with 247 additions and 309 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ client/static/build
circuits/test/**/*.circom
circuits/test/*.circom
circuits/main/*
ir_log/*
log_input_signals.txt
*.bin
40 changes: 16 additions & 24 deletions circuits/aes-gcm/aes-gcm-fold.circom
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pragma circom 2.1.9;

include "./aes-gcm-foldable.circom";

// Compute AES-GCM
template AESGCMFOLD(bytesPerFold, totalBytes) {
// cannot fold outside chunk boundaries.
assert(bytesPerFold % 16 == 0);
Expand All @@ -14,47 +15,38 @@ template AESGCMFOLD(bytesPerFold, totalBytes) {

// Output from the last encryption step
// Always use last bytes for inputs which are not same size.
// step_in[0] => lastCounter
// step_in[1] => lastTag
// step_in[2] => foldedBlocks
signal input step_in[48];
// step_in[0..4] => lastCounter
// step_in[4..20] => lastTag
// step_in[20] => foldedBlocks
signal input step_in[21];

// For now, attempt to support variable fold size. Potential fix at 16 in the future.
component aes = AESGCMFOLDABLE(bytesPerFold, totalBytes\16);
aes.key <== key;
aes.iv <== iv;
aes.aad <== aad;
aes.key <== key;
aes.iv <== iv;
aes.aad <== aad;
aes.plainText <== plainText;

// Fold inputs
var inputIndex = bytesPerFold-4;
for(var i = 0; i < 4; i++) {
aes.lastCounter[i] <== step_in[inputIndex];
inputIndex+=1;
aes.lastCounter[i] <== step_in[i];
}

for(var i = 0; i < 16; i++) {
aes.lastTag[i] <== step_in[inputIndex];
inputIndex+=1;
aes.lastTag[i] <== step_in[4 + i];
}
// TODO: range check, assertions, stuff.
inputIndex+=15;
aes.foldedBlocks <== step_in[inputIndex];
aes.foldedBlocks <== step_in[20];

// Fold Outputs
signal output step_out[48];
var outputIndex = bytesPerFold-4;
signal output step_out[21];
for(var i = 0; i < 4; i++) {
step_out[outputIndex] <== aes.counter[i];
outputIndex+=1;
step_out[i] <== aes.counter[i];
}
for(var i = 0; i < 16; i++) {
step_out[outputIndex] <== aes.authTag[i];
outputIndex+=1;
step_out[4 + i] <== aes.authTag[i];
}
outputIndex+=15;
step_out[outputIndex] <== step_in[inputIndex] + bytesPerFold \ 16;
step_out[20] <== step_in[20] + bytesPerFold \ 16;

signal output authTag[16] <== aes.authTag;
signal output cipherText[bytesPerFold] <== aes.cipherText;
}
}
61 changes: 34 additions & 27 deletions circuits/aes-gcm/aes-gcm-foldable.circom
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ template AESGCMFOLDABLE(l, TOTAL_BLOCKS) {
}

// Step 1: Let H = CIPHK(0128)
component cipherH = Cipher(4); // 128-bit key -> 4 32-bit words -> 10 rounds
component cipherH = Cipher(); // 128-bit key -> 4 32-bit words -> 10 rounds
cipherH.key <== key;
cipherH.block <== zeroBlock.blocks[0];

Expand All @@ -75,7 +75,7 @@ template AESGCMFOLDABLE(l, TOTAL_BLOCKS) {
J0[3] <== J0WordIncrementer.out;

// Step 3: Let C = GCTRK(inc32(J0), P)
component gctr = GCTR(l, 4);
component gctr = GCTR(l);
gctr.key <== key;
gctr.initialCounterBlock <== J0;
gctr.plainText <== plainText;
Expand All @@ -89,15 +89,16 @@ template AESGCMFOLDABLE(l, TOTAL_BLOCKS) {
// len(A) => u64
// len(b) => u64 (together, 1 block)
//
var blockCount = l\16 + (l%16 > 0 ? 1 : 0); // blocksize is 16 bytes
var blockCount = l\16 + (l%16 > 0 ? 1 : 0); // blocksize is 16 bytes
var ghashBlocks = 1 + blockCount + 1;

component targetMode = SelectGhashMode(TOTAL_BLOCKS, blockCount, ghashBlocks);
component targetMode = SelectGhashMode(TOTAL_BLOCKS, blockCount, ghashBlocks);
targetMode.foldedBlocks <== foldedBlocks;

// TODO(CR 2024-10-18): THIS BLOCK IS PROBLEM CHILD SO FAR
// S = GHASHH (A || 0^v || C || 0^u || [len(A)] || [len(C)]).
component selectedBlocks = SelectGhashBlocks(l, ghashBlocks, TOTAL_BLOCKS);
selectedBlocks.aad <== aad;
selectedBlocks.aad <== aad;
selectedBlocks.cipherText <== gctr.cipherText;
selectedBlocks.targetMode <== targetMode.mode;

Expand Down Expand Up @@ -136,7 +137,7 @@ template AESGCMFOLDABLE(l, TOTAL_BLOCKS) {
}

// Step 6: Encrypt the tag. Let T = MSBt(GCTRK(J0, S))
component gctrT = GCTR(16, 4);
component gctrT = GCTR(16);
gctrT.key <== key;
gctrT.initialCounterBlock <== StartJ0.blocks[0];
gctrT.plainText <== selectTag.tag;
Expand Down Expand Up @@ -171,30 +172,30 @@ template SelectGhashBlocks(l, ghashBlocks, totalBlocks) {
signal targetBlocks[3][ghashBlocks*4*4];
signal modeToBlocks[4] <== [0, 0, 1, 2];

component start = GhashStartMode(l, totalBlocks, ghashBlocks);
start.aad <== aad;
component start = GhashStartMode(l, totalBlocks, ghashBlocks);
start.aad <== aad;
start.cipherText <== cipherText;
targetBlocks[0] <== start.blocks;
targetBlocks[0] <== start.blocks;

component stream = GhashStreamMode(l, ghashBlocks);
component stream = GhashStreamMode(l, ghashBlocks);
stream.cipherText <== cipherText;
targetBlocks[1] <== stream.blocks;
targetBlocks[1] <== stream.blocks;

component end = GhashEndMode(l, totalBlocks, ghashBlocks);
end.cipherText <== cipherText;
component end = GhashEndMode(l, totalBlocks, ghashBlocks);
end.cipherText <== cipherText;
targetBlocks[2] <== end.blocks;

component mapModeToArray = Selector(4);
mapModeToArray.in <== modeToBlocks;
mapModeToArray.index <== targetMode;
mapModeToArray.in <== modeToBlocks;
mapModeToArray.index <== targetMode;

component chooseBlocks = ArraySelector(3, ghashBlocks*4*4);
chooseBlocks.in <== targetBlocks;
chooseBlocks.index <== mapModeToArray.out;
chooseBlocks.in <== targetBlocks;
chooseBlocks.index <== mapModeToArray.out;

component toBlocks = ToBlocks(ghashBlocks*4*4);
toBlocks.stream <== chooseBlocks.out;
blocks <== toBlocks.blocks;
toBlocks.stream <== chooseBlocks.out;
blocks <== toBlocks.blocks;
}

template SelectGhashTag(ghashBlocks) {
Expand Down Expand Up @@ -230,10 +231,10 @@ template SelectGhashMode(totalBlocks, blocksPerFold, ghashBlocks) {
// May need to compute these differently due to foldedBlocks.
// i.e. using GT operator, Equal operator, etc.
signal isFinish <-- (blocksPerFold >= totalBlocks-foldedBlocks) ? 1 : 0;
signal isStart <-- (foldedBlocks == 0) ? 1: 0;
signal isStart <-- (foldedBlocks == 0) ? 1: 0;

isFinish * (isFinish - 1) === 0;
isStart * (isStart - 1) === 0;
isStart * (isStart - 1) === 0;

// case isStart && isFinish: START_END_MODE
// case isStart && !isFinish: START_MODE
Expand All @@ -247,9 +248,9 @@ template SelectGhashMode(totalBlocks, blocksPerFold, ghashBlocks) {
choice.s <== [isStart, isFinish];

signal isStartEndMode <== IsEqual()([choice.out, m.START_END_MODE]);
signal isStartMode <== IsEqual()([choice.out, m.START_MODE]);
signal isStreamMode <== IsEqual()([choice.out, m.STREAM_MODE]);
signal isEndMode <== IsEqual()([choice.out, m.END_MODE]);
signal isStartMode <== IsEqual()([choice.out, m.START_MODE]);
signal isStreamMode <== IsEqual()([choice.out, m.STREAM_MODE]);
signal isEndMode <== IsEqual()([choice.out, m.END_MODE]);

isStartEndMode + isStartMode + isStreamMode + isEndMode === 1;

Expand Down Expand Up @@ -294,7 +295,8 @@ template GhashStartMode(l, totalBlocks, ghashBlocks) {
// Insert in reversed (big endian) order.
blocks[blockIndex+7-i] <== byteValue;
}
blockIndex+=8;
// 16 + l + 8 + 8
blockIndex+=8; // TODO(CR 2024-10-18): I don't think this does anything
}

// TODO: Mildly more efficient if we add this, maybe it's needed?
Expand Down Expand Up @@ -347,5 +349,10 @@ template GhashEndMode(l, totalBlocks, ghashBlocks) {
// Insert in reversed (big endian) order.
blocks[blockIndex+7-i] <== byte_value;
}
blockIndex+=8;
}
blockIndex+=8;
// NOTE: Added this so all of blocks is written
for (var i = 0; i<16; i++) {
blocks[blockIndex] <== 0;
blockIndex += 1;
}
}
9 changes: 5 additions & 4 deletions circuits/aes-gcm/aes-gcm.circom
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ template AESGCM(l) {
}

// Step 1: Let H = CIPHK(0128)
component cipherH = Cipher(4); // 128-bit key -> 4 32-bit words -> 10 rounds
component cipherH = Cipher(); // 128-bit key -> 4 32-bit words -> 10 rounds
cipherH.key <== key;
cipherH.block <== zeroBlock.blocks[0];

Expand All @@ -65,11 +65,12 @@ template AESGCM(l) {
J0[3] <== J0WordIncrementer2.out;

// Step 3: Let C = GCTRK(inc32(J0), P)
component gctr = GCTR(l, 4);
component gctr = GCTR(l);
gctr.key <== key;
gctr.initialCounterBlock <== J0;
gctr.plainText <== plainText;


// Step 4: Let u and v
var u = 128 * (l \ 128) - l;
// when we handle dynamic aad lengths, we'll need to change this
Expand Down Expand Up @@ -156,11 +157,11 @@ template AESGCM(l) {
// log("end ghash bytes");

// Step 6: Let T = MSBt(GCTRK(J0, S))
component gctrT = GCTR(16, 4);
component gctrT = GCTR(16);
gctrT.key <== key;
gctrT.initialCounterBlock <== J0;
gctrT.plainText <== bytes;

authTag <== gctrT.cipherText;
cipherText <== gctr.cipherText;
}
}
38 changes: 19 additions & 19 deletions circuits/aes-gcm/aes/cipher.circom
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,27 @@ include "mix_columns.circom";
//
// Ciphertext


// @param nk: number of keys which can be 4, 6, 8
// @inputs block: 4x4 matrix representing the input block
// @inputs key: array of nk*4 bytes representing the key
// @outputs cipher: 4x4 matrix representing the output block
template Cipher(nk){
assert(nk == 4 || nk == 6 || nk == 8 );
template Cipher(){
signal input block[4][4];
signal input key[nk * 4];
signal input key[16];
signal output cipher[4][4];

var nr = Rounds(nk);
// var nr = Rounds(nk);

component keyExpansion = KeyExpansion(nk,nr);
component keyExpansion = KeyExpansion();
keyExpansion.key <== key;

component addRoundKey[nr+1];
component subBytes[nr];
component shiftRows[nr];
component mixColumns[nr-1];
component addRoundKey[11];
component subBytes[10];
component shiftRows[10];
component mixColumns[9];

signal interBlock[nr][4][4];
signal interBlock[10][4][4];

addRoundKey[0] = AddRoundKey();
addRoundKey[0].state <== block;
Expand All @@ -59,7 +59,7 @@ template Cipher(nk){
}

interBlock[0] <== addRoundKey[0].newState;
for (var i = 1; i < nr; i++) {
for (var i = 1; i < 10; i++) {
subBytes[i-1] = SubBlock();
subBytes[i-1].state <== interBlock[i-1];

Expand All @@ -78,19 +78,19 @@ template Cipher(nk){
interBlock[i] <== addRoundKey[i].newState;
}

subBytes[nr-1] = SubBlock();
subBytes[nr-1].state <== interBlock[nr-1];
subBytes[9] = SubBlock();
subBytes[9].state <== interBlock[9];

shiftRows[nr-1] = ShiftRows();
shiftRows[nr-1].state <== subBytes[nr-1].newState;
shiftRows[9] = ShiftRows();
shiftRows[9].state <== subBytes[9].newState;

addRoundKey[nr] = AddRoundKey();
addRoundKey[nr].state <== shiftRows[nr-1].newState;
addRoundKey[10] = AddRoundKey();
addRoundKey[10].state <== shiftRows[9].newState;
for (var i = 0; i < 4; i++) {
addRoundKey[nr].roundKey[i] <== keyExpansion.keyExpanded[i + (nr * 4)];
addRoundKey[10].roundKey[i] <== keyExpansion.keyExpanded[i + (40)];
}

cipher <== addRoundKey[nr].newState;
cipher <== addRoundKey[10].newState;
}

// @param nk: number of keys which can be 4, 6, 8
Expand Down
Loading

0 comments on commit ac209d7

Please sign in to comment.