Skip to content

Commit

Permalink
cbpy update cbSPKCACHE usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
cboulay committed Aug 14, 2023
1 parent 21442bd commit 8a45fd3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
22 changes: 12 additions & 10 deletions cerebus/cbpy.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ cdef extern from "numpy/arrayobject.h":

cdef class SpikeCache:
cdef readonly int inst, chan, n_samples, n_pretrig
cdef cbSPKCACHE *cache
cdef cbSPKCACHE *p_cache
cdef int last_valid
# cache
# .chid: ID of the Channel
Expand All @@ -1023,33 +1023,35 @@ cdef class SpikeCache:
def __cinit__(self, int channel=1, int instance=0):
self.inst = instance
self.chan = channel
cdef cbSPKCACHE ignoreme # Just so self.cache is not NULL... but this won't be used by anything
self.cache = &ignoreme # because cbSdkGetSpkCache changes what self.cache is pointing to.
cdef cbSPKCACHE ignoreme # Just so self.p_cache is not NULL... but this won't be used by anything
self.p_cache = &ignoreme # because cbSdkGetSpkCache changes what self.p_cache is pointing to.
self.reset_cache()
sys_config_dict = get_sys_config(instance)
self.n_samples = sys_config_dict['spklength']
self.n_pretrig = sys_config_dict['spkpretrig']

def reset_cache(self):
cdef cbSdkResult res = cbSdkGetSpkCache(self.inst, self.chan, &self.cache)
cdef cbSdkResult res = cbSdkGetSpkCache(self.inst, self.chan, &self.p_cache)
handle_result(res)
self.last_valid = self.cache.valid
self.last_valid = self.p_cache.valid

# This function needs to be FAST!
@cython.boundscheck(False) # turn off bounds-checking for entire function
def get_new_waveforms(self):
cdef int new_valid = self.cache.valid
cdef int new_head = self.cache.head
# cdef everything!
cdef int new_valid = self.p_cache.valid
cdef int new_head = self.p_cache.head
cdef int n_new = min(max(new_valid - self.last_valid, 0), 400)
cdef np.ndarray[np.int16_t, ndim=2, mode="c"] np_waveforms = np.empty((n_new, self.n_samples), dtype=np.int16)
cdef np.ndarray[np.uint8_t, ndim=1] np_unit_ids = np.empty(n_new, dtype=np.uint8)
cdef np.ndarray[np.uint16_t, ndim=1] np_unit_ids = np.empty(n_new, dtype=np.uint16)
cdef int wf_ix, pkt_ix, samp_ix
for wf_ix in range(n_new):
pkt_ix = (new_head - 2 - n_new + wf_ix) % 400
np_unit_ids[wf_ix] = self.cache.spkpkt[pkt_ix].unit
np_unit_ids[wf_ix] = self.p_cache.spkpkt[pkt_ix].cbpkt_header.type
# Instead of per-sample copy, we could copy the pointer for the whole wave to the buffer of a 1-d np array,
# then use memory view copying from 1-d array into our 2d matrix. But below is pure-C so should be fast too.
for samp_ix in range(self.n_samples):
np_waveforms[wf_ix, samp_ix] = self.cache.spkpkt[pkt_ix].wave[samp_ix]
np_waveforms[wf_ix, samp_ix] = self.p_cache.spkpkt[pkt_ix].wave[samp_ix]
#unit_ids_out = [<int>unit_ids[wf_ix] for wf_ix in range(n_new)]
PyArray_ENABLEFLAGS(np_waveforms, np.NPY_OWNDATA)
self.last_valid = new_valid
Expand Down
4 changes: 2 additions & 2 deletions cerebus/cbsdk_cython.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ cdef extern from "cbproto.h":
ctypedef struct cbPKT_HEADER:
uint64_t time # system clock timestamp
uint16_t chid # channel identifier
uint8_t type # packet type
uint16_t type # packet type
uint16_t dlen # length of data field in 32-bit chunks
uint8_t instrument # instrument number to transmit this packets
uint8_t reserved[2] # reserved for future
uint8_t reserved # reserved for future

ctypedef struct cbPKT_CHANINFO:
cbPKT_HEADER cbpkt_header
Expand Down
11 changes: 10 additions & 1 deletion samples/Python/fetch_data_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,18 @@
)

try:
spk_cache = {}
while True:
result, data = cbpy.trial_event(reset=True)
if len(data) > 0:
print(data)
for ev in data:
chid = ev[0]
ev_dict = ev[1]
timestamps = ev_dict["timestamps"]
print(f"Ch {chid} unit 0 has {len(timestamps[0])} events.")
if chid not in spk_cache:
spk_cache[chid] = cbpy.SpikeCache(channel=chid)
temp_wfs, unit_ids = spk_cache[chid].get_new_waveforms()
print(f"Waveform shape: {temp_wfs.shape} on unit_ids {unit_ids}")
except KeyboardInterrupt:
cbpy.close()

0 comments on commit 8a45fd3

Please sign in to comment.