diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index bef83bd52..ea114ad6c 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -1,4 +1,4 @@ -use constants::OpcodeType; +use constants::RvInstruction; use ff_ext::ExtensionField; use super::Instruction; @@ -12,5 +12,5 @@ pub mod constants; mod test; pub trait RIVInstruction: Instruction { - const OPCODE_TYPE: OpcodeType; + const OPCODE_TYPE: RvInstruction; } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 5525fdfec..1afd3aa5a 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -6,10 +6,7 @@ use itertools::Itertools; use super::{ config::ExprLtConfig, - constants::{ - OPType, OpcodeType, RegUInt, FUNCT3_ADD_SUB, FUNCT7_ADD, FUNCT7_SUB, OPCODE_OP, - PC_STEP_SIZE, - }, + constants::{RegUInt, RvInstruction, RvOpcode, PC_STEP_SIZE}, RIVInstruction, }; use crate::{ @@ -49,11 +46,11 @@ pub struct InstructionConfig { } impl RIVInstruction for AddInstruction { - const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0000000); + const OPCODE_TYPE: RvInstruction = RvInstruction::ADD; } impl RIVInstruction for SubInstruction { - const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0100000); + const OPCODE_TYPE: RvInstruction = RvInstruction::SUB; } fn add_sub_gadget( @@ -97,15 +94,20 @@ fn add_sub_gadget( let rs2_id = circuit_builder.create_witin(|| "rs2_id")?; let rd_id = circuit_builder.create_witin(|| "rd_id")?; + let opcode: RvOpcode = if IS_ADD { + AddInstruction::::OPCODE_TYPE.into() + } else { + SubInstruction::::OPCODE_TYPE.into() + }; // Fetch the instruction. circuit_builder.lk_fetch(&InsnRecord::new( pc.expr(), - OPCODE_OP.into(), + (opcode.opcode as usize).into(), rd_id.expr(), - FUNCT3_ADD_SUB.into(), + (opcode.funct3.unwrap() as usize).into(), rs1_id.expr(), rs2_id.expr(), - (if IS_ADD { FUNCT7_ADD } else { FUNCT7_SUB }).into(), + (opcode.funct7.unwrap() as usize).into(), ))?; let prev_rs1_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?; @@ -239,7 +241,7 @@ fn add_sub_assignment( impl Instruction for AddInstruction { // const NAME: &'static str = "ADD"; fn name() -> String { - "ADD".into() + Self::OPCODE_TYPE.to_string() } type InstructionConfig = InstructionConfig; fn construct_circuit( @@ -260,9 +262,8 @@ impl Instruction for AddInstruction { } impl Instruction for SubInstruction { - // const NAME: &'static str = "ADD"; fn name() -> String { - "SUB".into() + Self::OPCODE_TYPE.to_string() } type InstructionConfig = InstructionConfig; fn construct_circuit( diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index 30f4cc333..cdc32c9ef 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -20,7 +20,7 @@ use crate::{ use super::{ config::ExprLtConfig, - constants::{OPType, OpcodeType, RegUInt, RegUInt8, PC_STEP_SIZE}, + constants::{RegUInt, RegUInt8, RvInstruction, PC_STEP_SIZE}, RIVInstruction, }; @@ -141,7 +141,7 @@ impl BltInput { } impl RIVInstruction for BltInstruction { - const OPCODE_TYPE: OpcodeType = OpcodeType::BType(OPType::Branch, 0x004); + const OPCODE_TYPE: RvInstruction = RvInstruction::BLT; } /// if (rs1 < rs2) PC += sext(imm) @@ -210,9 +210,8 @@ fn blt_gadget( } impl Instruction for BltInstruction { - // const NAME: &'static str = "BLT"; fn name() -> String { - "BLT".into() + Self::OPCODE_TYPE.to_string() } type InstructionConfig = InstructionConfig; fn construct_circuit( diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index e8c12cdb0..4e206f2fc 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -1,30 +1,122 @@ use std::fmt; +use strum_macros::EnumIter; use crate::uint::UInt; pub use ceno_emul::PC_STEP_SIZE; -pub const OPCODE_OP: usize = 0x33; -pub const FUNCT3_ADD_SUB: usize = 0; -pub const FUNCT7_ADD: usize = 0; -pub const FUNCT7_SUB: usize = 0x20; +/// This struct is used to define the opcode format for RISC-V instructions, +/// containing three main components: the opcode, funct3, and funct7 fields. +/// These fields are crucial for specifying the +/// exact operation and variants in the RISC-V instruction set architecture. +#[derive(Default, Clone, Debug)] +pub struct RvOpcode { + pub opcode: OPType, + pub funct3: Option, + pub funct7: Option, +} -#[allow(clippy::upper_case_acronyms)] +impl From for u64 { + fn from(opcode: RvOpcode) -> Self { + let mut result: u64 = 0; + result |= (opcode.opcode as u64) & 0xFF; + result |= ((opcode.funct3.unwrap() as u64) & 0xFF) << 8; + result |= ((opcode.funct7.unwrap() as u64) & 0xFF) << 16; + result + } +} + +#[allow(dead_code, non_camel_case_types)] +/// List all RISC-V base instruction formats: +/// R-Type, I-Type, S-Type, B-Type, U-Type, J-Type and special type. #[derive(Debug, Clone, Copy)] pub enum OPType { - Op, - Opimm, - Jal, - Jalr, - Branch, + UNKNOWN = 0x00, + + R = 0x33, + I_LOAD = 0x03, + I_ARITH = 0x13, + S = 0x63, + B = 0x23, + U_LUI = 0x37, + U_AUIPC = 0x7, + J = 0x6F, + JAR = 0x67, + SYS = 0x73, } -#[derive(Debug, Clone, Copy)] -pub enum OpcodeType { - RType(OPType, usize, usize), // (OP, func3, func7) - BType(OPType, usize), // (OP, func3) +impl Default for OPType { + fn default() -> Self { + OPType::UNKNOWN + } +} + +impl From for u8 { + fn from(opcode: OPType) -> Self { + opcode as u8 + } +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone, Copy, EnumIter)] +pub enum RvInstruction { + // Type R + ADD = 0, + SUB, + + // Type M + MUL, + DIV, + DIVU, + + // Type B + BLT, +} + +impl From for RvOpcode { + fn from(ins: RvInstruction) -> Self { + // Find the instruction format here: + // https://fraserinnovations.com/risc-v/risc-v-instruction-set-explanation/ + match ins { + // Type R + RvInstruction::ADD => RvOpcode { + opcode: OPType::R, + funct3: Some(0b000 as u8), + funct7: Some(0), + }, + RvInstruction::SUB => RvOpcode { + opcode: OPType::R, + funct3: Some(0b000 as u8), + funct7: Some(0b010_0000), + }, + + // Type M + RvInstruction::MUL => RvOpcode { + opcode: OPType::R, + funct3: Some(0b000 as u8), + funct7: Some(0b0000_0001), + }, + RvInstruction::DIV => RvOpcode { + opcode: OPType::R, + funct3: Some(0b100 as u8), + funct7: Some(0b0000_0001), + }, + RvInstruction::DIVU => RvOpcode { + opcode: OPType::R, + funct3: Some(0b101 as u8), + funct7: Some(0b0000_0001), + }, + + // Type B + RvInstruction::BLT => RvOpcode { + opcode: OPType::B, + funct3: Some(0b100 as u8), + funct7: None, + }, + } + } } -impl fmt::Display for OpcodeType { +impl fmt::Display for RvInstruction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:?}", self) }