Skip to content

Commit

Permalink
test head score
Browse files Browse the repository at this point in the history
  • Loading branch information
lepmik committed Mar 28, 2019
1 parent 5302690 commit 3109909
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
3 changes: 0 additions & 3 deletions head_direction/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def head_direction_rate(spike_train, head_angles, t,

with np.errstate(divide='ignore', invalid='ignore'):
rate_in_ang = np.divide(spikes_in_ang, time_in_ang)

rate_in_ang = moving_average(rate_in_ang, avg_window)
return ang_bins[:-1], rate_in_ang

Expand All @@ -62,8 +61,6 @@ def head_direction_score(head_angle_bins, rate):
"""
import math
import pycircstat as pc
# if any(np.isnan(rate)):
# raise ValueError('Nan not supported')
nanIndices = np.where(np.isnan(rate))
head_angle_bins = np.delete(head_angle_bins, nanIndices)
mean_ang = pc.mean(head_angle_bins, w=rate)
Expand Down
18 changes: 17 additions & 1 deletion head_direction/tests/test_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,20 @@ def test_head_rate():
sptr = np.linspace(0,1,100)
bins, rate = head_direction_rate(sptr, a, t, n_bins=8, avg_window=1)
assert bins[1] == np.pi / 4
assert rate[1] == 100
assert abs(rate[1] - 100) < .5


def test_head_score():
from head_direction.head import (
head_direction, head_direction_rate, head_direction_score)
x1 = np.linspace(.01,1,10)
y1 = x1
x2 = x1 + .01 # 1cm between
y2 = x1 - .01
t = np.linspace(0,1,10)
a, t = head_direction(x1, y1, x2, y2, t)
sptr = np.linspace(0,1,100)
bins, rate = head_direction_rate(sptr, a, t, n_bins=100, avg_window=2)
ang, score = head_direction_score(bins, rate)
assert abs(score - 1) < 0.001
assert abs(ang - np.pi / 4) < 0.00001
1 change: 1 addition & 0 deletions head_direction/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def moving_average(vector, N):
>>> moving_average(a, 5)
array([4., 3., 2., 3., 4., 5., 6., 7., 6., 5.])
"""
vector[np.isnan(vector)] = 0
if N * 2 > len(vector):
raise ValueError('Window must be at least half of "len(vector)"')
vector = np.concatenate((vector[-N:], vector, vector[:N]))
Expand Down

0 comments on commit 3109909

Please sign in to comment.