-
Notifications
You must be signed in to change notification settings - Fork 3
/
run_test_dataset.py
134 lines (110 loc) · 5.88 KB
/
run_test_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import subprocess
import shutil
import os
import jax
import sys
if len(sys.argv) > 1:
do_all_tests = True
else:
do_all_tests = False
if len(sys.argv) > 2:
delete_everything = False
else:
delete_everything = True
RECOVAR_PATH = './'
passed_functions = []
failed_functions = []
def error_message():
print("--------------------------------------------")
print("--------------------------------------------")
print("No GPU devices found by JAX. Please ensure that JAX is properly configured with CUDA and a compatible GPU. Some info from the JAX website: (https://jax.readthedocs.io/en/latest/installation.html): \n You must first install the NVIDIA driver. You’re recommended to install the newest driver available from NVIDIA, but the driver version must be >= 525.60.13 for CUDA 12 on Linux. \n The info below is outdated, and this problem should not happen anymore with the newest versions of JAX but I am leaving it in case it could be useful.... \n Typically, the problem was during the installation of JAX which could not find the CUDA libraries. Sometimes, this can be fixed by setting the correct paths to the CUDA libraries in the environment variables, or module load depending on your system. Note that you may have to reinstall JAX after setting the correct paths. E.g. run the following:")
print("pip uninstall jax jaxlib; pip install -U \"jax[cuda12_pip]\"==0.4.23 -f https://storage.googleapis.com/jax-releases/" )
print("--------------------------------------------")
print("--------------------------------------------")
exit(1)
# Check if JAX can find a GPU device
def check_gpu():
try:
gpu_devices = jax.devices('gpu')
if gpu_devices:
print("GPU devices found:", gpu_devices)
else:
error_message()
except Exception as e:
print("Error occurred while checking for GPU devices:", e)
error_message()
# Check for GPU availability
check_gpu()
def run_command(command, description, function_name):
print(f"Running: {description}")
print(f"Command: {command}\n")
result = subprocess.run(command, shell=True)
if result.returncode == 0:
print(f"Success: {description}\n")
passed_functions.append(function_name)
else:
print(f"Failed: {description}\n")
failed_functions.append(function_name)
# Generate a small test dataset - should take about 30 sec
run_command(
f'python {RECOVAR_PATH}/make_test_dataset.py',
'Generate a small test dataset',
'make_test_dataset.py'
)
# Run pipeline, should take about 2 min
run_command(
f'python {RECOVAR_PATH}/pipeline.py test_dataset/particles.64.mrcs --poses test_dataset/poses.pkl --ctf test_dataset/ctf.pkl --correct-contrast -o test_dataset/pipeline_output --mask=from_halfmaps --lazy',
'Run pipeline',
'pipeline.py'
)
# Run analyze.py with 2D embedding and no regularization on latent space (better for density estimation)
# Should take about 5 min
run_command(
f'python {RECOVAR_PATH}/analyze.py test_dataset/pipeline_output --zdim=2 --no-z-regularization --n-clusters=3 --n-trajectories=0',
'Run analyze.py',
'analyze.py'
)
# Estimate conformational density
run_command(
f'python {RECOVAR_PATH}/estimate_conformational_density.py test_dataset/pipeline_output --pca_dim 2',
'Estimate conformational density',
'estimate_conformational_density.py'
)
if do_all_tests:
# Run analyze.py with 2D embedding and no regularization on latent space (better for density estimation) and trajectory estimation
# Should take about 5 min
run_command(
f'python {RECOVAR_PATH}/analyze.py test_dataset/pipeline_output --zdim=2 --no-z-regularization --n-clusters=3 --n-trajectories=1 --density test_dataset/pipeline_output/density/deconv_density_knee.pkl --skip-centers',
'Run analyze.py with density',
'analyze.py'
)
# Compute trajectory - option 1
run_command(
f'python {RECOVAR_PATH}/compute_trajectory.py test_dataset/pipeline_output -o test_dataset/pipeline_output/trajectory1 --endpts test_dataset/pipeline_output/analysis_2_noreg/kmeans_center_coords.txt --ind=0,1 --density test_dataset/pipeline_output/density/deconv_density_knee.pkl --zdim=2 --n-vols-along-path=3',
'Compute trajectory - option 1',
'compute_trajectory.py (option 1)'
)
# Compute trajectory - option 2
run_command(
f'python {RECOVAR_PATH}/compute_trajectory.py test_dataset/pipeline_output -o test_dataset/pipeline_output/trajectory2 --z_st test_dataset/pipeline_output/analysis_2_noreg/kmeans_center_volumes/vol0000/latent_coords.txt --z_end test_dataset/pipeline_output/analysis_2_noreg/kmeans_center_volumes/vol0002/latent_coords.txt --density test_dataset/pipeline_output/density/deconv_density_knee.pkl --zdim=2 --n-vols-along-path=0',
'Compute trajectory - option 2',
'compute_trajectory.py (option 2)'
)
run_command(
f'python {RECOVAR_PATH}/estimate_stable_states.py test_dataset/pipeline_output/density/all_densities/deconv_density_1.pkl --percent_top=10 --n_local_maxs=-1 -o test_dataset/pipeline_output/stable_states',
'estimate stable states',
'estimate_stable_states.py'
)
if failed_functions:
print("The following functions failed:")
for func in failed_functions:
print(f"- {func}")
print("\nPlease check the output above for details.")
else:
print("All functions completed successfully!")
# Delete the test_dataset directory since all steps passed
if delete_everything and os.path.exists('test_dataset'):
shutil.rmtree('test_dataset')
print("Test dataset directory 'test_dataset' has been deleted.")
# One way to make sure everything went well is that the states in test_dataset/pipeline_output/output/analysis_2_noreg/kmeans_center_volumes/all_volumes
# should be similar to the simulated ones in recovar/data/vol*.mrc (the order doesn't matter, though).