forked from hp240920/Traffic-Light-Detection-System
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lif_model.py
92 lines (75 loc) · 2.56 KB
/
lif_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import matplotlib.pyplot as plt
class lif:
def __init__(self, I= 0, **kwargs):
# define variables
self.const_cap = 10
self.voltage = np.zeros(100, dtype=float)
self.resistance = 1
self.threshold = 0.1
self.time_cost = self.const_cap * self.resistance
self.v_spike = 0.5
self.ref_time = 4 # refractory time ms after spike
self.reset_v = 0
self.I = I
self.length = 100 # time in ms
self.mini_time = 1 # ??? 1000/0.50= 2000
def plot_potential_decay(obj):
isSpike = False
time = 0
arr_time = np.arange(0, obj.length, obj.mini_time)
for i in range(0, len(arr_time) - 1):
if isSpike:
obj.voltage[i + 1] = obj.reset_v
isSpike = False
continue
if arr_time[i] < time:
continue
obj.voltage[i + 1] = obj.voltage[i] + obj.mini_time * (
(obj.I - (obj.voltage[i] / obj.resistance)) / obj.const_cap)
if obj.voltage[i + 1] >= obj.threshold:
obj.voltage[i + 1] += obj.v_spike
time = arr_time[i] + obj.ref_time
# print(arr_time[i])
isSpike = True
plt.plot(arr_time, obj.voltage)
plt.xlabel('time (ms)')
plt.ylabel('Output (mV)')
plt.show()
def count_spikes(obj):
spike_count = 0
isSpike = False
time = 0
arr_time = np.arange(0, obj.length, obj.mini_time)
for i in range(0, len(arr_time) - 1):
if isSpike:
obj.voltage[i + 1] = obj.reset_v
isSpike = False
continue
if arr_time[i] < time:
continue
obj.voltage[i + 1] = obj.voltage[i] + obj.mini_time * (
(obj.I - (obj.voltage[i] / obj.resistance)) / obj.const_cap)
if obj.voltage[i + 1] >= obj.threshold:
obj.voltage[i + 1] += obj.v_spike
time = arr_time[i] + obj.ref_time
spike_count += 1
isSpike = True
# print(spike_count)
return spike_count
def plot_spiking_behavior():
arr_current = np.arange(1, 5, 0.05)
num_spikes = []
for i in arr_current:
obj = lif(i)
num_spikes.append(count_spikes(obj))
# print(num_spikes)
plt.plot(arr_current, num_spikes)
plt.xlabel('Synaptic current (I)')
plt.ylabel('Number of spikes')
plt.show()
if __name__ == '__main__':
a = lif(0.9)
print(count_spikes(a))
plot_potential_decay(a)
# plot_spiking_behavior()