Skip to content

Commit

Permalink
Use remote URLs for new models. Update formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
mihow committed Aug 15, 2024
1 parent 4e9cf3a commit f5d5cc1
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions trapdata/ml/models/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import timm
import torch
import torch.utils.data
import torchvision

from trapdata import constants, logger
Expand Down Expand Up @@ -105,6 +106,7 @@ def forward(self, x):

return x


class Resnet50Classifier_Turing(InferenceBaseClass):
# function to run the Turing models
logger.info("KG: Resnet50Classifier_Turing")
Expand Down Expand Up @@ -144,6 +146,7 @@ def post_process_batch(self, output):
logger.debug(f"Post-processing result batch: {result}")
return result


class Resnet50Classifier(InferenceBaseClass):
input_size = 300

Expand Down Expand Up @@ -299,37 +302,33 @@ class QuebecVermontMothSpeciesClassifierMixedResolution(
"quebec-vermont_moth-category-map_19Jan2023.json"
)

class TuringCostaRicaSpeciesClassifier(
SpeciesClassifier, Resnet50Classifier_Turing
):

class TuringCostaRicaSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing):
name = "Turing Costa Rica Species Classifier"
description = (
"Trained on 4th June 2024 by Turing team using Resnet50 model."
)
description = "Trained on 4th June 2024 by Turing team using Resnet50 model."
weights_path = (
"https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
"turing-costarica_v03_resnet50_2024-06-04-16-17_state.pt"
)
labels_path = (
"https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
"03_costarica_data_category_map.json"
)


class TuringUKSpeciesClassifier(
SpeciesClassifier, Resnet50Classifier_Turing
):
class TuringUKSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing):
name = "Turing UK Species Classifier"
description = (
"Trained on 13th May 2024 by Turing team using Resnet50 model."
)
description = "Trained on 13th May 2024 by Turing team using Resnet50 model."
weights_path = (
"https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
"turing-uk_v03_resnet50_2024-05-13-10-03_state.pt"
)
labels_path = (
"https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
"03_uk_data_category_map.json"
)



class UKDenmarkMothSpeciesClassifierMixedResolution(
SpeciesClassifier, Resnet50ClassifierLowRes
):
Expand Down

0 comments on commit f5d5cc1

Please sign in to comment.