From 4aa0bfbc325877790530486a2ba1de36df331a88 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Fri, 5 Apr 2024 07:41:59 -0700 Subject: [PATCH] Use Kaggle Models URL for downloading Perch. This allows loading Version 8. Also adds a convenience method for loading the model from just a version number. PiperOrigin-RevId: 622176242 --- chirp/inference/models.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/chirp/inference/models.py b/chirp/inference/models.py index 8cd3c852..0b8a33d8 100644 --- a/chirp/inference/models.py +++ b/chirp/inference/models.py @@ -32,7 +32,11 @@ import tensorflow.compat.v1 as tf1 import tensorflow_hub as hub -PERCH_TF_HUB_URL = 'https://tfhub.dev/google/bird-vocalization-classifier' +PERCH_TF_HUB_URL = ( + 'https://www.kaggle.com/models/google/' + 'bird-vocalization-classifier/frameworks/TensorFlow2/' + 'variations/bird-vocalization-classifier/versions' +) def model_class_map() -> dict[str, Any]: @@ -287,6 +291,10 @@ def from_tfhub(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': raise ValueError( 'Exactly one of tfhub_version and model_path should be set.' ) + if config.tfhub_version in (5, 6, 7): + # Due to SNAFUs uploading the new model version to KaggleModels, + # some version numbers were skipped. + raise ValueError('TFHub version 5, 6, and 7 do not exist.') model_url = f'{PERCH_TF_HUB_URL}/{config.tfhub_version}' # This model behaves exactly like the usual saved_model. @@ -303,6 +311,20 @@ def from_tfhub(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': model=model, class_list=class_lists, batchable=batchable, **config ) + @classmethod + def load_version( + cls, tfhub_version: int, hop_size_s: float = 5.0 + ) -> 'TaxonomyModelTF': + cfg = config_dict.ConfigDict({ + 'model_path': '', + 'sample_rate': 32000, + 'window_size_s': 5.0, + 'hop_size_s': hop_size_s, + 'target_peak': 0.25, + 'tfhub_version': tfhub_version, + }) + return cls.from_tfhub(cfg) + @classmethod def from_config(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': logging.info('Loading taxonomy model...')