Skip to content

Commit

Permalink
Update PE skeleton
Browse files Browse the repository at this point in the history
  • Loading branch information
minseongg committed Oct 22, 2024
1 parent 5c0dd77 commit e2dac16
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 84 deletions.
101 changes: 46 additions & 55 deletions hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,62 @@
//! Processing element.
//!
//! FIXME:
//! Currently, this implementation is assuming the base configuration(i.e., inputType = SInt(8.W), accType = SInt(32.W), spatialArrayOutputType = SInt(20.W))

#![allow(unused)] // Added for assignment.

use super::*;

/// PE Row Data
/// PE row data signals.
#[derive(Debug, Clone, Copy)]
pub struct PeRowData {
/// a
/// A.
pub a: U<INPUT_BITS>,
}

/// PE Column Data
/// PE column data signals.
#[derive(Debug, Clone, Copy)]
pub struct PeColData {
/// b
/// B.
pub b: U<OUTPUT_BITS>,
/// d

/// D.
pub d: U<OUTPUT_BITS>,
}

/// Which register to use to preload the value
#[derive(Debug, Default, Clone, Copy)]
/// PE column control signals.
///
/// NOTE: The column data and control signals should be separated to handle the `flush` operation.
/// <https://github.com/ucb-bar/gemmini/blob/be2e9f26181658895ebc7ca7f7d6be6210f5cdef/src/main/scala/gemmini/ExecuteController.scala#L189-L207>
#[derive(Debug, Clone, Copy)]
pub struct PeColControl {
/// ID.
pub id: U<ID_BITS>,

/// Is this last row?
pub last: bool,

/// PE control signals.
pub control: PeControl,
}

/// Represents which register to use to preload the value.
#[derive(Debug, Default, Clone, Copy, HEq)]
pub enum Propagate {
/// use Reg1 for preload and Reg2 for computation
/// Use `Reg1` for preloading and `Reg2` for computation.
#[default]
Reg1,
/// use Reg2 for preload and Reg1 for computation

/// Use `Reg2` for preloading and `Reg1` for computation.
Reg2,
}

/// Is Dataflow Output-Stationary(OS) or Weight-Stationary(WS)?
#[derive(Debug, Clone, Copy)]
/// Represents the dataflow.
#[derive(Debug, Default, Clone, Copy, HEq)]
pub enum Dataflow {
/// Output Stationary
/// Output Stationary.
#[default]
OS,
/// Weight Stationary
WS,
}

impl Default for Dataflow {
fn default() -> Self {
Self::OS
}
/// Weight Stationary.
WS,
}

impl From<U<1>> for Dataflow {
Expand All @@ -63,43 +74,29 @@ impl From<bool> for Dataflow {
}
}

/// PE Control
/// PE control data.
#[derive(Debug, Clone, Copy)]
pub struct PeControl {
/// DataFlow
/// Dataflow.
pub dataflow: Dataflow,

/// Propagate
/// Propagate.
pub propagate: Propagate,

/// Shift
/// Shift.
pub shift: U<5>,
}

/// PE column control.
///
/// NOTE: column data and control should be separated because of the `flush` operation.
/// <https://github.com/ucb-bar/gemmini/blob/be2e9f26181658895ebc7ca7f7d6be6210f5cdef/src/main/scala/gemmini/ExecuteController.scala#L189-L207>
#[derive(Debug, Clone, Copy)]
pub struct PeColControl {
/// id
pub id: U<ID_BITS>,
/// is this last row?
pub last: bool,
/// pe control
pub control: PeControl,
/// bad_dataflow
pub bad_dataflow: bool,
}

/// PE state.
#[derive(Debug, Default, Clone, Copy)]
pub struct PeS {
/// Register 1
/// Register 1.
pub reg1: U<32>,
/// Register 2

/// Register 2.
pub reg2: U<32>,
/// Same as `last_s` in the Chisel implementation.

/// Propagate.
pub propagate: Propagate,
}

Expand All @@ -112,14 +109,7 @@ impl PeS {

/// MAC unit (computes `a * b + c`).
fn mac(a: U<8>, b: U<8>, c: U<32>) -> U<OUTPUT_BITS> {
todo!("Assignment 4")
}

/// Returns whether there was a change in the propagate option.
///
/// NOTE: This is equivalent to `prev != curr`, but hazardflow compiler does not support it (ICE).
fn propagate_flipped(prev: Propagate, curr: Propagate) -> bool {
matches!(prev, Propagate::Reg1) ^ matches!(curr, Propagate::Reg1)
todo!("assignment 4")
}

/// Same as `(val >> shamt).clippedToWidthOf(20)`.
Expand All @@ -134,11 +124,12 @@ pub fn pe(
_in_left: Valid<PeRowData>,
(_in_top_data, _in_top_control): (Valid<PeColData>, Valid<PeColControl>),
) -> (Valid<PeRowData>, (Valid<PeColData>, Valid<PeColControl>)) {
todo!("Assignment 4")
todo!("assignment 4")
}

/// Chisel PE Wrapper.
/// This module allows students to proceed with future assignments even if they have not completed Assignment4.
///
/// This module allows students to proceed with future assignments even if they have not completed assignment 4.
#[magic(ffi::PE256Wrapper())]
pub fn pe_256_chisel(
_in_left: Valid<PeRowData>,
Expand Down
29 changes: 0 additions & 29 deletions scripts/gemmini/unit_tests/pe/test_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def rounding_shift(value, shift):
"payload_Some_0_control_dataflow_discriminant",
"payload_Some_0_control_propagate_discriminant",
"payload_Some_0_control_shift",
"payload_Some_0_bad_dataflow",
],
valid_signal="payload_discriminant",
)
Expand Down Expand Up @@ -190,7 +189,6 @@ async def ws_simple(dut):
payload_Some_0_control_dataflow_discriminant=WS,
payload_Some_0_control_propagate_discriminant=REG2,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
await tb.pe_row_data_req.send(PeDataRowTransaction(payload_Some_0_a=0))
await tb.pe_col_data_req.send(
Expand All @@ -212,7 +210,6 @@ async def ws_simple(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == WS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG2
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False

# Input data and check output data
for i in range(16):
Expand All @@ -229,7 +226,6 @@ async def ws_simple(dut):
payload_Some_0_control_dataflow_discriminant=WS,
payload_Some_0_control_propagate_discriminant=REG1,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -245,7 +241,6 @@ async def ws_simple(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == WS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG1
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False


@cocotb.test(timeout_time=10, timeout_unit="ms")
Expand Down Expand Up @@ -290,7 +285,6 @@ async def ws_random(dut):
payload_Some_0_control_dataflow_discriminant=WS,
payload_Some_0_control_propagate_discriminant=REG2,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)

await tb.pe_row_data_req.send(PeDataRowTransaction(payload_Some_0_a=0))
Expand All @@ -313,7 +307,6 @@ async def ws_random(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == WS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG2
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False

# Input data and check output data
for i in range(16):
Expand All @@ -330,7 +323,6 @@ async def ws_random(dut):
payload_Some_0_control_dataflow_discriminant=WS,
payload_Some_0_control_propagate_discriminant=REG1,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -346,7 +338,6 @@ async def ws_random(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == WS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG1
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False


@cocotb.test(timeout_time=10, timeout_unit="ms")
Expand Down Expand Up @@ -398,7 +389,6 @@ async def os_simple(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG1,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -414,7 +404,6 @@ async def os_simple(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG1
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False

# Check output data
await tb.pe_row_data_req.send(PeDataRowTransaction(payload_Some_0_a=0))
Expand All @@ -428,7 +417,6 @@ async def os_simple(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG2,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -444,7 +432,6 @@ async def os_simple(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG2
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False


@cocotb.test(timeout_time=10, timeout_unit="ms")
Expand Down Expand Up @@ -496,7 +483,6 @@ async def os_random_shift(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG1,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -512,7 +498,6 @@ async def os_random_shift(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG1
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False

# Check output data
await tb.pe_row_data_req.send(PeDataRowTransaction(payload_Some_0_a=0))
Expand All @@ -526,7 +511,6 @@ async def os_random_shift(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG2,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -542,7 +526,6 @@ async def os_random_shift(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG2
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False


@cocotb.test(timeout_time=10, timeout_unit="ms")
Expand Down Expand Up @@ -594,7 +577,6 @@ async def os_random_inp_and_shift(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG1,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -610,7 +592,6 @@ async def os_random_inp_and_shift(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG1
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False

# Check output data
await tb.pe_row_data_req.send(PeDataRowTransaction(payload_Some_0_a=0))
Expand All @@ -624,7 +605,6 @@ async def os_random_inp_and_shift(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG2,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -640,7 +620,6 @@ async def os_random_inp_and_shift(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG2
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False


@cocotb.test(timeout_time=10, timeout_unit="ms")
Expand Down Expand Up @@ -692,7 +671,6 @@ async def os_random_inp(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG1,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -708,7 +686,6 @@ async def os_random_inp(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG1
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False

# Check output data
await tb.pe_row_data_req.send(PeDataRowTransaction(payload_Some_0_a=0))
Expand All @@ -722,7 +699,6 @@ async def os_random_inp(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG2,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -738,7 +714,6 @@ async def os_random_inp(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG2
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False


@cocotb.test(timeout_time=10, timeout_unit="ms")
Expand Down Expand Up @@ -790,7 +765,6 @@ async def os_random_inp_and_shift(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG1,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -806,7 +780,6 @@ async def os_random_inp_and_shift(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG1
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False

# Check output data
await tb.pe_row_data_req.send(PeDataRowTransaction(payload_Some_0_a=0))
Expand All @@ -820,7 +793,6 @@ async def os_random_inp_and_shift(dut):
payload_Some_0_control_dataflow_discriminant=OS,
payload_Some_0_control_propagate_discriminant=REG2,
payload_Some_0_control_shift=rnd_shift,
payload_Some_0_bad_dataflow=False,
)
)

Expand All @@ -836,4 +808,3 @@ async def os_random_inp_and_shift(dut):
assert col_ctrl_resp.payload_Some_0_control_dataflow_discriminant == OS
assert col_ctrl_resp.payload_Some_0_control_propagate_discriminant == REG2
assert col_ctrl_resp.payload_Some_0_control_shift == rnd_shift
assert col_ctrl_resp.payload_Some_0_bad_dataflow == False

0 comments on commit e2dac16

Please sign in to comment.