Skip to content

Commit

Permalink
Merge pull request #58 from AMI-system/cr_models
Browse files Browse the repository at this point in the history
Classes for Turing Costa Rica models
  • Loading branch information
mihow authored Aug 15, 2024
2 parents 9356076 + 7d2f4f6 commit 4e9cf3a
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,7 @@ trapdata.ini

# Docker volumes
db_data/


# Test files
sample_images
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,22 @@ docker stop ami-db && docker remove ami-db

A script is available in the repo source to run the commands above.
`./scrips/start_db_container.sh`



## KG Notes for adding new models

- To add new models, save the pt and json files to:
```
~/Library/Application Support/trapdata/models
```
or wherever you set the appropriate dir in settings.
The json file is simply a dict of species name and index.

Then you need to create a class in `trapdata/ml/models/classification.py` or `trapdata/ml/models/localization.py` and add the model details.

- To clear the cache:

```
rm ~/Library/Application\ Support/trapdata/trapdata.db
```
68 changes: 68 additions & 0 deletions trapdata/ml/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,44 @@ def forward(self, x):

return x

class Resnet50Classifier_Turing(InferenceBaseClass):
# function to run the Turing models
logger.info("KG: Resnet50Classifier_Turing")
input_size = 300

def get_model(self):
num_classes = len(self.category_map)
model = Resnet50(num_classes=num_classes)
model = model.to(self.device)
checkpoint = torch.load(self.weights, map_location=self.device)
# The model state dict is nested in some checkpoints, and not in others
state_dict = checkpoint.get("model_state_dict") or checkpoint

model.load_state_dict(state_dict)
model.eval()
return model

def get_transforms(self):
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
return torchvision.transforms.Compose(
[
torchvision.transforms.Resize((self.input_size, self.input_size)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)

def post_process_batch(self, output):
predictions = torch.nn.functional.softmax(output, dim=1)
predictions = predictions.cpu().numpy()

categories = predictions.argmax(axis=1)
labels = [self.category_map[cat] for cat in categories]
scores = predictions.max(axis=1).astype(float)

result = list(zip(labels, scores))
logger.debug(f"Post-processing result batch: {result}")
return result

class Resnet50Classifier(InferenceBaseClass):
input_size = 300
Expand Down Expand Up @@ -261,6 +299,36 @@ class QuebecVermontMothSpeciesClassifierMixedResolution(
"quebec-vermont_moth-category-map_19Jan2023.json"
)

class TuringCostaRicaSpeciesClassifier(
SpeciesClassifier, Resnet50Classifier_Turing
):
name = "Turing Costa Rica Species Classifier"
description = (
"Trained on 4th June 2024 by Turing team using Resnet50 model."
)
weights_path = (
"turing-costarica_v03_resnet50_2024-06-04-16-17_state.pt"
)
labels_path = (
"03_costarica_data_category_map.json"
)


class TuringUKSpeciesClassifier(
SpeciesClassifier, Resnet50Classifier_Turing
):
name = "Turing UK Species Classifier"
description = (
"Trained on 13th May 2024 by Turing team using Resnet50 model."
)
weights_path = (
"turing-uk_v03_resnet50_2024-05-13-10-03_state.pt"
)
labels_path = (
"03_uk_data_category_map.json"
)



class UKDenmarkMothSpeciesClassifierMixedResolution(
SpeciesClassifier, Resnet50ClassifierLowRes
Expand Down

0 comments on commit 4e9cf3a

Please sign in to comment.