diff --git a/transactron/lib/storage.py b/transactron/lib/storage.py index e402df7ec..0b25dd6c2 100644 --- a/transactron/lib/storage.py +++ b/transactron/lib/storage.py @@ -1,6 +1,7 @@ from amaranth import * from amaranth.utils import * import amaranth.lib.memory as memory +import amaranth_types.memory as amemory from transactron.utils.transactron_helpers import from_method_layout, make_layout from ..core import * @@ -49,6 +50,7 @@ def __init__( transparent: bool = False, read_ports: int = 1, write_ports: int = 1, + memory_type: amemory.AbstractMemoryConstructor[int, Value] = memory.Memory, src_loc: int | SrcLoc = 0, ): """ @@ -82,6 +84,7 @@ def __init__( self.transparent = transparent self.reads_ports = read_ports self.writes_ports = write_ports + self.memory_type = memory_type self.read_reqs_layout: LayoutList = [("addr", self.addr_width)] write_layout = [("addr", self.addr_width), ("data", self.data_layout)] @@ -102,7 +105,7 @@ def __init__( def elaborate(self, platform) -> TModule: m = TModule() - m.submodules.mem = mem = memory.Memory(shape=self.width, depth=self.elem_count, init=[]) + m.submodules.mem = mem = self.memory_type(shape=self.width, depth=self.elem_count, init=[]) write_port = [mem.write_port() for _ in range(self.writes_ports)] read_port = [ mem.read_port(transparent_for=write_port if self.transparent else []) for _ in range(self.reads_ports) @@ -275,6 +278,7 @@ def __init__( granularity: Optional[int] = None, read_ports: int = 1, write_ports: int = 1, + memory_type: amemory.AbstractMemoryConstructor[int, Value] = memory.Memory, src_loc: int | SrcLoc = 0, ): """ @@ -303,6 +307,7 @@ def __init__( self.addr_width = bits_for(self.elem_count - 1) self.reads_ports = read_ports self.writes_ports = write_ports + self.memory_type = memory_type self.read_reqs_layout: LayoutList = [("addr", self.addr_width)] write_layout = [("addr", self.addr_width), ("data", self.data_layout)] @@ -323,7 +328,7 @@ def __init__( def elaborate(self, platform) -> TModule: m = TModule() - mem = memory.Memory(shape=self.width, depth=self.elem_count, init=[]) + mem = self.memory_type(shape=self.width, depth=self.elem_count, init=[]) m.submodules.mem = mem write_port = [mem.write_port() for _ in range(self.writes_ports)] read_port = [mem.read_port(domain="comb") for _ in range(self.reads_ports)]