diff --git a/code/ML/segmentation_train.py b/code/ML/segmentation_train.py index 0a6c921..f3489c3 100644 --- a/code/ML/segmentation_train.py +++ b/code/ML/segmentation_train.py @@ -9,16 +9,19 @@ from fastai.vision.all import * import numpy as np import pathlib +import os +#make a directory called models_out +os.mkdir('models_out') -root = pathlib.Path('~/data/ukv/train/') -root = pathlib.Path('C:/Users/mm16jdc/Documents/GitHub/LeeWaveNet/data/ukv/train/') + +root = pathlib.Path('~/data/train/') codes={0:'no wave',255:'lee wave'} def label_func(fn): string = str(fn.stem)[:49]+"mask.png" - return root/"mask_png"/string + return root/"masks_png"/string def open_xarray(fname): x = xr.open_dataarray(fname) @@ -38,9 +41,9 @@ def open_xarray(fname): splitter=RandomSplitter(), batch_tfms=tfms, ) -dsets = waves_ds.datasets(root/'700hPa') +dsets = waves_ds.datasets(root/'vertical_velocities') -dls = waves_ds.dataloaders(root/"700hPa", path=root, bs=4) +dls = waves_ds.dataloaders(root/"vertical_velocities", path=root, bs=4) learn2 = unet_learner(dls,resnet34,metrics=DiceMulti) learn2.fine_tune(100,cbs=EarlyStoppingCallback(monitor='valid_loss', patience=5))