Skip to content

Commit

Permalink
Merge branch 'master' of github.com:jurjen93/lofar_helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
jurjen93 committed Sep 25, 2024
2 parents 89a2b58 + 2c4e5fc commit b9371bf
Show file tree
Hide file tree
Showing 8 changed files with 825 additions and 631 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,8 @@ mergemerge.py
runs/
out/
venv/
.cache/

_cache/
grid_search/
public.spider.surfsara.nl/
13 changes: 9 additions & 4 deletions neural_networks/multi_train.sh
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
#!/bin/bash
#SBATCH --job-name=cortex_multi_node
#SBATCH --partition=gpu
#SBATCH --partition=gpu
#SBATCH --time=00:30:00
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-task=18
#SBATCH --output=out/multi_cortex%A_%a.out
#SBATCH --gpu-bind=None

set -e

cd ~/projects/lofar_helpers/neural_networks

source venv/bin/activate

module load 2023
# module load NCCL/2.18.3-GCCcore-12.3.0-CUDA-12.1.1
# module load PyTorch/2.1.2-foss-2023a-CUDA-12.1.1
# module load libjpeg-turbo/2.1.5.1-GCCcore-12.3.0
# module load torchvision/0.16.0-foss-2023a-CUDA-12.1.1
source venv/bin/activate

export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
# export NCCL_SOCKET_IFNAME='eno2np0' # Change for a100
echo "MASTER_ADDR:MASTER_PORT="${MASTER_ADDR}:${MASTER_PORT}

srun python train_nn_multi.py
NCCL_DEBUG=INFO srun python train_nn_multi.py
191 changes: 154 additions & 37 deletions neural_networks/pre_processing_for_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
import numpy as np
import torch
from astropy.io import fits
from joblib import Parallel, delayed
from joblib import Parallel, delayed, Memory
import joblib
from matplotlib.colors import SymLogNorm
from torch.utils.data import Dataset
from concurrent.futures import ThreadPoolExecutor, as_completed

cache = joblib.Memory(location="_cache", verbose=0)


def get_rms(data: np.ndarray, maskSup=1e-7):
Expand All @@ -26,7 +30,7 @@ def get_rms(data: np.ndarray, maskSup=1e-7):
m = mIn[np.abs(mIn) > maskSup]
rmsold = np.std(m)
diff = 1e-1
cut = 3.
cut = 3.0
med = np.median(m)

for i in range(10):
Expand Down Expand Up @@ -74,81 +78,124 @@ def normalize_fits(image_data: np.ndarray):

# Pre-processing
rms = get_rms(image_data)
norm = SymLogNorm(linthresh=rms * 2, linscale=2, vmin=-rms, vmax=rms * 50000, base=10)

image_data = norm(image_data)
image_data = np.clip(image_data - image_data.min(), a_min=0, a_max=1)
norm_f = SymLogNorm(
linthresh=rms * 2, linscale=2, vmin=-rms, vmax=rms * 50000, base=10
)

# make RGB image
cmap = plt.get_cmap('RdBu_r')
image_data = cmap(image_data)
image_data = np.delete(image_data, 3, 2)
image_data = norm_f(image_data)

image_data = -image_data + 1 # make the peak exist at zero
image_data = image_data - image_data.min()
image_data = np.clip(image_data, a_min=0, a_max=1)

return image_data
return image_data[..., None]


def transform_data(root_dir, classes=('continue', 'stop'), modes=('', '_val')):
def transform_data(root_dir, classes=("continue", "stop"), modes=("", "_val")):

def process_fits(fits_path):
with fits.open(fits_path) as hdul:
image_data = hdul[0].data

transformed = normalize_fits(image_data)

np.savez_compressed(fits_path.with_suffix('.npz'), transformed.astype(np.float32))
np.savez_compressed(
fits_path.with_suffix(".npz"), transformed.astype(np.float32)
)

root_dir = Path(root_dir)
assert root_dir.exists()

Parallel(n_jobs=len(os.sched_getaffinity(0)))(
delayed(process_fits)(fits_path)
for cls, mode in itertools.product(classes, modes)
for fits_path in (root_dir / (cls + mode)).glob('*.fits')
for fits_path in (root_dir / (cls + mode)).glob("*.fits")
)



class FitsDataset(Dataset):
def __init__(self, root_dir, mode='train'):
def __init__(self, root_dir, mode="train"):
"""
Args:
root_dir (string): Directory with good/bad folders in it.
"""

modes = ('train', 'val')
modes = ("train", "val")
assert mode in modes

classes = {'stop': 0, 'continue': 1}
classes = {"stop": 0, "continue": 1}

root_dir = Path(root_dir)
assert root_dir.exists(), f"'{root_dir}' doesn't exist!"

ext = '.npz'
glob_ext = '*' + ext
ext = ".npz"
glob_ext = "*" + ext

for folder in (root_dir / (cls + ('' if mode == 'train' else '_val')) for cls in classes):
assert folder.exists(), f"root folder doesn't exist, got: '{str(folder.resolve())}'"
assert len(list(folder.glob(glob_ext))) > 0, f"no '{ext}' files were found in '{str(folder.resolve())}'"
for folder in (
root_dir / (cls + ("" if mode == "train" else "_val")) for cls in classes
):
assert (
folder.exists()
), f"root folder doesn't exist, got: '{str(folder.resolve())}'"
assert (
len(list(folder.glob(glob_ext))) > 0
), f"no '{ext}' files were found in '{str(folder.resolve())}'"

# Yes this code is way overengineered. Yes I also derive pleasure from writing it :) - RJS
#
# Actual documentation:
# You want all 'self.x' variables to be non-python objects such as numpy arrays,
# otherwise you get memory leaks in the PyTorch dataloader
self.data_paths, self.labels = map(np.asarray, list(zip(*(
(str(file), val)
for cls, val in classes.items()
for file in (root_dir / (cls + ('' if mode == 'train' else '_val'))).glob(glob_ext)
))))

self.data_paths, self.labels = map(
np.asarray,
list(
zip(
*(
(str(file), val)
for cls, val in classes.items()
for file in (
root_dir / (cls + ("" if mode == "train" else "_val"))
).glob(glob_ext)
)
)
),
)

assert len(self.data_paths) > 0

sources = ", ".join(sorted([str(elem).split('/')[-1].strip(ext) for elem in self.data_paths]))
self.sources = ", ".join(
sorted([str(elem).split("/")[-1].strip(ext) for elem in self.data_paths])
)
self.mode = mode
_, counts = np.unique(self.labels, return_counts=True)
self.label_ratio = counts[0] / counts[1]
# print(f'{mode}: using the following sources: {sources}')

def compute_statistics(self, normalize):
self.mean, self.std = FitsDataset._compute_statistics(self, normalize)
return self.mean, self.std

@staticmethod
@cache.cache()
def _compute_statistics(loader, normalize, verbose=True):
if not normalize:
return torch.asarray([0]), torch.asarray([1])
if verbose:
print("Computing dataset statistics")
means = []
sums_of_squares = []
f = (lambda x: torch.log(x + 1e-10)) if normalize == 2 else lambda x: x
for i, (imgs, _) in enumerate(loader):
imgs = f(imgs)
means.append(torch.mean(imgs, dim=(1, 2)))
sums_of_squares.append((imgs**2).sum(dim=(1, 2)))

mean = torch.stack(means).mean(0)
sums_of_squares = torch.stack(sums_of_squares).sum(0)
variance = (sums_of_squares / (len(loader) * imgs.shape[1] * imgs.shape[2])) - (
mean**2
)
if verbose:
print("Finished Computing dataset statistics")
return mean, torch.sqrt(variance)

@staticmethod
def transform_data(image_data):
Expand All @@ -171,19 +218,89 @@ def __getitem__(self, idx):
npy_path = self.data_paths[idx]
label = self.labels[idx]

image_data = np.load(npy_path)['arr_0'] # there is always only one array
image_data = np.load(npy_path)["arr_0"] # there is always only one array

# Pre-processing
image_data = self.transform_data(image_data)


return image_data, label


if __name__ == '__main__':
root = f'/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data'
def make_histogram(root_dir):
root_dir = Path(root_dir)
assert root_dir.exists(), f"'{root_dir}' doesn't exist!"

ext = ".npz"
glob_ext = "*" + ext

classes = {"stop": 0, "continue": 1}
file_paths = []
for mode in ["train", "val"]:
for folder in (
root_dir / (cls + ("" if mode == "train" else "_val")) for cls in classes
):
assert (
folder.exists()
), f"root folder doesn't exist, got: '{str(folder.resolve())}'"
assert (
len(list(folder.glob(glob_ext))) > 0
), f"no '{ext}' files were found in '{str(folder.resolve())}'"

# Yes this code is way overengineered. Yes I also derive pleasure from writing it :) - RJS
#
# Actual documentation:
# You want all 'self.x' variables to be non-python objects such as numpy arrays,
# otherwise you get memory leaks in the PyTorch dataloader
file_paths += [
str(file)
for cls, _ in classes.items()
for file in (root_dir / (cls + ("" if mode == "train" else "_val"))).glob(
glob_ext
)
]

result = np.asarray(
list(
map(
lambda file_path: file_path.split("/")[-1].split("_")[0].split("-")[0],
file_paths,
)
)
)
unique, inverse, counts = np.unique(result, return_inverse=True, return_counts=True)
print(f"number of unique stations: {len(unique)}")
plt.ylabel("Count")
plt.xlabel("Number of images per station")
plt.hist(counts, bins=len(np.unique(counts)))
plt.savefig("image_count_histogram.png")

plt.hist(inverse, bins=len(unique))
plt.ylabel("Number of images")
plt.xlabel("Station")
plt.savefig("images_per_station.png")


if __name__ == "__main__":
root = f"/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data"
# transform_data(root)

# make_histogram(root)
dataset = FitsDataset(root, mode="val")
dataset.compute_statistics(normalize=1)

# dataset = FitsDataset(root, mode='train', normalize=1)
# sources = dataset.sources
# hash(dataset)
# print(sources)
# imgs, label = dataset[0]
# from PIL import Image
# plt.imshow(imgs.permute(1, 2, 0).to(torch.float32).numpy())
# plt.savefig('test.png')

# for img, label in dataset:
# print(img.shape)
# exit()

transform_data(root)
# images = np.concatenate([image.flatten() for image, label in Idat])
# print("creating hist")
# plt.hist(images)
Expand Down
71 changes: 71 additions & 0 deletions neural_networks/test_scripts/inspect_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from ..train_nn import *
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F

@lru_cache(maxsize=1)
def get_transforms():
return v2.Compose([
v2.ColorJitter(brightness=.5, hue=.3, saturation=0.1, contrast=0.1),
v2.RandomInvert(),
v2.RandomEqualize(),
v2.RandomVerticalFlip(p=0.5),
v2.RandomHorizontalFlip(p=0.5),
])

def compute_statistics(loader, normalize: int):
if not normalize:
return torch.asarray([0, 0, 0]), torch.asarray([1, 1, 1])
means = []
sums_of_squares = []
f = torch.log if normalize==2 else lambda x: x
for i, (imgs, _) in enumerate(loader):
print(i, len(loader))
imgs = imgs.to('cuda')
imgs = f(imgs)
means.append(torch.mean(imgs, dim=(0, 2, 3)))
sums_of_squares.append((imgs**2).sum(dim=(0, 2, 3)))
mean = torch.stack(means).mean(0)
sums_of_squares = torch.stack(sums_of_squares).sum(0)
variance = (sums_of_squares / (len(loader) * imgs.shape[0] * imgs.shape[2] * imgs.shape[3])) - (mean ** 2)
return mean, torch.sqrt(variance)

def plot_image(img, fname):
img = img[0].cpu().permute(1, 2, 0).to(torch.float32).numpy()
plt.imshow((img - np.min(img))/(np.max(img)-np.min(img)))
plt.savefig(fname)

if __name__=='__main__':
dataset_root = 'public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data'
batch_size = 32
normalize = 2
train_dataloader, val_loader = get_dataloaders(dataset_root, batch_size)
print('computing_statistics')
mean, std = compute_statistics(train_dataloader, normalize=normalize)
print('done_with_statistics')

prepare_data_f = partial(prepare_data, resize=0, device='cuda', mean=mean, std=std, normalize=normalize)

output_folder = 'image_samples/'
os.makedirs(output_folder, exist_ok=True)
data, labels = next(iter(val_loader))

data_norm, labels = prepare_data_f(data, labels)
plot_image(data_norm, f"{output_folder}/normalize_{normalize}")
transforms = {'brightness': [v2.ColorJitter(brightness=0.5)],
'hue': [v2.ColorJitter(hue=0.3)],
'saturation': [v2.ColorJitter(saturation=0.1)],
'contrast': [v2.ColorJitter(contrast=0.1)],
'colorjitter': [v2.ColorJitter(brightness=0.5, hue=0.3, saturation=0.1, contrast=0.1)],
'invert': [v2.RandomInvert(1)],
'equalize': [v2.RandomEqualize(1)],
'all': [v2.ColorJitter(brightness=0.5, hue=0.3, saturation=0.1, contrast=0.1), v2.RandomInvert(1),v2.RandomEqualize(1)]}
for t_name, transformations in transforms.items():
transform_f = v2.Compose(transformations)
data_transformed = transform_f(data_norm)
plot_image(data_transformed, f"{output_folder}/normalize_{normalize}_{t_name}")

print('done')
exit()



File renamed without changes.
Loading

0 comments on commit b9371bf

Please sign in to comment.