-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add option to finetune pos embedding, slightly more general way to sa…
…ve model params for loading in cortexchange. Skip non 2k images in data processing
- Loading branch information
Showing
7 changed files
with
215 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from cortexchange.wdclient import init_downloader | ||
|
||
init_downloader( | ||
url="https://researchdrive.surfsara.nl/public.php/webdav/", | ||
login="JeofximLVcr8Ttm", | ||
password="?CortexAdminTest1?", | ||
cache="/home/larsve/.cache/cortexchange", | ||
# cache=".cache/cortexchange", | ||
) | ||
|
||
from cortexchange.architecture import get_architecture, Architecture | ||
|
||
TransferLearning: type(Architecture) = get_architecture("surf/TransferLearning") | ||
model = TransferLearning(device="cpu", model_name="surf/dinov2_october_09902_lora") | ||
|
||
# torch_tensor = model.prepare_data( | ||
# "ILTJ160454.72+555949.7_selfcal/selfcal_007-MFS-image.fits" | ||
# ) | ||
torch_tensor = model.prepare_data( | ||
"/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/stop/ILTJ142906.77+334820.3_image_009-MFS-image.fits" | ||
) | ||
print(torch_tensor.shape) | ||
result = model.predict(torch_tensor) | ||
print(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#!/bin/bash | ||
#SBATCH --job-name=cortex_grid_search | ||
#SBATCH -p gpu | ||
#SBATCH -t 08:00:00 | ||
#SBATCH --gpus 1 | ||
#SBATCH --output=out/multi_cortex%A_%a.out | ||
|
||
set -e | ||
|
||
cd ~/projects/lofar_helpers/neural_networks | ||
|
||
|
||
module load 2023 | ||
source ../../lofar_venv/bin/activate | ||
|
||
# Read the parameter file | ||
PARAM_FILE=parameters.txt | ||
|
||
# Set default value for SLURM_ARRAY_TASK_ID | ||
SLURM_ARRAY_TASK_ID=${SLURM_ARRAY_TASK_ID:=1} | ||
# Extract the specific line corresponding to the SLURM_ARRAY_TASK_ID | ||
PARAMS=$(sed -n "${SLURM_ARRAY_TASK_ID}p" $PARAM_FILE) | ||
|
||
# Parse the parameters | ||
read model lr normalize dropout_p batch_size label_smoothing stochastic_smoothing use_lora rank alpha resize lift flip_augmentations <<< $PARAMS | ||
|
||
if [ "$use_lora" -eq 1 ]; then | ||
LORA_ARG="--use_lora" | ||
else | ||
LORA_ARG="" | ||
fi | ||
|
||
if [ "$stochastic_smoothing" -eq 1 ]; then | ||
STOCHASTIC_SMOOTHING="--stochastic_smoothing" | ||
else | ||
STOCHASTIC_SMOOTHING="" | ||
fi | ||
|
||
if [ "$flip_augmentations" -eq 1 ]; then | ||
FLIP_AUGMENTATIONS="--flip_augmentations" | ||
else | ||
FLIP_AUGMENTATIONS="" | ||
fi | ||
|
||
# Scale up by 1e6 to convert to integers for comparison | ||
scaled_lr=$(echo "$lr * 1000000" | awk '{printf("%d", $1)}') | ||
scaled_threshold=$(echo "4e-05 * 1000000" | awk '{printf("%d", $1)}') | ||
|
||
if [ "$scaled_lr" -le "$scaled_threshold" ]; then | ||
EPOCHS="250" | ||
else | ||
EPOCHS="120" | ||
fi | ||
|
||
DATA_INPUT_PATH="/scratch-shared/CORTEX/" | ||
# find $DATA_INPUT_PATH -name '*npz' | xargs -n 1 -P 18 -i rsync -R {} '/dev/shm/' | ||
|
||
DATA_TRAINDATA_PATH="/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/" | ||
|
||
|
||
# Execute your Python script with the given parameters | ||
echo $DATA_TRAINDATA_PATH --model $model --lr $lr --normalize $normalize --dropout_p $dropout_p --batch_size $batch_size --log_path grid_search_lora --label_smoothing $label_smoothing --rank $rank --resize $resize --alpha $alpha $LORA_ARG $STOCHASTIC_SMOOTHING -d --epochs $EPOCHS --lift $lift $FLIP_AUGMENTATIONS | ||
python train_nn.py $DATA_TRAINDATA_PATH --model $model --lr $lr --normalize $normalize --dropout_p $dropout_p --batch_size $batch_size --log_path grid_search_lora --label_smoothing $label_smoothing --rank $rank --resize $resize --alpha $alpha $LORA_ARG $STOCHASTIC_SMOOTHING -d --epochs $EPOCHS --lift $lift $FLIP_AUGMENTATIONS |