Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Aug 16, 2022
2 parents 6f93ce8 + e73107c commit 5aa00af
Show file tree
Hide file tree
Showing 8 changed files with 937 additions and 278 deletions.
322 changes: 174 additions & 148 deletions hs_detection/detect/Detection.cpp

Large diffs are not rendered by default.

250 changes: 130 additions & 120 deletions hs_detection/detect/Detection.h
Original file line number Diff line number Diff line change
@@ -1,120 +1,130 @@
#ifndef DETECTION_H
#define DETECTION_H

#include <string>

#include "ProbeLayout.h"
#include "TraceWrapper.h"
#include "RollingArray.h"
#include "SpikeQueue.h"

namespace HSDetection
{
class Detection
{
private:
friend SpikeQueue; // allow access to the whole param set

// constants
static constexpr IntVolt initBase = 0; // initial value of baseline
static constexpr IntVolt initDev = 400; // initial value of deviation
static constexpr IntVolt tauBase = 4; // time constant for baseline update
static constexpr IntVolt devChange = 1; // changing for deviation update
static constexpr IntVolt minDev = 200; // minimum level of deviation

static constexpr IntCalc thrQuant = 256; // 8bit precision

static constexpr size_t channelAlign = 32; // align IntVolt=16bit to 64B (assume FloatRaw is wider)

static constexpr IntChannel alignChannel(IntChannel x) { return (x + (channelAlign - 1)) & (-channelAlign); }

// input data
TraceWrapper traceRaw; // input trace
IntChannel numChannels; // number of probe channels
IntFrame chunkSize; // size of each chunk, only the last chunk can be of a different (smaller) size
IntFrame chunkLeftMargin; // margin on the left of each chunk

// rescaling
bool rescale; // whether to scale the input
FloatRaw *scale; // scale for rescaling
FloatRaw *offset; // offset for rescaling
RollingArray trace; // rescaled and quantized trace to be used

// common reference
bool medianReference; // whether to use CMR (overrides CAR)
bool averageReference; // whether to use CAR
RollingArray commonRef; // common median/average reference

// running estimation
RollingArray runningBaseline; // running estimation of baseline (33 percentile)
RollingArray runningDeviation; // running estimation of deviation from baseline

// detection
IntFrame *spikeTime; // counter for time since spike peak
IntVolt *spikeAmp; // spike peak amplitude
IntCalc *spikeArea; // area under spike used for average amplitude, actually integral*fps
bool *hasAHP; // flag for AHP existence

IntFrame spikeDur; // duration of a spike since peak
IntFrame ampAvgDur; // duration to average amplitude
IntCalc threshold; // threshold to detect spikes, used as multiplier of deviation
IntCalc minAvgAmp; // threshold for average amplitude of peak, used as multiplier of deviation
IntCalc maxAHPAmp; // threshold for voltage level of AHP, used as multiplier of deviation

// queue processing
SpikeQueue *pQueue; // spike queue, must be a pointer to be new-ed later

ProbeLayout probeLayout; // geometry for probe layout

std::vector<Spike> result; // detection result, use vector to expand as needed

IntFrame jitterTol; // tolerance of jitter in electrical signal
IntFrame riseDur; // duration that a spike rises to peak

// decay filtering
bool decayFilter; // whether to use decay filtering instead of normal one
FloatRatio decayRatio; // ratio of amplitude to be considered as decayed

// localization
bool localize; // whether to turn on localization

// save shape
bool saveShape; // whether to save spike shapes to file
std::string filename; // filename for saving
IntFrame cutoutStart; // the start of spike shape cutout
IntFrame cutoutEnd; // the end of cutout

private:
void traceScaleCast(IntFrame chunkStart, IntFrame chunkLen);
void traceCast(IntFrame chunkStart, IntFrame chunkLen);
void commonMedian(IntFrame chunkStart, IntFrame chunkLen);
void commonAverage(IntFrame chunkStart, IntFrame chunkLen);
void runningEstimation(IntFrame chunkStart, IntFrame chunkLen);
void detectSpikes(IntFrame chunkStart, IntFrame chunkLen);

public:
Detection(IntChannel numChannels, IntFrame chunkSize, IntFrame chunkLeftMargin,
bool rescale, const FloatRaw *scale, const FloatRaw *offset,
bool medianReference, bool averageReference,
IntFrame spikeDur, IntFrame ampAvgDur,
FloatRatio threshold, FloatRatio minAvgAmp, FloatRatio maxAHPAmp,
const FloatGeom *channelPositions, FloatGeom neighborRadius, FloatGeom innerRadius,
IntFrame jitterTol, IntFrame riseDur,
bool decayFiltering, FloatRatio decayRatio, bool localize,
bool saveShape, std::string filename, IntFrame cutoutStart, IntFrame cutoutEnd);
~Detection();

// copy constructor deleted to protect internals
Detection(const Detection &) = delete;
// copy assignment deleted to protect internals
Detection &operator=(const Detection &) = delete;

void step(FloatRaw *traceBuffer, IntFrame chunkStart, IntFrame chunkLen);
IntResult finish();
const Spike *getResult() const;

}; // class Detection

} // namespace HSDetection

#endif
#ifndef DETECTION_H
#define DETECTION_H

#include <string>

#include "ProbeLayout.h"
#include "TraceWrapper.h"
#include "RollingArray.h"
#include "SpikeQueue.h"

namespace HSDetection
{
class Detection
{
private:
friend SpikeQueue; // allow access to the whole param set

// constants
static constexpr IntVolt initBase = 0; // initial value of baseline
static constexpr IntVolt initDev = 400; // initial value of deviation
static constexpr IntVolt tauBase = 4; // time constant for baseline update
static constexpr IntVolt devChange = 1; // changing for deviation update
static constexpr IntVolt minDev = 200; // minimum level of deviation

static constexpr IntCalc thrQuant = 256; // 8bit precision

static constexpr IntChannel channelAlign = 32; // align IntVolt=16bit to 64B (assume FloatRaw is wider)

static constexpr IntChannel alignChannel(IntChannel x) { return (x + (channelAlign - 1)) / channelAlign; }

// input data
TraceWrapper traceRaw; // input trace
IntChannel numChannels; // number of probe channels
IntChannel alignedChannels; // number of slices of aligned channels
IntFrame chunkSize; // size of each chunk, only the last chunk can be of a different (smaller) size
IntFrame chunkLeftMargin; // margin on the left of each chunk

// rescaling
bool rescale; // whether to scale the input
FloatRaw *scale; // scale for rescaling
FloatRaw *offset; // offset for rescaling
RollingArray trace; // rescaled and quantized trace to be used

// common reference
bool medianReference; // whether to use CMR (overrides CAR)
bool averageReference; // whether to use CAR
RollingArray commonRef; // common median/average reference

// running estimation
RollingArray runningBaseline; // running estimation of baseline (33 percentile)
RollingArray runningDeviation; // running estimation of deviation from baseline

// detection
IntFrame *spikeTime; // counter for time since spike peak
IntVolt *spikeAmp; // spike peak amplitude
IntCalc *spikeArea; // area under spike used for average amplitude, actually integral*fps
bool *hasAHP; // flag for AHP existence

IntFrame spikeDur; // duration of a spike since peak
IntFrame ampAvgDur; // duration to average amplitude
IntCalc threshold; // threshold to detect spikes, used as multiplier of deviation
IntCalc minAvgAmp; // threshold for average amplitude of peak, used as multiplier of deviation
IntCalc maxAHPAmp; // threshold for voltage level of AHP, used as multiplier of deviation

// queue processing
SpikeQueue *pQueue; // spike queue, must be a pointer to be new-ed later

ProbeLayout probeLayout; // geometry for probe layout

std::vector<Spike> result; // detection result, use vector to expand as needed

IntFrame jitterTol; // tolerance of jitter in electrical signal
IntFrame riseDur; // duration that a spike rises to peak

// decay filtering
bool decayFilter; // whether to use decay filtering instead of normal one
FloatRatio decayRatio; // ratio of amplitude to be considered as decayed

// localization
bool localize; // whether to turn on localization

// save shape
bool saveShape; // whether to save spike shapes to file
std::string filename; // filename for saving
IntFrame cutoutStart; // the start of spike shape cutout
IntFrame cutoutEnd; // the end of cutout

private:
inline void scaleCast(IntVolt *trace, const FloatRaw *input);
inline void noscaleCast(IntVolt *trace, const FloatRaw *input);
inline void commonMedian(IntVolt *ref, const IntVolt *trace,
IntVolt *buffer, IntChannel mid);
inline void commonAverage(IntVolt *ref, const IntVolt *trace);
void scaleAndAverage(IntFrame chunkStart, IntFrame chunkLen);
void castAndCommonref(IntFrame chunkStart, IntFrame chunkLen);
inline void estimation(IntVolt *baselines, IntVolt *deviations,
const IntVolt *trace, const IntVolt *ref,
const IntVolt *basePrev, const IntVolt *devPrev,
IntChannel alignedStart, IntChannel alignedEnd);
inline void detection(const IntVolt *trace, const IntVolt *ref,
const IntVolt *baselines, const IntVolt *deviations,
IntChannel channelStart, IntChannel channelEnd, IntFrame t);
void estimateAndDetect(IntFrame chunkStart, IntFrame chunkLen);

public:
Detection(IntChannel numChannels, IntFrame chunkSize, IntFrame chunkLeftMargin,
bool rescale, const FloatRaw *scale, const FloatRaw *offset,
bool medianReference, bool averageReference,
IntFrame spikeDur, IntFrame ampAvgDur,
FloatRatio threshold, FloatRatio minAvgAmp, FloatRatio maxAHPAmp,
const FloatGeom *channelPositions, FloatGeom neighborRadius, FloatGeom innerRadius,
IntFrame jitterTol, IntFrame riseDur,
bool decayFiltering, FloatRatio decayRatio, bool localize,
bool saveShape, std::string filename, IntFrame cutoutStart, IntFrame cutoutEnd);
~Detection();

// copy constructor deleted to protect internals
Detection(const Detection &) = delete;
// copy assignment deleted to protect internals
Detection &operator=(const Detection &) = delete;

void step(FloatRaw *traceBuffer, IntFrame chunkStart, IntFrame chunkLen);
IntResult finish();
const Spike *getResult() const;

}; // class Detection

} // namespace HSDetection

#endif
2 changes: 1 addition & 1 deletion hs_detection/detect/RollingArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace HSDetection
IntFrame frameMask; // rolling length will be 2^n and mask is 2^n-1 for bit ops
IntChannel numChannels;

static constexpr std::align_val_t memAlign = std::align_val_t(4096); // align to 4K anyway
static constexpr std::align_val_t memAlign = std::align_val_t(512); // align to 4K/8 to avoid 4K alias

static constexpr IntFrame getMask(IntFrame x) // get minimum 0...01...1 >= x
{
Expand Down
30 changes: 26 additions & 4 deletions hs_detection/detect/SpikeQueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ using namespace std;
namespace HSDetection
{
SpikeQueue::SpikeQueue(Detection *pDet)
: queue(), queProcs(), spkProcs(), pRresult(&pDet->result),
procDelay(max(pDet->cutoutEnd, pDet->riseDur + pDet->spikeDur) + pDet->jitterTol + 1)
: spikes((Spike *)new char[pDet->chunkSize * pDet->numChannels * sizeof(Spike)]), spikeCnt(0),
queue(), queProcs(), spkProcs(), pRresult(&pDet->result),
procDelay(max(pDet->cutoutEnd - pDet->spikeDur, pDet->riseDur) + pDet->jitterTol + 1)
{
SpikeProcessor *pSpkProc;
QueueProcessor *pQueProc;
Expand Down Expand Up @@ -64,9 +65,11 @@ namespace HSDetection
for_each(spkProcs.begin(), spkProcs.end(),
[](SpikeProcessor *pSpkProc)
{ delete pSpkProc; });

delete[](char *) spikes;
}

void SpikeQueue::process()
void SpikeQueue::procFront()
{
for_each(queProcs.begin(), queProcs.end(),
[this](QueueProcessor *pQueProc)
Expand All @@ -76,11 +79,30 @@ namespace HSDetection
queue.erase(queue.begin());
}

void SpikeQueue::process()
{
sort(spikes, spikes + spikeCnt,
[](const Spike &lhs, const Spike &rhs)
{ return lhs.frame < rhs.frame || (lhs.frame == rhs.frame && lhs.channel < rhs.channel); });

for (IntResult i = 0; i < spikeCnt; i++)
{
while (!queue.empty() && queue.front().frame < spikes[i].frame - procDelay)
{
procFront();
}

queue.push_back(move(spikes[i]));
}

spikeCnt = 0; // reset for next chunk
}

void SpikeQueue::finalize()
{
while (!queue.empty())
{
process();
procFront();
}
}

Expand Down
15 changes: 13 additions & 2 deletions hs_detection/detect/SpikeQueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace HSDetection
class SpikeQueue
{
private:
Spike *spikes; // buffer for detected spikes
IntResult spikeCnt; // count of detected spikes

std::list<Spike> queue; // list has constant-time erase and also bi-directional iter

std::vector<QueueProcessor *> queProcs; // content created and released here
Expand All @@ -26,6 +29,9 @@ namespace HSDetection

IntFrame procDelay; // delayed frames from push to process

void procFront();
// cannot inline procFront because no definition of Processor here

public:
SpikeQueue(Detection *pDet); // passing the whole param set altogether
~SpikeQueue();
Expand All @@ -35,10 +41,15 @@ namespace HSDetection
// copy assignment deleted to protect container content
SpikeQueue &operator=(const SpikeQueue &) = delete;

bool checkDelay(IntFrame curFrame) { return !queue.empty() && queue.front().frame < curFrame - procDelay; }
void addSpike(Spike &&spike)
{
IntResult spkIdx;
#pragma omp atomic capture
spkIdx = spikeCnt++;
spikes[spkIdx] = std::move(spike);
}
void process();
void finalize();
// cannot inline process because no definition of Processor here

// wrappers of container interface

Expand Down
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def cythonize(module_list, **kwargs):


PROFILE = 0
NATIVE_OPTIM = True


def get_version() -> str:
Expand Down Expand Up @@ -67,8 +68,9 @@ def get_version() -> str:
sources = glob.glob('hs_detection/detect/**/[A-Z]*.cpp', recursive=True)
sources += [os.path.join('hs_detection/detect', fn) for fn in ext_src]

extra_compile_args = ['-std=c++17', '-O3']
link_extra_args = []
extra_compile_args = ['-std=c++17', '-O3', '-fopenmp'] + \
['-march=native', '-mtune=native'] * NATIVE_OPTIM
link_extra_args = ['-fopenmp']
# OS X support
if platform.system() == 'Darwin':
extra_compile_args += ['-mmacosx-version-min=10.14', '-F.']
Expand All @@ -79,7 +81,8 @@ def get_version() -> str:
Extension(name='hs_detection.detect.detect',
sources=sources,
include_dirs=[numpy_include],
define_macros=[('CYTHON_TRACE_NOGIL', '1' if PROFILE >= 2 else '0')],
define_macros=[
('CYTHON_TRACE_NOGIL', '1' if PROFILE >= 2 else '0')],
extra_compile_args=extra_compile_args,
extra_link_args=link_extra_args,
language='c++'),
Expand Down
Loading

0 comments on commit 5aa00af

Please sign in to comment.