From ee628e9f249577f99cdaf871bc8ef79f58145b2b Mon Sep 17 00:00:00 2001 From: Cameron Bodine Date: Wed, 4 Dec 2024 04:46:17 -0500 Subject: [PATCH] Kludge to fix tf.shape for shadow model #128 --- src/class_portstarObj.py | 4 ++-- src/funcs_model.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/class_portstarObj.py b/src/class_portstarObj.py index ec0de5b..e6c7903 100644 --- a/src/class_portstarObj.py +++ b/src/class_portstarObj.py @@ -1824,8 +1824,8 @@ def _detectShadow(self, remShadow, i, USE_GPU, doPlot=True, tileFile='.jpg'): ############### # Do prediction - port_label, port_prob = doPredict(model, MODEL, self.port.sonDat, N_DATA_BANDS, NCLASSES, TARGET_SIZE, OTSU_THRESHOLD) - star_label, star_prob = doPredict(model, MODEL, self.star.sonDat, N_DATA_BANDS, NCLASSES, TARGET_SIZE, OTSU_THRESHOLD) + port_label, port_prob = doPredict(model, MODEL, self.port.sonDat, N_DATA_BANDS, NCLASSES, TARGET_SIZE, OTSU_THRESHOLD, shadow=True) + star_label, star_prob = doPredict(model, MODEL, self.star.sonDat, N_DATA_BANDS, NCLASSES, TARGET_SIZE, OTSU_THRESHOLD, shadow=True) # Set shadow to 0, else 1 port_label = np.where(port_label==0,1,0) diff --git a/src/funcs_model.py b/src/funcs_model.py index 99cddf0..19d5e18 100644 --- a/src/funcs_model.py +++ b/src/funcs_model.py @@ -140,7 +140,7 @@ def initModel(weights, configfile, USE_GPU=False): ################################################ #======================================================================= -def doPredict(model, MODEL, arr, N_DATA_BANDS, NCLASSES, TARGET_SIZE, OTSU_THRESHOLD): +def doPredict(model, MODEL, arr, N_DATA_BANDS, NCLASSES, TARGET_SIZE, OTSU_THRESHOLD, shadow=False): ''' ''' @@ -152,6 +152,11 @@ def doPredict(model, MODEL, arr, N_DATA_BANDS, NCLASSES, TARGET_SIZE, OTSU_THRES image = standardize(image.numpy()).squeeze() + # Kludge to fix error noted in Issue #128 + if shadow: + image = image[:,:,0] + image = tf.expand_dims(image, 2) + if NCLASSES == 2: E0, E1 = est_label_binary(image, model, MODEL, False, NCLASSES, TARGET_SIZE, w, h)