Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
luiztauffer committed Nov 29, 2024
1 parent d538d86 commit 6103a4f
Show file tree
Hide file tree
Showing 31 changed files with 251 additions and 416 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ where = ["src"]
[tool.pytest.ini_options]
pythonpath = [".", "src"]
testpaths = ["tests"]

[tool.black]
line-length = 119
1 change: 1 addition & 0 deletions src/vame/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
sys.dont_write_bytecode = True

from vame.initialize_project import init_new_project
from vame.preprocessing.preprocessing import preprocessing
from vame.model import create_trainset
from vame.model import train_model
from vame.model import evaluate_model
Expand Down
39 changes: 7 additions & 32 deletions src/vame/analysis/community_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,7 @@ def get_motif_labels(
file_labels = np.load(
os.path.join(
path_to_dir,
str(n_clusters)
+ "_"
+ segmentation_algorithm
+ "_label_"
+ session
+ ".npy",
str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy",
)
)
shape = len(file_labels)
Expand All @@ -276,12 +271,7 @@ def get_motif_labels(
file_labels = np.load(
os.path.join(
path_to_dir,
str(n_clusters)
+ "_"
+ segmentation_algorithm
+ "_label_"
+ session
+ ".npy",
str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy",
)
)[:min_frames]
community_label.extend(file_labels)
Expand Down Expand Up @@ -390,11 +380,7 @@ def create_cohort_community_bag(
add = input("Extend list or add in the end? (ext/end)")
if add == "ext":
motif_idx = int(input("Which motif number? "))
list_idx = int(
input(
"At which position in the list? (pythonic indexing starts at 0) "
)
)
list_idx = int(input("At which position in the list? (pythonic indexing starts at 0) "))
community_bag[list_idx].append(motif_idx)
if add == "end":
motif_idx = int(input("Which motif number? "))
Expand Down Expand Up @@ -440,9 +426,7 @@ def get_cohort_community_labels(
for j in range(len(clust)):
find_clust = np.where(motif_labels == clust[j])[0]
community_labels[find_clust] = i
community_labels = np.int64(
scipy.signal.medfilt(community_labels, median_filter_size)
)
community_labels = np.int64(scipy.signal.medfilt(community_labels, median_filter_size))
community_labels_all.append(community_labels)
return community_labels_all

Expand All @@ -468,12 +452,7 @@ def save_cohort_community_labels_per_file(
file_labels = np.load(
os.path.join(
path_to_dir,
str(n_clusters)
+ "_"
+ segmentation_algorithm
+ "_label_"
+ session
+ ".npy",
str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy",
)
)
community_labels = get_cohort_community_labels(
Expand Down Expand Up @@ -640,9 +619,7 @@ def community(
),
cohort_community_bag,
)
with open(
os.path.join(path_to_dir, "hierarchy" + ".pkl"), "wb"
) as fp: # Pickling
with open(os.path.join(path_to_dir, "hierarchy" + ".pkl"), "wb") as fp: # Pickling
pickle.dump(cohort_community_bag, fp)

# Added by Luiz - 11/10/2024
Expand All @@ -659,9 +636,7 @@ def community(

# # Work in Progress - cohort is False
else:
raise NotImplementedError(
"Community analysis for cohort=False is not supported yet."
)
raise NotImplementedError("Community analysis for cohort=False is not supported yet.")
# labels = get_labels(cfg, files, model_name, n_clusters, parametrization)
# transition_matrices = compute_transition_matrices(
# files,
Expand Down
7 changes: 1 addition & 6 deletions src/vame/analysis/generative_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,7 @@ def generative_model(
os.path.join(
path_to_file,
"",
str(n_clusters)
+ "_"
+ segmentation_algorithm
+ "_label_"
+ session
+ ".npy",
str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy",
)
)
return random_generative_samples_motif(
Expand Down
15 changes: 3 additions & 12 deletions src/vame/analysis/gif_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def create_video(
frame = frames[i]
ax2.imshow(frame, cmap=cmap_reversed)
# ax2.set_title("Motif %d,\n Community: %s" % (lbl, motifs[lbl]), fontsize=10)
fig.savefig(
os.path.join(path_to_file, "gif_frames", session + "gif_%d.png") % i
)
fig.savefig(os.path.join(path_to_file, "gif_frames", session + "gif_%d.png") % i)


def gif(
Expand Down Expand Up @@ -205,9 +203,7 @@ def gif(
random_state=cfg["random_state"],
)

latent_vector = np.load(
os.path.join(path_to_file, "", "latent_vector_" + session + ".npy")
)
latent_vector = np.load(os.path.join(path_to_file, "", "latent_vector_" + session + ".npy"))

num_points = cfg["num_points"]
if num_points > latent_vector.shape[0]:
Expand All @@ -228,12 +224,7 @@ def gif(
umap_label = np.load(
os.path.join(
path_to_file,
str(n_clusters)
+ "_"
+ segmentation_algorithm
+ "_label_"
+ session
+ ".npy",
str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy",
)
)
elif label == "community":
Expand Down
14 changes: 3 additions & 11 deletions src/vame/analysis/pose_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,9 @@ def embedd_latent_vectors(
data_sample_np = data[:, i : temp_win + i].T
data_sample_np = np.reshape(data_sample_np, (1, temp_win, num_features))
if use_gpu:
h_n = model.encoder(
torch.from_numpy(data_sample_np)
.type("torch.FloatTensor")
.cuda()
)
h_n = model.encoder(torch.from_numpy(data_sample_np).type("torch.FloatTensor").cuda())
else:
h_n = model.encoder(
torch.from_numpy(data_sample_np).type("torch.FloatTensor").to()
)
h_n = model.encoder(torch.from_numpy(data_sample_np).type("torch.FloatTensor").to())
mu, _, _ = model.lmbda(h_n)
latent_vector_list.append(mu.cpu().data.numpy())

Expand Down Expand Up @@ -406,9 +400,7 @@ def segment_session(
)

else:
logger.info(
f"\nSegmentation with {n_clusters} k-means clusters already exists for model {model_name}"
)
logger.info(f"\nSegmentation with {n_clusters} k-means clusters already exists for model {model_name}")

if os.path.exists(
os.path.join(
Expand Down
8 changes: 2 additions & 6 deletions src/vame/analysis/tree_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def hierarchy_pos(
raise TypeError("cannot use hierarchy_pos on a graph that is not a tree")
if root is None:
if isinstance(G, nx.DiGraph):
root = next(
iter(nx.topological_sort(G))
) # allows back compatibility with nx version 1.11
root = next(iter(nx.topological_sort(G))) # allows back compatibility with nx version 1.11
else:
root = random.choice(list(G.nodes))

Expand Down Expand Up @@ -121,9 +119,7 @@ def merge_func(
for i in range(n_clusters):
for j in range(n_clusters):
try:
cost = motif_norm[i] + motif_norm[j] / np.abs(
transition_matrix[i, j] + transition_matrix[j, i]
)
cost = motif_norm[i] + motif_norm[j] / np.abs(transition_matrix[i, j] + transition_matrix[j, i])
except ZeroDivisionError:
print(
"Error: Transition probabilities between motif "
Expand Down
7 changes: 1 addition & 6 deletions src/vame/analysis/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,7 @@ def visualization(
os.path.join(
path_to_file,
"",
str(n_clusters)
+ "_"
+ segmentation_algorithm
+ "_label_"
+ session
+ ".npy",
str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy",
)
)
output_figure = umap_label_vis(
Expand Down
15 changes: 3 additions & 12 deletions src/vame/analysis/videowriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,12 @@ def create_cluster_videos(
labels = np.load(
os.path.join(
path_to_file,
str(n_clusters)
+ "_"
+ segmentation_algorithm
+ "_label_"
+ session
+ ".npy",
str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy",
)
)
if flag == "community":
if cohort:
logger.info(
"Cohort community videos getting created for " + session + " ..."
)
logger.info("Cohort community videos getting created for " + session + " ...")
labels = np.load(
os.path.join(
path_to_file,
Expand All @@ -109,9 +102,7 @@ def create_cluster_videos(
)
capture = cv.VideoCapture(video_file_path)
if not capture.isOpened():
raise ValueError(
f"Video capture could not be opened. Ensure the video file is valid.\n {video_file_path}"
)
raise ValueError(f"Video capture could not be opened. Ensure the video file is valid.\n {video_file_path}")
width = capture.get(cv.CAP_PROP_FRAME_WIDTH)
height = capture.get(cv.CAP_PROP_FRAME_HEIGHT)
fps = 25 # capture.get(cv.CAP_PROP_FPS)
Expand Down
16 changes: 4 additions & 12 deletions src/vame/initialize_project/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,14 @@ def init_new_project(
for i in videos:
# Check if it is a folder
if os.path.isdir(i):
vids_in_dir = [
os.path.join(i, vp) for vp in os.listdir(i) if video_type in vp
]
vids_in_dir = [os.path.join(i, vp) for vp in os.listdir(i) if video_type in vp]
vids = vids + vids_in_dir
if len(vids_in_dir) == 0:
logger.info(f"No videos found in {i}")
logger.info(
f"Perhaps change the video_type, which is currently set to: {video_type}"
)
logger.info(f"Perhaps change the video_type, which is currently set to: {video_type}")
else:
videos = vids
logger.info(
f"{len(vids_in_dir)} videos from the directory {i} were added to the project."
)
logger.info(f"{len(vids_in_dir)} videos from the directory {i} were added to the project.")
else:
if os.path.isfile(i):
vids = vids + [i]
Expand Down Expand Up @@ -210,9 +204,7 @@ def init_new_project(

unique_num_features = list(set(num_features_list))
if len(unique_num_features) > 1:
raise ValueError(
"All pose estimation files must have the same number of features."
)
raise ValueError("All pose estimation files must have the same number of features.")

if config_kwargs is None:
config_kwargs = {}
Expand Down
15 changes: 4 additions & 11 deletions src/vame/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

class VameLogger:
LOG_FORMAT = (
"%(asctime)-15s.%(msecs)d %(levelname)-5s --- [%(threadName)s]"
" %(name)-15s : %(lineno)d : %(message)s"
"%(asctime)-15s.%(msecs)d %(levelname)-5s --- [%(threadName)s]" " %(name)-15s : %(lineno)d : %(message)s"
)
LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"

Expand All @@ -19,19 +18,15 @@ def __init__(
):
self.log_level = log_level
self.file_handler = None
logging.basicConfig(
level=log_level, format=self.LOG_FORMAT, datefmt=self.LOG_DATE_FORMAT
)
logging.basicConfig(level=log_level, format=self.LOG_FORMAT, datefmt=self.LOG_DATE_FORMAT)
self.logger = logging.getLogger(f"{base_name}")
if self.logger.hasHandlers():
self.logger.handlers.clear()

self.logger.setLevel(self.log_level)
# Stream handler for logging to stdout
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(
logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT)
)
stream_handler.setFormatter(logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT))
self.logger.addHandler(stream_handler)
self.logger.propagate = False

Expand All @@ -56,9 +51,7 @@ def add_file_handler(self, file_path: str):
f.write(f"{line_break}[LOGGING STARTED AT: {handler_datetime}]")

self.file_handler = logging.FileHandler(file_path, mode="a")
self.file_handler.setFormatter(
logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT)
)
self.file_handler.setFormatter(logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT))
self.logger.addHandler(self.file_handler)

def remove_file_handler(self):
Expand Down
Loading

0 comments on commit 6103a4f

Please sign in to comment.