Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separating shift and sign extend #130

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions pil/zisk.pil
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ require "constants.pil"
require "main/pil/main.pil"
require "binary/pil/binary.pil"
require "binary/pil/binary_table.pil"
require "binary/pil/binary_extension.pil"
require "binary/pil/binary_extension_table.pil"
require "binary/pil/shift.pil"
require "binary/pil/shift_table.pil"
require "binary/pil/sign_extension.pil"
// require "mem/pil/mem.pil"

const int OPERATION_BUS_ID = 5000;
Expand All @@ -24,10 +25,14 @@ airgroup BinaryTable {
BinaryTable(disable_fixed: 0);
}

airgroup BinaryExtension {
BinaryExtension(N: 2**21, operation_bus_id: OPERATION_BUS_ID);
airgroup Shift {
Shift(N: 2**21, operation_bus_id: OPERATION_BUS_ID);
}

airgroup BinaryExtensionTable {
BinaryExtensionTable(disable_fixed: 0);
airgroup ShiftTable {
ShiftTable(disable_fixed: 0);
}

airgroup SignExtension {
SignExtension(N: 2**21, operation_bus_id: OPERATION_BUS_ID);
}
88 changes: 53 additions & 35 deletions state-machines/binary/pil/binary.pil
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
require "std_lookup.pil"

// Coprocessor in charge of performing standard RISCV binary operations

/*
Coprocessor in charge of performing the following binary operations:

List 64-bit operations:
name │ op │ m_op │ carry │ use_last_carry │ NOTES
────────┼──────────┼──────────┼───────┼────────────────┼───────────────────────────────────
Expand Down Expand Up @@ -65,6 +64,7 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID)
// Default values
const int bits = 64;
const int bytes = bits / 8;
const int half_bytes = bytes / 2;

// Main values
const int input_chunks = 2;
Expand All @@ -80,51 +80,69 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID)

// Secondary columns
col witness use_last_carry; // 1 if the operation uses the last carry as its result
col witness op_is_min_max; // 1 if op ∈ {MINU,MIN,MAXU,MAX}
col witness op_is_min_max; // 1 if the operation is any of the MIN/MAX operations

const expr cout32 = carry[bytes/2-1];
const expr mode64 = 1 - mode32;
const expr cout32 = carry[half_bytes-1];
const expr cout64 = carry[bytes-1];
expr cout = (1-mode32) * (cout64 - cout32) + cout32;

use_last_carry * (1 - use_last_carry) === 0;
op_is_min_max * (1 - op_is_min_max) === 0;
cout32*(1 - cout32) === 0;
cout64*(1 - cout64) === 0;

// Constraints to check the correctness of each binary operation
// Auxiliary columns (primarily used to optimize lookups, but can be substituted with expressions)
col witness cout;
col witness result_is_a;
col witness use_last_carry_mode32;
col witness use_last_carry_mode64;
cout === mode64 * (cout64 - cout32) + cout32;
result_is_a === op_is_min_max * cout;
use_last_carry_mode32 === mode32 * use_last_carry;
use_last_carry_mode64 === mode64 * use_last_carry;

/*
opid last a b c cin cout
───────────────────────────────────────────────────────────────
m_op 0 a0 b0 c0 0 carry0
m_op 0 a1 b1 c1 carry0 carry1
m_op 0 a2 b2 c2 carry1 carry2
m_op 0 a3 b3 c3 carry2 carry3 + 2*use_last_carry
m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4
m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5
m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6
m_op|EXT_32 1 a7|c3 b7|0 c7 carry6 carry7 + 2*use_last_carry
Constraints to check the correctness of each binary operation
opid last a b c cin cout + flags
───────────────────────────────────────────────────────────────-------------------------------------------------
m_op 0 a0 b0 c0 0 carry0 + 2*op_is_min_max + 4*result_is_a
m_op 0 a1 b1 c1 carry0 carry1 + 2*op_is_min_max + 4*result_is_a
m_op 0 a2 b2 c2 carry1 carry2 + 2*op_is_min_max + 4*result_is_a
m_op 0|1 a3 b3 c3 carry2 carry3 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32
m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4 + 2*op_is_min_max + 4*result_is_a
m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5 + 2*op_is_min_max + 4*result_is_a
m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6 + 2*op_is_min_max + 4*result_is_a
m_op|EXT_32 0|1 a7|c3 b7|0 c7 carry6 carry7 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64
───────────────────────────────────────────────────────────────-------------------------------------------------
Perform, at the byte level, lookups against the binary table on inputs:
[last, m_op, a, b, cin, c, cout + flags]
where last indicates whether the byte is the last one in the operation
*/

// Perform, at the byte level, lookups against the binary table on inputs:
// [last, m_op, a, b, cin, c, cout + flags]
// where last indicates whether the byte is the last one in the operation
lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*result_is_a]);

lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*op_is_min_max*cout]);

expr _m_op = (1-mode32) * (m_op - EXT_32_OP) + EXT_32_OP;
// More auxiliary columns
col witness m_op_or_ext;
col witness free_in_a_or_c[half_bytes];
col witness free_in_b_or_zero[half_bytes];
m_op_or_ext === mode64 * (m_op - EXT_32_OP) + EXT_32_OP;
int index = 0;
for (int i = 1; i < bytes; i++) {
expr _free_in_a = (1-mode32) * (free_in_a[i] - free_in_c[bytes/2-1]) + free_in_c[bytes/2-1];
expr _free_in_b = (1-mode32) * free_in_b[i];

if (i < bytes/2 - 1) {
lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]);
} else if (i == bytes/2 - 1) {
lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*mode32]);
} else if (i < bytes - 1) {
lookup_assumes(BINARY_TABLE_ID, [0, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]);
} else {
lookup_assumes(BINARY_TABLE_ID, [1-mode32, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*(1-mode32)]);
}
if (i >= half_bytes) {
index = i - half_bytes;
free_in_a_or_c[index] === mode64 * (free_in_a[i] - free_in_c[half_bytes-1]) + free_in_c[half_bytes-1];
free_in_b_or_zero[index] === mode64 * free_in_b[i];
}

if (i < half_bytes - 1) {
lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]);
} else if (i == half_bytes - 1) {
lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32]);
} else if (i < bytes - 1) {
lookup_assumes(BINARY_TABLE_ID, [0, m_op_or_ext, free_in_a_or_c[index], free_in_b_or_zero[index], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]);
} else {
lookup_assumes(BINARY_TABLE_ID, [mode64, m_op_or_ext, free_in_a_or_c[index], free_in_b_or_zero[index], carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64]);
}
}

// Constraints to make sure that this component is called from the main component
Expand Down
111 changes: 0 additions & 111 deletions state-machines/binary/pil/binary_extension.pil

This file was deleted.

99 changes: 99 additions & 0 deletions state-machines/binary/pil/shift.pil
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
require "std_lookup.pil"
require "std_range_check.pil"

/*
Coprocessor in charge of performing shift operations:

┼────────┼────────┼──────────┼
│ name │ bits │ op │
┼────────┼────────┼──────────┼
│ SLL │ 64 │ 0x0d │
│ SRL │ 64 │ 0x0e │
│ SRA │ 64 │ 0x0f │
│ SLL_W │ 32 │ 0x1d │
│ SRL_W │ 32 │ 0x1e │
│ SRA_W │ 32 │ 0x1f │
┼────────┼────────┼──────────┼

Examples:
=======================================

SLL 28
x in1[x] out[x][0] out[x][1]
---------------------------------------
0 0x11 0x10000000 0x00000001
1 0x22 0x00000000 0x00000220
2 0x33 0x00000000 0x00033000
3 0x44 0x00000000 0x04400000
4 0x55 0x00000000 0x50000000
5 0x66 0x00000000 0x00000000
6 0x77 0x00000000 0x00000000
7 0x88 0x00000000 0x00000000
---------------------------------------
Result: 0x10000000 0x54433221

SLL_W 8
x in1[x] out[x][0] out[x][1]
---------------------------------------
0 0x11 0x00001100 0x00000000
1 0x22 0x00220000 0x00000000
2 0x33 0x33000000 0x00000000
3 0x44 0x00000000 0x00000044
4 0x55 0x00000000 0x00000000 (since 0x44 & 0x80 = 0, we stop here and set the remaining bytes to 0x00)
5 0x66 0x00000000 0x00000000 (bytes of in1 are ignored from here)
6 0x77 0x00000000 0x00000000
7 0x88 0x00000000 0x00000000
---------------------------------------
Result: 0x33221100 0x00000000
*/

const int SHIFT_ID = 21;

airtemplate Shift(const int N = 2**18, const int operation_bus_id = SHIFT_ID) {
const int bits = 64;
const int bytes = bits / 8;
const int half_bytes = bytes / 2;

col witness op;
col witness in1[bytes];
col witness in2_low; // Note: if in2_low∊[0,2^5-1], else in2_low∊[0,2^6-1] (checked by the table)
col witness out[bytes][2];

// Constraints to check the correctness of each shift operation
for (int j = 0; j < bytes; j++) {
lookup_assumes(SHIFT_TABLE_ID, [op, j, in1[j], in2_low, out[j][0], out[j][1]]);
}

// Constraints to make sure that this component is called from the main component
expr in1_low = 0;
expr in1_high = 0;
expr out_low = 0;
expr out_high = 0;
for (int i = 0; i < half_bytes; i++) {
in1_low += in1[i] * (0xFF ** i);
in1_high += in1[i + half_bytes] * (0xFF ** i);
out_low += out[i][0] + out[i + half_bytes][0];
out_high += out[i][1] + out[i + half_bytes][1];
}

col witness in2[2];
col witness main_step;
col witness multiplicity;
lookup_proves(
operation_bus_id,
[
main_step,
op,
in1_low,
in1_high,
in2_low + 256 * in2[0],
in2[1],
out_low,
out_high,
0
],
multiplicity
);

range_check(in2[0], 0, 2**24-1);
}
Loading
Loading