From f5d5cc149e905ed8c02e7855d9e0bd8eae6ce69c Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 14 Aug 2024 21:58:07 -0700 Subject: [PATCH] Use remote URLs for new models. Update formatting. --- trapdata/ml/models/classification.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/trapdata/ml/models/classification.py b/trapdata/ml/models/classification.py index 5ae43613..2813683e 100644 --- a/trapdata/ml/models/classification.py +++ b/trapdata/ml/models/classification.py @@ -1,5 +1,6 @@ import timm import torch +import torch.utils.data import torchvision from trapdata import constants, logger @@ -105,6 +106,7 @@ def forward(self, x): return x + class Resnet50Classifier_Turing(InferenceBaseClass): # function to run the Turing models logger.info("KG: Resnet50Classifier_Turing") @@ -144,6 +146,7 @@ def post_process_batch(self, output): logger.debug(f"Post-processing result batch: {result}") return result + class Resnet50Classifier(InferenceBaseClass): input_size = 300 @@ -299,37 +302,33 @@ class QuebecVermontMothSpeciesClassifierMixedResolution( "quebec-vermont_moth-category-map_19Jan2023.json" ) -class TuringCostaRicaSpeciesClassifier( - SpeciesClassifier, Resnet50Classifier_Turing -): + +class TuringCostaRicaSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing): name = "Turing Costa Rica Species Classifier" - description = ( - "Trained on 4th June 2024 by Turing team using Resnet50 model." - ) + description = "Trained on 4th June 2024 by Turing team using Resnet50 model." weights_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" "turing-costarica_v03_resnet50_2024-06-04-16-17_state.pt" ) labels_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" "03_costarica_data_category_map.json" ) -class TuringUKSpeciesClassifier( - SpeciesClassifier, Resnet50Classifier_Turing -): +class TuringUKSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing): name = "Turing UK Species Classifier" - description = ( - "Trained on 13th May 2024 by Turing team using Resnet50 model." - ) + description = "Trained on 13th May 2024 by Turing team using Resnet50 model." weights_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" "turing-uk_v03_resnet50_2024-05-13-10-03_state.pt" ) labels_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" "03_uk_data_category_map.json" ) - class UKDenmarkMothSpeciesClassifierMixedResolution( SpeciesClassifier, Resnet50ClassifierLowRes ):