forked from EischLab/deeplabcut-social-interaction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
interaction_zone.py
254 lines (225 loc) · 14.1 KB
/
interaction_zone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import glob
import math
import os
from tkinter import filedialog
import pandas as pd
from ast import literal_eval
def interaction_zone_area(arena_top_corner, arena_bottom_corner, side, pixel_per_cm, interaction_dist,
interaction_width, interaction_length):
"""
A function that returns the coordinates for the interaction zone area
:param arena_top_corner: The arena top corner coordinate
:param arena_bottom_corner: The arena bottom corner coordinate
:param side: Left or Right arena
:param pixel_per_cm: The amount of pixels for centimeter conversion
:param interaction_dist: The interaction distance in pixels
:param interaction_width: The width of the interaction zone area
:param interaction_length: The length of the itneraction zone area
:return: A list of coordinates that correspond to the interaction zone area
"""
if side == 'left':
tl_corner = (arena_top_corner[0], int(arena_top_corner[1] + interaction_dist * pixel_per_cm))
bl_corner = (arena_bottom_corner[0], int(arena_bottom_corner[1] - interaction_dist * pixel_per_cm))
tr_corner = (int(tl_corner[0] + interaction_width * pixel_per_cm), tl_corner[1])
br_corner = (int(bl_corner[0] + interaction_width * pixel_per_cm), bl_corner[1])
return tl_corner, tr_corner, bl_corner, br_corner
if side == 'right':
tr_corner = (arena_top_corner[0], int(arena_top_corner[1] + interaction_dist * pixel_per_cm))
br_corner = (arena_bottom_corner[0], int(arena_bottom_corner[1] - interaction_dist * pixel_per_cm))
tl_corner = (int(tr_corner[0] - interaction_width * pixel_per_cm), tr_corner[1])
bl_corner = (int(br_corner[0] - interaction_width * pixel_per_cm), br_corner[1])
return tl_corner, tr_corner, bl_corner, br_corner
def check_zone(mouse_coord, interaction_zone_corners):
"""
A function that checks if the mouse's nose is within the interaction zone
:param mouse_coord: The (x,y) coordinates of the mouse's nose
:param interaction_zone_corners: A list containing the coordinates of the interaction zone
:return: A boolean value depending if the condition is met
"""
# checks if the mouse's nose is within the restricted area
if interaction_zone_corners[0][0] <= mouse_coord[0] <= interaction_zone_corners[3][0] and \
interaction_zone_corners[0][
1] <= mouse_coord[1] <= interaction_zone_corners[3][1]:
return True
return False
def update_counters(first_col, second_col, total_frames, interaction_frames,
interaction_entries, zone, current_frame):
"""
A function that updates all related counters
:param first_col: The column corresponding to the x position of the mouse's nose
:param second_col: The column corresponding to the y position of the mouse's nose
:param total_frames: The total amount of frames in the video
:param interaction_frames: The total amount of interaction frame for the mouse
:param interaction_entries: The mouse's total entry counter
:param zone: The list of zone coordinates
:param current_frame: The current amount of frames going towards the interaction
:return: An updated version of each counter if the conditions are met
"""
mouse_coord = (float(first_col), float(second_col))
if not math.isnan(mouse_coord[0]) and not math.isnan(mouse_coord[1]):
total_frames += 1
if check_zone(mouse_coord, zone):
interaction_frames += 1
current_frame += 1
consecutive = True
else:
consecutive = False
if not consecutive:
current_frame = 0
if current_frame == 1:
interaction_entries += 1
return interaction_frames, interaction_entries, total_frames, current_frame
def interaction_zone(enclosure_length_cm, enclosure_length_pix, interact_dist, interact_width, interact_length,
left_arena_top, left_arena_bot, right_arena_top, right_arena_bot):
"""
A function that produces a CSV with information about the time in the interaction zone and entries in the zone
:param enclosure_length_cm: The length of the enclosure in cm
:param enclosure_length_pix: The length of the enclosure in pixels
:param interact_dist: The interaction distance in pixel
:param interact_width: The interaction zone width
:param interact_length: The interaction zone length
:param left_arena_top: The top corner in the left arena
:param left_arena_bot: The bottom corner in the left arena
:param right_arena_top: The top corner in the right arena
:param right_arena_bot: The bottom corner in the right arena
"""
enclosure_len_cm = int(enclosure_length_cm.get())
enclosure_len_pix = int(enclosure_length_pix.get())
pix_per_cm = enclosure_len_pix / enclosure_len_cm
l_arena_top, l_arena_bot = literal_eval(left_arena_top.get()), literal_eval(left_arena_bot.get())
r_arena_top, r_arena_bot = literal_eval(right_arena_top.get()), literal_eval(right_arena_bot.get())
# create interaction zone area
la_tl, la_tr, la_bl, la_br = interaction_zone_area(l_arena_top, l_arena_bot, 'left', pix_per_cm,
int(interact_dist.get()),
int(interact_width.get()), int(interact_length.get()))
ra_tl, ra_tr, ra_bl, ra_br = interaction_zone_area(r_arena_top, r_arena_bot, 'right', pix_per_cm,
int(interact_dist.get()),
int(interact_width.get()), int(interact_length.get()))
left_interaction_zone = [la_tl, la_tr, la_bl, la_br]
right_interaction_zone = [ra_tl, ra_tr, ra_bl, ra_br]
# set up file directory
file_path = filedialog.askdirectory()
pattern = os.path.join(file_path, '*.csv')
files = glob.glob(pattern)
mouse_entry = dict()
mouse_entry_missed = dict()
# iterate the files and update the counters
for index, file in enumerate(files):
df_csv = pd.read_csv(file, index_col=False)
left_total_interaction_frames, left_total_interaction_entries, left_total_frames, left_current_frame = 0, 0, 0, 0
left_missed_total_interaction_frames, left_missed_total_interaction_entries, left_missed_total_frames, left_missed_current_frame = 0, 0, 0, 0
right_total_interaction_frames, right_total_interaction_entries, right_total_frames, right_current_frame = 0, 0, 0, 0
right_missed_total_interaction_frames, right_missed_total_interaction_entries, right_missed_total_frames, right_missed_current_frame = 0, 0, 0, 0
mouse_counter = 1
for row in df_csv[3:].itertuples():
left_total_interaction_frames, left_total_interaction_entries, left_total_frames, left_current_frame = update_counters(
row[14],
row[15],
left_total_frames,
left_total_interaction_frames,
left_total_interaction_entries,
left_interaction_zone,
left_current_frame)
right_total_interaction_frames, right_total_interaction_entries, right_total_frames, right_current_frame = update_counters(
row[2],
row[3],
right_total_frames,
right_total_interaction_frames,
right_total_interaction_entries,
right_interaction_zone,
right_current_frame)
left_missed_total_interaction_frames, left_missed_total_interaction_entries, left_missed_total_frames, left_missed_current_frame = update_counters(
row[2],
row[3],
left_missed_total_frames,
left_missed_total_interaction_frames,
left_missed_total_interaction_entries,
left_interaction_zone,
left_missed_current_frame)
right_missed_total_interaction_frames, right_missed_total_interaction_entries, right_missed_total_frames, right_missed_current_frame = update_counters(
row[14],
row[15],
right_missed_total_frames,
right_missed_total_interaction_frames,
right_missed_total_interaction_entries,
right_interaction_zone,
right_missed_current_frame)
mouse_entry['trial_' + str(index + 1) + '_mouse_' + str(mouse_counter)] = [left_total_interaction_frames / 25,
left_total_interaction_entries]
mouse_entry_missed['trial_' + str(index + 1) + '_mouse_' + str(mouse_counter)] = [
left_missed_total_interaction_frames / 25,
left_missed_total_interaction_entries]
mouse_counter += 1
mouse_entry['trial_' + str(index + 1) + '_mouse_' + str(mouse_counter)] = [right_total_interaction_frames / 25,
right_total_interaction_entries]
mouse_entry_missed['trial_' + str(index + 1) + '_mouse_' + str(mouse_counter)] = [
right_missed_total_interaction_frames / 25,
right_missed_total_interaction_entries]
# convert the dictionaries into df, then into CSV file
interaction_zone_df = pd.DataFrame.from_dict(mouse_entry, orient='index',
columns=['Time in Interaction Zone (s)', 'Entries in Interaction Zone'])
interaction_zone_filtered_df = interaction_zone_df[(interaction_zone_df['Time in Interaction Zone (s)'] == 0) & (
interaction_zone_df['Entries in Interaction Zone'] == 0)]
missed_entries = list(interaction_zone_filtered_df.index)
interaction_zone_missed_df = pd.DataFrame.from_dict(mouse_entry_missed, orient='index',
columns=['Time in Interaction Zone (s)',
'Entries in Interaction Zone'])
interaction_zone_missed_series = interaction_zone_missed_df.index.isin(missed_entries)
interaction_zone_missed_df = interaction_zone_missed_df[interaction_zone_missed_series]
interaction_zone_df.update(interaction_zone_missed_df)
save_file_path = filedialog.asksaveasfilename(defaultextension='.csv', title='Save the file')
interaction_zone_df.to_csv(save_file_path)
def make_interaction_zone_buttons(tk, root):
"""
Creates the buttons and UI for the interaction zone functionalities
:param tk:
:param root:
:return:
"""
iz_enclosure_pixel_label = tk.Label(root, text='Enter enclosure length in pixels:')
iz_enclosure_pixel_label.grid(row=0, column=0)
iz_enclosure_pixel_entry = tk.Entry(root, width=30, justify='center')
iz_enclosure_pixel_entry.grid(row=0, column=1)
iz_enclosure_cm_label = tk.Label(root, text='Enter enclosure length in cm:')
iz_enclosure_cm_label.grid(row=1, column=0)
iz_enclosure_cm_entry = tk.Entry(root, width=30, justify='center')
iz_enclosure_cm_entry.grid(row=1, column=1)
iz_interaction_dist_cm_label = tk.Label(root, text='Enter distance from corner to interaction zone in cm:')
iz_interaction_dist_cm_label.grid(row=2, column=0)
iz_interaction_dist_cm_entry = tk.Entry(root, width=30, justify='center')
iz_interaction_dist_cm_entry.grid(row=2, column=1)
iz_interaction_width_cm_label = tk.Label(root, text='Enter interaction zone width in cm:')
iz_interaction_width_cm_label.grid(row=3, column=0)
iz_interaction_width_cm_entry = tk.Entry(root, width=30, justify='center')
iz_interaction_width_cm_entry.grid(row=3, column=1)
iz_interaction_length_cm_label = tk.Label(root, text='Enter interaction zone length in cm:')
iz_interaction_length_cm_label.grid(row=4, column=0)
iz_interaction_length_cm_entry = tk.Entry(root, width=30, justify='center')
iz_interaction_length_cm_entry.grid(row=4, column=1)
spacer_btn = tk.Label(root, text='')
spacer_btn.grid(row=5, column=0)
iz_left_arena_top_corner_label = tk.Label(root, text='Enter left arena top corner as (x,y):')
iz_left_arena_top_corner_label.grid(row=6, column=0)
iz_left_arena_top_corner_entry = tk.Entry(root, width=30, justify='center')
iz_left_arena_top_corner_entry.grid(row=6, column=1)
iz_left_arena_bottom_corner_label = tk.Label(root, text='Enter left arena bottom corner as (x,y):')
iz_left_arena_bottom_corner_label.grid(row=7, column=0)
iz_left_arena_bottom_corner_entry = tk.Entry(root, width=30, justify='center')
iz_left_arena_bottom_corner_entry.grid(row=7, column=1)
iz_right_arena_top_corner_label = tk.Label(root, text='Enter right arena top corner as (x,y):')
iz_right_arena_top_corner_label.grid(row=8, column=0)
iz_right_arena_top_corner_entry = tk.Entry(root, width=30, justify='center')
iz_right_arena_top_corner_entry.grid(row=8, column=1)
iz_right_arena_bottom_corner_label = tk.Label(root, text='Enter right arena bottom corner as (x,y)')
iz_right_arena_bottom_corner_label.grid(row=9, column=0)
iz_right_arena_bottom_corner_entry = tk.Entry(root, width=30, justify='center')
iz_right_arena_bottom_corner_entry.grid(row=9, column=1)
iz_button = tk.Button(root, text='Make IZ CSV',
command=lambda: interaction_zone(iz_enclosure_cm_entry, iz_enclosure_pixel_entry,
iz_interaction_dist_cm_entry, iz_interaction_width_cm_entry,
iz_interaction_length_cm_entry,
iz_left_arena_top_corner_entry,
iz_left_arena_bottom_corner_entry,
iz_right_arena_top_corner_entry,
iz_right_arena_bottom_corner_entry))
iz_button.grid(row=10, column=0, columnspan= 2)