diff --git a/train_code/utils.py b/train_code/utils.py index 472559c..f16c859 100644 --- a/train_code/utils.py +++ b/train_code/utils.py @@ -64,7 +64,7 @@ def label2rgb(label_field, image, kind='mix', bg_label=-1, bg_color=(0, 0, 0)): median = np.median(image[mask], axis=0) color = 0.5*mean + 0.5*median elif 40 < std: - color = image[mask].median(axis=0) + color = np.median(image[mask], axis=0) out[mask] = color return out