Skip to content

Commit

Permalink
Update segmentation_train.py
Browse files Browse the repository at this point in the history
tweaking to allow support for png mask files
  • Loading branch information
jdconey committed Dec 1, 2023
1 parent 703da85 commit 3f30647
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions code/ML/segmentation_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
from fastai.vision.all import *
import numpy as np
import pathlib
import os

#make a directory called models_out
os.mkdir('models_out')

root = pathlib.Path('~/data/ukv/train/')
root = pathlib.Path('C:/Users/mm16jdc/Documents/GitHub/LeeWaveNet/data/ukv/train/')

root = pathlib.Path('~/data/train/')

codes={0:'no wave',255:'lee wave'}

def label_func(fn):
string = str(fn.stem)[:49]+"mask.png"
return root/"mask_png"/string
return root/"masks_png"/string

def open_xarray(fname):
x = xr.open_dataarray(fname)
Expand All @@ -38,9 +41,9 @@ def open_xarray(fname):
splitter=RandomSplitter(),
batch_tfms=tfms,
)
dsets = waves_ds.datasets(root/'700hPa')
dsets = waves_ds.datasets(root/'vertical_velocities')

dls = waves_ds.dataloaders(root/"700hPa", path=root, bs=4)
dls = waves_ds.dataloaders(root/"vertical_velocities", path=root, bs=4)

learn2 = unet_learner(dls,resnet34,metrics=DiceMulti)
learn2.fine_tune(100,cbs=EarlyStoppingCallback(monitor='valid_loss', patience=5))
Expand Down

0 comments on commit 3f30647

Please sign in to comment.