Skip to content

Commit

Permalink
remove anchor
Browse files Browse the repository at this point in the history
  • Loading branch information
luiztauffer committed Nov 29, 2024
1 parent 6103a4f commit 96ce0b8
Showing 1 changed file with 4 additions and 23 deletions.
27 changes: 4 additions & 23 deletions src/vame/preprocessing/clean_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def clean_timeseries(
pos = [0]
pos_temp = 0

pose_ref_1 = config["pose_ref_1"]
pose_ref_2 = config["pose_ref_2"]

session_names = config["session_names"]
for session in session_names:
logger.info("z-scoring of session %s" % session)
Expand All @@ -27,7 +30,7 @@ def clean_timeseries(

path_to_file = Path(config["project_path"]) / "data" / "processed" / session / session + "-aligned.nc"
ds = load_vame_dataset(path_to_file)
X = ds.position_aligned.sel(individuals="individual_0").values
X = ds.position_aligned.sel(individuals="individual_0").drop_sel(keypoints=pose_ref_1).values

# Standardize data
X_mean = np.mean(X, axis=0)
Expand All @@ -53,25 +56,3 @@ def clean_timeseries(
X_all_sessions.append(X_z)

X_all_sessions = np.concatenate(X_all_sessions, axis=0)

# Detect and delete anchors
detect_anchors = np.std(X_all_sessions, axis=0)
sort_anchors = np.sort(detect_anchors)
if sort_anchors[0] == sort_anchors[1]:
anchors = np.where(detect_anchors == sort_anchors[0])[0]
anchor_1_temp = anchors[0]
anchor_2_temp = anchors[1]
else:
anchor_1_temp = int(np.where(detect_anchors == sort_anchors[0])[0])
anchor_2_temp = int(np.where(detect_anchors == sort_anchors[1])[0])

if anchor_1_temp > anchor_2_temp:
anchor_1 = anchor_1_temp
anchor_2 = anchor_2_temp
else:
anchor_1 = anchor_2_temp
anchor_2 = anchor_1_temp

X = np.delete(X, anchor_1, 1)
X = np.delete(X, anchor_2, 1)
X = X.T

0 comments on commit 96ce0b8

Please sign in to comment.