Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hash functions and CountHashTab #727

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion test/transactron/test_transactron_lib_storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import timedelta
from hypothesis import given, settings, Phase
from transactron.testing import *
from transactron.lib.storage import ContentAddressableMemory
from transactron.lib.storage import ContentAddressableMemory, CountHashTab


class TestContentAddressableMemory(TestCaseWithSimulator):
Expand Down Expand Up @@ -133,3 +133,79 @@ def test_random(self, in_push, in_write, in_read, in_remove):
sim.add_sync_process(self.read_process(in_read))
sim.add_sync_process(self.write_process(in_write))
sim.add_sync_process(self.remove_process(in_remove))


@pytest.mark.parametrize("size", [9, 16])
class TestCountHashTab(TestCaseWithSimulator):
@pytest.fixture(autouse=True)
def setup(self):
self.inserted = False
self.insert_end = False
self.query_req_end = False
self.table_snapshots = []

def take_hash_table_snapshot(self):
table = []
for i in range(self.size):
table.append((yield self.circ._dut.counters[i]))
self.table_snapshots.append(table)

def insert_process(self, input):
def f():
self.insert_end = False
for in_val in input:
yield from self.circ.insert.call(data=in_val)
yield from self.take_hash_table_snapshot()
yield
yield Settle()
yield from self.take_hash_table_snapshot()
self.insert_end = True

return f

def query_req_process(self, input):
def f():
self.query_req_end = False
while not self.insert_end:
yield
for i in input:
yield from self.circ.query_req.call(data=i)
self.query_req_end = True

return f

def query_resp_process(self, input):
def f():
input_rev = list(reversed(input))
while not self.query_req_end:
count = yield from self.circ.query_resp.call()
in_val = input_rev.pop()
# Check if there is at least as many elements as we put inside (can be more because of aliasing)
assert count["count"] >= sum(map(lambda x: 1, filter(lambda x: x == in_val, input)))
assert count["count"] in self.table_snapshots[-1]
Comment on lines +183 to +185
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very weak test. I believe a constant table which is filled with maximum values would satisfy it.


return f

@pytest.mark.parametrize("input_width", [6, 10, 32, 64])
def test_random_one_value(self, size, input_width):
random.seed(14)
self.size = size
n = random.randrange(1 << input_width)
input = [n] * random.randrange(16)
self.circ = SimpleTestCircuit(CountHashTab(size, 5, input_width))
with self.run_simulation(self.circ) as sim:
sim.add_sync_process(self.insert_process(input))
sim.add_sync_process(self.query_req_process(input))
sim.add_sync_process(self.query_resp_process(input))

@pytest.mark.parametrize("input_width", [6, 10, 32, 64])
def test_random_many_values(self, size, input_width):
random.seed(14)
self.size = size
input_length = 300
input = [random.randrange(1 << input_width)] * random.randrange(input_length)
self.circ = SimpleTestCircuit(CountHashTab(size, 9, input_width))
with self.run_simulation(self.circ) as sim:
sim.add_sync_process(self.insert_process(input))
sim.add_sync_process(self.query_req_process(input))
sim.add_sync_process(self.query_resp_process(input))
50 changes: 50 additions & 0 deletions test/transactron/utils/test_hw_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import math
from transactron.testing import *
import random
from transactron.utils.amaranth_ext.hw_hash import *
import pytest


@pytest.mark.parametrize("model", [JenkinsHash96Bits, lambda: SipHash64Bits(2, 4), lambda: SipHash64Bits(1, 3)])
@pytest.mark.parametrize("bits", [16, 32, 64, 96])
@pytest.mark.parametrize("seed", [None, 0x573A8CF])
class TestHash(TestCaseWithSimulator):
test_number = 800
modulo = 16

def process(self):
hash_mod_list = [0] * self.modulo
if self.seed is not None:
yield self.circ.seed.eq(self.seed)
for i in range(self.test_number):
in_val = random.randrange(1 << self.bits)
yield self.circ.value.eq(in_val)
yield Settle()
yield Delay(1e-9) # pretty print
hash_mod_list[(yield self.circ.out_hash) % self.modulo] += 1
# hash_mod_list[lookup3(in_val, self.seed) % self.modulo] += 1
print(hash_mod_list)

# Chi squere test
# stat = sum([(count - self.test_number/self.modulo)**2/count for count in hash_mod_list])
# Chi(15) for p=0.01
# assert stat < 30.5779

for count in hash_mod_list:
p_test = count / self.test_number
p_expected = 1 / self.modulo

# Test of proportion
stat = (p_test - p_expected) / (math.sqrt(p_expected * (1 - p_expected))) * math.sqrt(self.test_number)

# Assumimg p=0.05 of the whole test and Bonfferoni correction k=16, we need the used p=0.003.
# For N(0,1) this is circa 3*standard deviation
# Sadly these hash algorithms have problems for the small input data sizes, so threshold is set on 4
assert stat < 4 and stat > -4

def test_random(self, seed, bits, model):
self.seed = seed
self.bits = bits
self.circ = model()
with self.run_simulation(self.circ) as sim:
sim.add_process(self.process)
116 changes: 114 additions & 2 deletions transactron/lib/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from ..core import *
from ..utils import SrcLoc, get_src_loc, MultiPriorityEncoder
from typing import Optional
from transactron.utils import assign, AssignType, LayoutList, MethodLayout
from transactron.utils import assign, AssignType, LayoutList, MethodLayout, ValueLike
from .reqres import ArgumentsToResultsZipper
from transactron.utils.amaranth_ext.hw_hash import JenkinsHash96Bits

__all__ = ["MemoryBank", "ContentAddressableMemory", "AsyncMemoryBank"]
__all__ = ["MemoryBank", "ContentAddressableMemory", "AsyncMemoryBank", "CountHashTab"]


class MemoryBank(Elaboratable):
Expand Down Expand Up @@ -307,3 +308,114 @@ def _(arg):
m.d.comb += write_port.en.eq(arg.mask)

return m


class CountHashTab(Elaboratable):
"""Hash table that counts inputs

This is a hash table that can be used as a part of CountSketch. It uses the
JenkinsHash96Bits function to hash the input data.
The resulting 32bit Jenkins hash is first striped to the nearest not smaller power
of two and if it is greater than the size of hash table,
the size is subtracted from it. This can lead in unbalanced hashes!


Attributes
----------
insert: Method
The insert method. Accepts a `data` to be hashed and counted in hash table.
The update is synchronous with an additional delay cycle.
query_req: Method
Receives `data` which frequency to read from the hash table.
The response can be queried with the `query_resp` method.
query_resp : Method
Ready only if there is a response available to read. Returns `count` value that
represents the number of hashed elements read from the table.
clear : Method
Zeros all fields in the hash table. No conflicts with other methods, but it haa the highest priority.
update_seed : Method
Gets an argument `seed` to be used to update the seed in the hash function. Please
note that this invalidates all stored values, so the table should be clear.
"""

def __init__(self, size: int, counter_width: int, input_data_width: int):
"""
Parameters
----------
size : int
The number of fields in the hash table. It is recommended to be power of two.
counter_width : int
The width in bits of a field of the hash table. After the counter reaches maximum
it will overflow back to 0.
input_data_width : int
The width in bits of the input data. Cannot be greater than 96.
"""
if input_data_width > 96:
raise ValueError(
"CountHashTab doesn't support input data longer than 96 bits because of "
+ "implementation limits of hash functions."
)
self.size = size
self.counter_width = counter_width
self.input_data_width = input_data_width

self.insert = Method(i=[("data", self.input_data_width)])
self.query_req = Method(i=[("data", self.input_data_width)])
self.query_resp = Method(o=[("count", self.counter_width)])
self.clear = Method()
self.update_seed = Method(i=[("seed", 64)])

self.query_resp.schedule_before(self.query_req)

def postprocess_hash(self, m, hash_org: ValueLike) -> Signal:
lsb_hash = Signal(ceil_log2(self.size))
out_hash = Signal(ceil_log2(self.size))
m.d.top_comb += lsb_hash.eq(hash_org)
m.d.av_comb += out_hash.eq(lsb_hash)
# Check if self.size is not power of two
if self.size & (self.size - 1):
with m.If(lsb_hash >= self.size):
m.d.av_comb += out_hash.eq(lsb_hash - self.size)
return out_hash

def elaborate(self, platform) -> TModule:
m = TModule()

insert_valid = Signal()
resp_valid = Signal()
insert_data = Signal(self.input_data_width)
query_data = Signal(self.input_data_width)
self.counters = Array([Signal(self.counter_width) for _ in range(self.size)])
hash_insert = self.postprocess_hash(m, JenkinsHash96Bits.create(m, insert_data, "hash_insert"))
hash_query = self.postprocess_hash(m, JenkinsHash96Bits.create(m, query_data, "hash_query"))

with m.If(insert_valid):
m.d.sync += self.counters[hash_insert].eq(self.counters[hash_insert] + 1)
m.d.sync += insert_valid.eq(0)

@def_method(m, self.insert)
def _(data):
m.d.sync += insert_data.eq(data)
m.d.sync += insert_valid.eq(1)

@def_method(m, self.query_resp, ready=resp_valid)
def _():
m.d.sync += resp_valid.eq(0)
return {"count": self.counters[hash_query]}

@def_method(m, self.query_req, ready=~resp_valid | self.query_resp.run)
def _(data):
m.d.sync += query_data.eq(data)
m.d.sync += resp_valid.eq(1)

@def_method(m, self.update_seed)
def _(seed):
m.d.sync += m.submodules.hash_insert.seed.eq(seed)
m.d.sync += m.submodules.hash_query.seed.eq(seed)

@def_method(m, self.clear)
def _():
for i in range(self.size):
m.d.sync += self.counters[i].eq(0)

return m
21 changes: 3 additions & 18 deletions transactron/utils/amaranth_ext/elaboratables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterable
from amaranth import *
from transactron.utils._typing import HasElaborate, ModuleLike, ValueLike
from transactron.utils.assign import add_to_submodules

__all__ = [
"OneHotSwitchDynamic",
Expand Down Expand Up @@ -316,14 +317,7 @@ def create(
on the second position.
"""
prio_encoder = MultiPriorityEncoder(input_width, outputs_count)
if name is None:
m.submodules += prio_encoder
else:
try:
getattr(m.submodules, name)
raise ValueError(f"Name: {name} is already in use, so MultiPriorityEncoder can not be added with it.")
except AttributeError:
setattr(m.submodules, name, prio_encoder)
add_to_submodules(m, prio_encoder, name)
m.d.comb += prio_encoder.input.eq(input)
return list(zip(prio_encoder.outputs, prio_encoder.valids))

Expand Down Expand Up @@ -478,16 +472,7 @@ def create(
on the second position.
"""
prio_encoder = RingMultiPriorityEncoder(input_width, outputs_count)
if name is None:
m.submodules += prio_encoder
else:
try:
getattr(m.submodules, name)
raise ValueError(
f"Name: {name} is already in use, so RingMultiPriorityEncoder can not be added with it."
)
except AttributeError:
setattr(m.submodules, name, prio_encoder)
add_to_submodules(m, prio_encoder, name)
m.d.comb += prio_encoder.input.eq(input)
m.d.comb += prio_encoder.first.eq(first)
m.d.comb += prio_encoder.last.eq(last)
Expand Down
Loading