From 7d2f4f6952cd85a786c8170fc2e613da6809da3f Mon Sep 17 00:00:00 2001 From: KatrionaGoldmann Date: Fri, 7 Jun 2024 10:08:47 +0100 Subject: [PATCH] update model description --- trapdata/ml/models/classification.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/trapdata/ml/models/classification.py b/trapdata/ml/models/classification.py index 48b37b24..5ae43613 100644 --- a/trapdata/ml/models/classification.py +++ b/trapdata/ml/models/classification.py @@ -105,26 +105,20 @@ def forward(self, x): return x -class Resnet50Classifier_KG(InferenceBaseClass): +class Resnet50Classifier_Turing(InferenceBaseClass): # function to run the Turing models - logger.info("KG: Resnet50Classifier_KG") + logger.info("KG: Resnet50Classifier_Turing") input_size = 300 def get_model(self): num_classes = len(self.category_map) - logger.info(f"KG: num_classes {num_classes}") - model = Resnet50(num_classes=num_classes) model = model.to(self.device) - - # # state_dict = torch.hub.load_state_dict_from_url(weights_url) checkpoint = torch.load(self.weights, map_location=self.device) - - # # The model state dict is nested in some checkpoints, and not in others + # The model state dict is nested in some checkpoints, and not in others state_dict = checkpoint.get("model_state_dict") or checkpoint model.load_state_dict(state_dict) - # model.load_state_dict(state_dict) model.eval() return model @@ -306,11 +300,11 @@ class QuebecVermontMothSpeciesClassifierMixedResolution( ) class TuringCostaRicaSpeciesClassifier( - SpeciesClassifier, Resnet50Classifier_KG + SpeciesClassifier, Resnet50Classifier_Turing ): name = "Turing Costa Rica Species Classifier" description = ( - "Trained on 13th May 2024 by Turing team using Resnet50 model." + "Trained on 4th June 2024 by Turing team using Resnet50 model." ) weights_path = ( "turing-costarica_v03_resnet50_2024-06-04-16-17_state.pt" @@ -321,7 +315,7 @@ class TuringCostaRicaSpeciesClassifier( class TuringUKSpeciesClassifier( - SpeciesClassifier, Resnet50Classifier_KG + SpeciesClassifier, Resnet50Classifier_Turing ): name = "Turing UK Species Classifier" description = (