Skip to content

Commit

Permalink
remove Black Sea 200k data, waiting on approval of Zenodo repository
Browse files Browse the repository at this point in the history
  • Loading branch information
Caio Stringari committed Feb 1, 2021
1 parent ea6e2f3 commit 8e69df9
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 6 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ conda env create -f environment_linux.yml
| ------------------------- | ------------------------------------------------------------------------------------------------------------------------ | ---------------- |
| **Train (10k)** | [![](badges/google_drive_badge.svg)](https://drive.google.com/file/d/1Qko68JTZT-JLHKwSJJvvKUQEjmcy0V0j/view?usp=sharing) | - |
| **Train (20k)** | [![](badges/google_drive_badge.svg)](https://drive.google.com/file/d/1uUcSW5s_jm5W-AQeeNxJKbIr6CR5fJIP/view?usp=sharing) | - |
| **Test (1k)** | Upcoming | - |
| **Black Sea (200k)** | Upcoming | - |
| **Test (1k)** | [![](badges/google_drive_badge.svg)](https://drive.google.com/file/d/1A6IK9IQjFN9JMNx3bUkcWdlO8YN8PbaC/view?usp=sharing) | - |
| **Black Sea (200k)** | **Upcoming** | - |
| **La Jument 2019 (10k)** | **Upcoming** | - |

## 3. Training
Expand Down
2 changes: 1 addition & 1 deletion segmentation/extract_detections_by_labelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
mk1 = np.zeros(frame.shape)
mk1[df_frm["i"].values, df_frm["j"].values] = 1

# lalbe the mask
# label the mask
lbl = label(mk1, connectivity=connectivity)

# get region properties
Expand Down
118 changes: 118 additions & 0 deletions segmentation/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Use a pre-trained segmentation model. Make sure your input is 256x256.
PROGRAM : predict.py
POURPOSE : Get the regions in an image where waves are actively breaking
AUTHOR : Caio Eadi Stringari
EMAIL : [email protected]
V2.0 : 06/10/2020 [Caio Stringari]
"""

import os
import argparse

from glob import glob
from natsort import natsorted

import numpy as np

from skimage.io import imread
from skimage.color import grey2rgb

import tensorflow as tf

# progress bar
from tqdm import tqdm

# quite skimage warnings
import warnings

# plot
import matplotlib.pyplot as plt

tf.get_logger().setLevel('INFO')
warnings.filterwarnings("ignore")


def display_mask(val_preds, i):
"""Display a model's prediction."""
mask = np.argmax(val_preds[i], axis=-1)
mask = np.expand_dims(mask, axis=-1)
return mask


def main():
"""Call the main program."""
# i/o
model = args.model[0] # pre-trained model
inp_data = args.input[0] # frames to be segmented
out_data = args.output[0] # output csv file

# create output
os.makedirs(out_data, exist_ok=True)

# load the model
M = tf.keras.models.load_model(model)

# verify if the input path exists,
# if it does, then get the frame names
if os.path.isdir(inp_data):
images = natsorted(glob(inp_data + "/*"))
else:
raise IOError("No such file or directory \"{}\"".format(inp_data))

# --- loop over frames ---
pbar = tqdm(total=len(images))

for k, image in enumerate(images):

# print("-- plotting frame {} of {}".format(k+1, total_frames), end="\r")

# load image
img = grey2rgb(imread(image))

# predict
pred = M.predict(np.expand_dims(img/255, axis=0)) # very important to normalize your data !
prd = np.squeeze(np.argmax(pred, axis=-1))

# plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6),
sharex=True, sharey=True)
ax1.imshow(np.squeeze(img))
ax2.imshow(np.squeeze(prd))
fig.tight_layout()
plt.savefig(os.path.join(out_data, str(k).zfill(6) + ".png"),
pad_inches=0.1, bbox_inches='tight')
plt.close()

pbar.update()


if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Predict active wave breaking segmentation')

parser.add_argument('--model', "-M",
nargs=1,
dest='model',
help='pre-trained model in .h5 format',
required=True,
action='store')

parser.add_argument("--input", "-i", "--frames", "-frames",
nargs=1,
action="store",
dest="input",
required=True,
help="Input path with data.",)

parser.add_argument("--output", "-o",
nargs=1,
action="store",
dest="output",
required=True,
help="Output path.",)

args = parser.parse_args()

main()
9 changes: 6 additions & 3 deletions segmentation/predict_on_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,12 @@ def display_mask(val_preds, i):
def make_plot(img, df, roi_patch=False, total_frames=-1, out_path="plt",
block_shape=[256, 256]):
"""Plot the results."""
# crate a mask
brk_mask = np.zeros([img.shape[0], img.shape[0]]).astype(int)
brk_mask[df["i"].values, df["j"].values] = 1
# create a mask
try:
brk_mask = np.zeros([img.shape[0], img.shape[0]]).astype(int)
brk_mask[df["i"].values, df["j"].values] = 1
except Exception:
brk_mask = np.zeros([img.shape[0], img.shape[0]]).astype(int)
brk_mask = np.ma.masked_less(brk_mask, 1)
binmap = mpl.colors.ListedColormap("red")

Expand Down

0 comments on commit 8e69df9

Please sign in to comment.