Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shift storage #709

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 131 additions & 1 deletion test/transactron/test_transactron_lib_storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import timedelta
from hypothesis import given, settings, Phase
from transactron.testing import *
from transactron.lib.storage import ContentAddressableMemory
from transactron.lib.storage import ContentAddressableMemory, ShiftStorage


class TestContentAddressableMemory(TestCaseWithSimulator):
Expand Down Expand Up @@ -133,3 +133,133 @@ def test_random(self, in_push, in_write, in_read, in_remove):
sim.add_sync_process(self.read_process(in_read))
sim.add_sync_process(self.write_process(in_write))
sim.add_sync_process(self.remove_process(in_remove))


class TestShiftStorage(TestCaseWithSimulator):
depth = 8
content_width = 5
test_number = 30
nop_number = 3
content_layout = data_layout(content_width)

def setUp(self):
self.circ = SimpleTestCircuit(ShiftStorage(self.content_layout, self.depth, support_update=True))
self.memory = []

def generic_process(
self,
method,
input_lst,
behaviour_check=None,
state_change=None,
input_verification=None,
settle_count=0,
name="",
):
def f():
while input_lst:
# wait till all processes will end the previous cycle
yield from self.multi_settle(4)
elem = input_lst.pop()
if isinstance(elem, OpNOP):
yield
continue
if input_verification is not None and not input_verification(elem):
yield
continue
response = yield from method.call(**elem)
yield from self.multi_settle(settle_count)
if behaviour_check is not None:
# Here accesses to circuit are allowed
ret = behaviour_check(elem, response)
if isinstance(ret, Generator):
yield from ret
if state_change is not None:
# It is standard python function by purpose to don't allow accessing circuit
state_change(elem, response)
yield

return f

def push_back_process(self, in_push):
def verify_in(elem):
return len(self.memory) < self.depth

def modify_state(elem, response):
self.memory.append(elem["data"])

return self.generic_process(
self.circ.push_back,
in_push,
state_change=modify_state,
input_verification=verify_in,
settle_count=3,
name="push_back",
)

def read_process(self, in_read):
def check(elem, response):
addr = elem["addr"]
if addr < len(self.memory):
assert response["valid"] == 1
assert response["data"] == self.memory[addr]
else:
assert response["valid"] == 0

return self.generic_process(self.circ.read, in_read, behaviour_check=check, settle_count=0, name="read")

def delete_process(self, in_delete):
def verify_in(elem):
return len(self.memory) > 0

def modify_state(elem, response):
addr = elem["addr"]
if addr < len(self.memory):
self.memory = self.memory[:addr] + self.memory[addr + 1 :]

return self.generic_process(
self.circ.delete,
in_delete,
state_change=modify_state,
input_verification=verify_in,
settle_count=2,
name="delete",
)

def update_process(self, in_update):
def check(elem, response):
assert response["err"] == (elem["addr"] >= len(self.memory))

def modify_state(elem, response):
if elem["addr"] < len(self.memory):
self.memory[elem["addr"]] = elem["data"]

return self.generic_process(
self.circ.update,
in_update,
behaviour_check=check,
state_change=modify_state,
settle_count=1,
name="update",
)

@settings(
max_examples=30,
phases=(Phase.explicit, Phase.reuse, Phase.generate, Phase.shrink),
derandomize=True,
deadline=timedelta(milliseconds=500),
)
@given(
generate_process_input(test_number, nop_number, [("data", content_layout)]),
generate_process_input(test_number, nop_number, [("addr", range(depth)), ("data", content_layout)]),
generate_process_input(test_number, nop_number, [("addr", range(depth))]),
generate_process_input(test_number, nop_number, [("addr", range(depth))]),
)
def test_random(self, in_push, in_update, in_read, in_delete):
with self.reinitialize_fixtures():
self.setUp()
with self.run_simulation(self.circ, max_cycles=500) as sim:
sim.add_sync_process(self.push_back_process(in_push))
sim.add_sync_process(self.read_process(in_read))
sim.add_sync_process(self.update_process(in_update))
sim.add_sync_process(self.delete_process(in_delete))
97 changes: 97 additions & 0 deletions transactron/lib/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,100 @@ def _(arg):
m.d.comb += write_port.en.eq(arg.mask)

return m


class ShiftStorage(Elaboratable):
"""Ordered storage, which shifts data

This module implements a storage, which maintain all elements in continuous, ordered array
according to the age of the elements (element inserted earlier will be earlier in the list).
The order is restored after each `delete` operation by shifting all elemets which are older
that the removed one.


Attributes
----------
read: Method
The read method. Accepts an `addr` from which data should be read.
The read response method. Return `data_layout` View which was saved on `addr` at reading
time. The `valid` signal indicates if the `addr` was inside the range.
push_back : Method
Insert the new entry on the end of storage. Ready only if there is a free place.
delate: Method
Remove the element from the `addr`. Shifts the elements to restore the proper order.
update: Method
The update method. Accepts `addr` where data should be saved, `data` in form of `data_layout`.
This method is available if `support_update` is set to `True`.
"""

def __init__(self, data_layout: MethodLayout, depth: int, support_update: bool = False):
"""
Parameters
----------
data_layout: LayoutList
The format of structures stored in the ShiftStorage.
depth: int
Number of elements stored in ShiftStorage.
support_update: bool
Indicates if `update` method should be generated. Default: `False`
"""
self.data_layout = from_method_layout(data_layout)
self.depth = depth
self.support_update = support_update

self._data = Array([Signal(self.data_layout, name=f"cell_{i}") for i in range(self.depth)])
self._last = Signal(range(self.depth + 1)) # pointer on first empty cell

self.read = Method(i=[("addr", range(self.depth))], o=[("data", self.data_layout), ("valid", 1)])
self.push_back = Method(i=[("data", self.data_layout)])
self.delete = Method(i=[("addr", range(self.depth))])
if self.support_update:
self.update = Method(i=[("addr", range(self.depth)), ("data", self.data_layout)], o=[("err", 1)])

def _generate_shift(self, m: TModule, addr: Value):
for i in range(1, self.depth):
with m.If(addr < i):
m.d.sync += self._data[i - 1].eq(self._data[i])

def elaborate(self, platform):
m = TModule()

last_incr = Signal()
last_decr = Signal()
delete_address = Signal(range(self.depth))

@def_method(m, self.read)
def _(addr):
return {"data": self._data[addr], "valid": addr < self._last}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an use case for reading invalid addresses?


@def_method(m, self.delete, ready=self._last > 0)
def _(addr):
m.d.top_comb += delete_address.eq(addr)
with m.If(addr < self._last):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this If? It makes the delete not happen if a wrong address is passed. What's the use case? I believe this behavior could create silent bugs. Maybe use an assertion instead?

m.d.comb += last_decr.eq(1)
self._generate_shift(m, addr)

if self.support_update:

@def_method(m, self.update)
def _(addr, data):
update_req_valid = Signal()
m.d.top_comb += update_req_valid.eq(addr < self._last)
affected_by_delete = Signal()
m.d.top_comb += affected_by_delete.eq(addr > delete_address)
with m.If(update_req_valid & (last_decr.implies(delete_address != addr))):
m.d.sync += self._data[addr - Mux(last_decr, affected_by_delete, 0)].eq(data)
return {"err": ~update_req_valid}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. If there is no use case, I prefer an assertion to a bit which can be easily ignored.


@def_method(m, self.push_back, self._last < self.depth)
def _(data):
m.d.sync += self._data[self._last - last_decr].eq(data)
m.d.comb += last_incr.eq(1)

with m.Switch(Cat(last_incr, last_decr)):
with m.Case(1):
m.d.sync += self._last.eq(self._last + 1)
with m.Case(2):
m.d.sync += self._last.eq(self._last - 1)

return m
20 changes: 11 additions & 9 deletions transactron/testing/input_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,18 @@ def generate_nops_in_list(draw: DrawFn, max_nops: int, generate_list: SearchStra
return out_lst


@composite
def generate_method_input(draw: DrawFn, args: list[tuple[str, MethodLayout]]) -> dict[str, RecordIntDict]:
out = []
for name, layout in args:
out.append((name, draw(generate_based_on_layout(layout))))
return dict(out)
# @composite
# def generate_method_input(draw: DrawFn, args: MethodLayout) -> dict[str, RecordIntDict]:
# out = []
# for name, layout in args:
# out.append((name, draw(generate_based_on_layout(layout))))
# return dict(out)


@composite
def generate_process_input(
draw: DrawFn, elem_count: int, max_nops: int, layouts: list[tuple[str, MethodLayout]]
) -> list[dict[str, RecordIntDict] | OpNOP]:
return draw(generate_nops_in_list(max_nops, generate_shrinkable_list(elem_count, generate_method_input(layouts))))
draw: DrawFn, elem_count: int, max_nops: int, layouts: MethodLayout
) -> list[RecordIntDict | OpNOP]:
return draw(
generate_nops_in_list(max_nops, generate_shrinkable_list(elem_count, generate_based_on_layout(layouts)))
)
4 changes: 2 additions & 2 deletions transactron/utils/data_repr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Iterable, Mapping
from ._typing import ShapeLike, MethodLayout
from ._typing import ShapeLike, LayoutList
from typing import Any, Sized
from statistics import fmean
from amaranth.lib.data import StructLayout
Expand Down Expand Up @@ -78,7 +78,7 @@ def bits_from_int(num: int, lower: int, length: int):
return (num >> lower) & ((1 << (length)) - 1)


def data_layout(val: ShapeLike) -> MethodLayout:
def data_layout(val: ShapeLike) -> LayoutList:
return [("data", val)]


Expand Down