Skip to content

Commit

Permalink
Add option to finetune pos embedding, slightly more general way to sa…
Browse files Browse the repository at this point in the history
…ve model params for loading in cortexchange. Skip non 2k images in data processing
  • Loading branch information
LVeefkind committed Nov 20, 2024
1 parent 1f346a6 commit a71f66f
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 19 deletions.
6 changes: 6 additions & 0 deletions neural_networks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ python pre_processing_for_ml.py <path-to-fits-dir>
```

## (Optional) Copy files to /dev/shm for fast dataloading
Copy data
```shell
find <path-to-fits-dir> -type f -name "*.npz" | xargs -n 1 -P 8 -i rsync -R {} /dev/shm
```
Copy cached dataset statistics
```shell
find <path-to-fits-dir> -type d -name "_cache" | xargs -n 1 -P 8 -I {} rsync -aR {}/ /dev/shm

```

## Run neural network training
Expand Down
4 changes: 2 additions & 2 deletions neural_networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def prepare_batch(self, batch: torch.Tensor, mean=None, std=None) -> torch.Tenso
batch, size=self.resize, mode="bilinear", align_corners=False
)
if mean is None:
mean = self.mean
mean = self.args["dataset_mean"]
if std is None:
std = self.std
std = self.args["dataset_std"]
batch = normalize_inputs(batch, mean, std, normalize=1)
return batch

Expand Down
55 changes: 55 additions & 0 deletions neural_networks/parameters.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,58 @@ efficientnet_v2_l 5e-05 1 0.2 32 0.3 0 0 16 16 0 stack 0
efficientnet_v2_l 5e-05 1 0.25 32 0.1 0 0 16 16 0 stack 0
efficientnet_v2_l 5e-05 1 0.25 32 0.2 0 0 16 16 0 stack 0
efficientnet_v2_l 5e-05 1 0.25 32 0.3 0 0 16 16 0 stack 0


dinov2_vitl14_reg 0.0001 1 0.1 32 0.1 0 1 16 16 560 stack 0
dinov2_vitl14_reg 0.0001 1 0.1 32 0.1 0 1 32 32 560 stack 0
dinov2_vitl14_reg 0.0001 1 0.1 32 0.1 0 1 64 64 560 stack 0
dinov2_vitl14_reg 0.0001 1 0.1 32 0.1 0 1 128 128 560 stack 0
dinov2_vitl14_reg 0.0001 1 0.1 32 0.1 0 1 256 256 560 stack 0
dinov2_vitb14_reg 0.0001 1 0.1 32 0.1 0 1 16 16 560 stack 0
dinov2_vitb14_reg 0.0001 1 0.1 32 0.1 0 1 32 32 560 stack 0
dinov2_vitb14_reg 0.0001 1 0.1 32 0.1 0 1 64 64 560 stack 0
dinov2_vitb14_reg 0.0001 1 0.1 32 0.1 0 1 128 128 560 stack 0
dinov2_vitb14_reg 0.0001 1 0.1 32 0.1 0 1 256 256 560 stack 0
dinov2_vits14_reg 0.0001 1 0.1 32 0.1 0 1 16 16 560 stack 0
dinov2_vits14_reg 0.0001 1 0.1 32 0.1 0 1 32 32 560 stack 0
dinov2_vits14_reg 0.0001 1 0.1 32 0.1 0 1 64 64 560 stack 0
dinov2_vits14_reg 0.0001 1 0.1 32 0.1 0 1 128 128 560 stack 0
dinov2_vits14_reg 0.0001 1 0.1 32 0.1 0 1 256 256 560 stack 0

dinov2_vitl14_reg 0.00001 1 0.1 32 0.1 0 1 16 16 560 stack 0
dinov2_vitl14_reg 0.00001 1 0.1 32 0.1 0 1 32 32 560 stack 0
dinov2_vitl14_reg 0.00001 1 0.1 32 0.1 0 1 64 64 560 stack 0
dinov2_vitl14_reg 0.00001 1 0.1 32 0.1 0 1 128 128 560 stack 0
dinov2_vitl14_reg 0.00001 1 0.1 32 0.1 0 1 256 256 560 stack 0
dinov2_vitb14_reg 0.00001 1 0.1 32 0.1 0 1 16 16 560 stack 0
dinov2_vitb14_reg 0.00001 1 0.1 32 0.1 0 1 32 32 560 stack 0
dinov2_vitb14_reg 0.00001 1 0.1 32 0.1 0 1 64 64 560 stack 0
dinov2_vitb14_reg 0.00001 1 0.1 32 0.1 0 1 128 128 560 stack 0
dinov2_vitb14_reg 0.00001 1 0.1 32 0.1 0 1 256 256 560 stack 0
dinov2_vits14_reg 0.00001 1 0.1 32 0.1 0 1 16 16 560 stack 0
dinov2_vits14_reg 0.00001 1 0.1 32 0.1 0 1 32 32 560 stack 0
dinov2_vits14_reg 0.00001 1 0.1 32 0.1 0 1 64 64 560 stack 0
dinov2_vits14_reg 0.00001 1 0.1 32 0.1 0 1 128 128 560 stack 0
dinov2_vits14_reg 0.00001 1 0.1 32 0.1 0 1 256 256 560 stack 0
dinov2_vitl14_reg 0.00005 1 0.1 32 0.1 0 1 16 16 560 stack 0
dinov2_vitl14_reg 0.00005 1 0.1 32 0.1 0 1 32 32 560 stack 0
dinov2_vitl14_reg 0.00005 1 0.1 32 0.1 0 1 64 64 560 stack 0
dinov2_vitl14_reg 0.00005 1 0.1 32 0.1 0 1 128 128 560 stack 0
dinov2_vitl14_reg 0.00005 1 0.1 32 0.1 0 1 256 256 560 stack 0
dinov2_vitb14_reg 0.00005 1 0.1 32 0.1 0 1 16 16 560 stack 0
dinov2_vitb14_reg 0.00005 1 0.1 32 0.1 0 1 32 32 560 stack 0
dinov2_vitb14_reg 0.00005 1 0.1 32 0.1 0 1 64 64 560 stack 0
dinov2_vitb14_reg 0.00005 1 0.1 32 0.1 0 1 128 128 560 stack 0
dinov2_vitb14_reg 0.00005 1 0.1 32 0.1 0 1 256 256 560 stack 0
dinov2_vits14_reg 0.00005 1 0.1 32 0.1 0 1 16 16 560 stack 0
dinov2_vits14_reg 0.00005 1 0.1 32 0.1 0 1 32 32 560 stack 0
dinov2_vits14_reg 0.00005 1 0.1 32 0.1 0 1 64 64 560 stack 0
dinov2_vits14_reg 0.00005 1 0.1 32 0.1 0 1 128 128 560 stack 0
dinov2_vits14_reg 0.00005 1 0.1 32 0.1 0 1 256 256 560 stack 0

dinov2_vitb14_reg 0.00001 1 0.1 32 0.1 0 1 32 32 784 stack 0
dinov2_vitb14_reg 0.00001 1 0.1 32 0.1 0 1 64 64 784 stack 0

dinov2_vitb14_reg 0.00001 1 0.1 32 0.1 0 0 32 32 784 stack 0
dinov2_vitb14_reg 0.00005 1 0.1 32 0.1 0 0 32 32 784 stack 0
dinov2_vitb14_reg 0.0001 1 0.1 32 0.1 0 0 32 32 784 stack 0
27 changes: 18 additions & 9 deletions neural_networks/pre_processing_for_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,15 @@ def transform_data(root_dir, classes=("continue", "stop"), modes=("", "_val")):
def process_fits(fits_path):
with fits.open(fits_path) as hdul:
image_data = hdul[0].data
# assert image_data.shape[2] == 2048, (image_data.shape, image_data.shape[2], fits_path)
transformed = normalize_fits(image_data)
if image_data.shape[2] == 2048:
transformed = normalize_fits(image_data)

np.savez_compressed(
fits_path.with_suffix(".npz"), transformed.astype(np.float32)
)
np.savez_compressed(
fits_path.with_suffix(".npz"), transformed.astype(np.float32)
)
else:
print(f"Skipping {fits_path}. Improper image size {image_data.shape}")
# print(fits_path)

root_dir = Path(root_dir)
assert root_dir.exists()
Expand Down Expand Up @@ -174,6 +177,13 @@ def compute_statistics(self, normalize):
self.mean, self.std = cached_compute(self, normalize)
return self.mean, self.std

# For caching
def __reduce__(self):
return (
self.__class__,
(self.labels, self.mode, self.label_ratio, self.sources),
)

@staticmethod
def _compute_statistics(loader, normalize, verbose=True):
if verbose:
Expand Down Expand Up @@ -218,7 +228,6 @@ def __getitem__(self, idx):

npy_path = self.data_paths[idx]
label = self.labels[idx]

image_data = np.load(npy_path)["arr_0"] # there is always only one array

# Pre-processing
Expand Down Expand Up @@ -288,7 +297,7 @@ def make_histogram(root_dir):
# make_histogram(root)
# dataset.compute_statistics(normalize=1)

# dataset = FitsDataset(root, mode='train', normalize=1)
# dataset = FitsDataset(root, mode="train")
# sources = dataset.sources
# hash(dataset)
# print(sources)
Expand All @@ -298,8 +307,8 @@ def make_histogram(root_dir):
# plt.savefig('test.png')

# for img, label in dataset:
# print(img.shape)
# exit()
# if img.shape[2] != 2048:
# exit()

# images = np.concatenate([image.flatten() for image, label in Idat])
# print("creating hist")
Expand Down
24 changes: 24 additions & 0 deletions neural_networks/test_scripts/test_cortexchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from cortexchange.wdclient import init_downloader

init_downloader(
url="https://researchdrive.surfsara.nl/public.php/webdav/",
login="JeofximLVcr8Ttm",
password="?CortexAdminTest1?",
cache="/home/larsve/.cache/cortexchange",
# cache=".cache/cortexchange",
)

from cortexchange.architecture import get_architecture, Architecture

TransferLearning: type(Architecture) = get_architecture("surf/TransferLearning")
model = TransferLearning(device="cpu", model_name="surf/dinov2_october_09902_lora")

# torch_tensor = model.prepare_data(
# "ILTJ160454.72+555949.7_selfcal/selfcal_007-MFS-image.fits"
# )
torch_tensor = model.prepare_data(
"/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/stop/ILTJ142906.77+334820.3_image_009-MFS-image.fits"
)
print(torch_tensor.shape)
result = model.predict(torch_tensor)
print(result)
55 changes: 47 additions & 8 deletions neural_networks/train_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ def __init__(
use_compile: bool = True,
lift: str = "stack",
use_lora: bool = False,
alpha=16,
rank=16,
alpha: float = 16.0,
rank: int = 16,
tune_pos_embed: bool = False,
):
super().__init__()

Expand All @@ -167,6 +168,9 @@ def __init__(
"use_compile": use_compile,
"lift": lift,
"use_lora": use_lora,
"rank": rank,
"alpha": alpha,
"tune_pos_embed": tune_pos_embed,
}

if lift == "stack":
Expand Down Expand Up @@ -242,6 +246,7 @@ def eval(self):
self.dino.eval()
else:
self.dino.decoder.eval()
self.dino.encoder.pos_embed.requires_grad = False
else:
self.classifier.eval()

Expand All @@ -253,12 +258,14 @@ def train(self):
self.dino.train()
else:
self.dino.decoder.train()
# Finetune learnable pos_embedding
self.dino.encoder.pos_embed.requires_grad = self.kwargs["tune_pos_embed"]
else:
self.classifier.train()


def get_dataloaders(dataset_root, batch_size):
num_workers = min(12, len(os.sched_getaffinity(0)))
num_workers = min(18, len(os.sched_getaffinity(0)))

prefetch_factor, persistent_workers = (
(2, True) if num_workers > 0 else (None, False)
Expand Down Expand Up @@ -382,6 +389,7 @@ def main(
log_path: Path = "runs",
epochs: int = 120,
flip_augmentations: bool = False,
tune_pos_embed: bool = False,
):
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -430,6 +438,7 @@ def main(
use_lora=use_lora,
alpha=alpha,
rank=rank,
tune_pos_embed=tune_pos_embed,
)

# noinspection PyArgumentList
Expand Down Expand Up @@ -500,6 +509,16 @@ def main(
"dataset_mean": mean,
"dataset_std": std,
},
model_args={
"model_name": model_name,
"use_compile": use_compile,
"lift": lift,
"use_lora": use_lora,
"rank": rank,
"alpha": alpha,
"dropout_p": dropout_p,
"tune_pos_embed": tune_pos_embed,
},
)

best_val_loss = torch.inf
Expand Down Expand Up @@ -742,12 +761,26 @@ def load_checkpoint(ckpt_path, device="cuda"):

# strip 'model_' from the name
model_name = ckpt_dict["args"]["model_name"]
lr = ckpt_dict["args"]["lr"]
dropout_p = ckpt_dict["args"]["dropout_p"]

model = ckpt_dict["model"](model_name=model_name, dropout_p=dropout_p).to(device)
if "model_args" in ckpt_dict["args"]:
model = ckpt_dict["model"](**ckpt_dict["model_args"]).to(device)
else:
dropout_p = ckpt_dict["args"]["dropout_p"]
use_lora = ckpt_dict["args"]["use_lora"]
rank = ckpt_dict["args"]["rank"]
alpha = ckpt_dict["args"]["alpha"]
lift = ckpt_dict["args"]["lift"]
model_name = ckpt_dict["args"]["model_name"]

model = ckpt_dict["model"](
model_name=model_name,
dropout_p=dropout_p,
use_lora=use_lora,
lift=lift,
alpha=alpha,
rank=rank,
).to(device)
model.load_state_dict(ckpt_dict["model_state_dict"])

lr = ckpt_dict["args"]["lr"]
try:
# FIXME: add optim class and args to state dict
optim = ckpt_dict.get("optimizer", torch.optim.AdamW)(
Expand Down Expand Up @@ -860,6 +893,12 @@ def get_argparser():
help="Whether to use LoRA if applicable.",
)

parser.add_argument(
"--tune_pos_embed",
action="store_true",
help="Whether to fine-tune the positional embedding if applicable",
)

parser.add_argument("--rank", type=int, default=16, help="rank of LoRA")

parser.add_argument(
Expand Down
63 changes: 63 additions & 0 deletions neural_networks/train_nn_lora.job
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/bin/bash
#SBATCH --job-name=cortex_grid_search
#SBATCH -p gpu
#SBATCH -t 08:00:00
#SBATCH --gpus 1
#SBATCH --output=out/multi_cortex%A_%a.out

set -e

cd ~/projects/lofar_helpers/neural_networks


module load 2023
source ../../lofar_venv/bin/activate

# Read the parameter file
PARAM_FILE=parameters.txt

# Set default value for SLURM_ARRAY_TASK_ID
SLURM_ARRAY_TASK_ID=${SLURM_ARRAY_TASK_ID:=1}
# Extract the specific line corresponding to the SLURM_ARRAY_TASK_ID
PARAMS=$(sed -n "${SLURM_ARRAY_TASK_ID}p" $PARAM_FILE)

# Parse the parameters
read model lr normalize dropout_p batch_size label_smoothing stochastic_smoothing use_lora rank alpha resize lift flip_augmentations <<< $PARAMS

if [ "$use_lora" -eq 1 ]; then
LORA_ARG="--use_lora"
else
LORA_ARG=""
fi

if [ "$stochastic_smoothing" -eq 1 ]; then
STOCHASTIC_SMOOTHING="--stochastic_smoothing"
else
STOCHASTIC_SMOOTHING=""
fi

if [ "$flip_augmentations" -eq 1 ]; then
FLIP_AUGMENTATIONS="--flip_augmentations"
else
FLIP_AUGMENTATIONS=""
fi

# Scale up by 1e6 to convert to integers for comparison
scaled_lr=$(echo "$lr * 1000000" | awk '{printf("%d", $1)}')
scaled_threshold=$(echo "4e-05 * 1000000" | awk '{printf("%d", $1)}')

if [ "$scaled_lr" -le "$scaled_threshold" ]; then
EPOCHS="250"
else
EPOCHS="120"
fi

DATA_INPUT_PATH="/scratch-shared/CORTEX/"
# find $DATA_INPUT_PATH -name '*npz' | xargs -n 1 -P 18 -i rsync -R {} '/dev/shm/'

DATA_TRAINDATA_PATH="/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/"


# Execute your Python script with the given parameters
echo $DATA_TRAINDATA_PATH --model $model --lr $lr --normalize $normalize --dropout_p $dropout_p --batch_size $batch_size --log_path grid_search_lora --label_smoothing $label_smoothing --rank $rank --resize $resize --alpha $alpha $LORA_ARG $STOCHASTIC_SMOOTHING -d --epochs $EPOCHS --lift $lift $FLIP_AUGMENTATIONS
python train_nn.py $DATA_TRAINDATA_PATH --model $model --lr $lr --normalize $normalize --dropout_p $dropout_p --batch_size $batch_size --log_path grid_search_lora --label_smoothing $label_smoothing --rank $rank --resize $resize --alpha $alpha $LORA_ARG $STOCHASTIC_SMOOTHING -d --epochs $EPOCHS --lift $lift $FLIP_AUGMENTATIONS

0 comments on commit a71f66f

Please sign in to comment.