diff --git a/chirp/projects/agile2/classifier.py b/chirp/projects/agile2/classifier.py new file mode 100644 index 00000000..311d3f18 --- /dev/null +++ b/chirp/projects/agile2/classifier.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for training and applying a linear classifier.""" + +from typing import Any + +from chirp.models import metrics +from chirp.projects.agile2 import classifier_data +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tqdm + + +def hinge_loss(pred: jax.Array, y: jax.Array, w: jax.Array) -> jax.Array: + """Weighted SVM hinge loss.""" + # Convert multihot to +/- 1 labels. + y = 2 * y - 1 + return w * jnp.maximum(0, 1 - y * pred) + + +def bce_loss(pred: jax.Array, y: jax.Array, w: jax.Array) -> jax.Array: + return w * optax.losses.sigmoid_binary_cross_entropy(pred, y) + + +def infer(params, embeddings: jax.Array | np.ndarray): + """Apply the model to embeddings.""" + return jnp.dot(embeddings, params['beta']) + params['beta_bias'] + + +def forward( + params, + batch, + weak_neg_weight: float, + l2_mu: float, + loss_name: str = 'hinge', +) -> jax.Array: + """Forward pass for classifier training.""" + embeddings = batch.embedding + pred = infer(params, embeddings) + weights = ( + batch.is_labeled_mask + (1.0 - batch.is_labeled_mask) * weak_neg_weight + ) + # Loss shape is [B, C] + if loss_name == 'hinge': + loss = hinge_loss(pred=pred, y=batch.multihot, w=weights).sum() + elif loss_name == 'bce': + loss = bce_loss(pred=pred, y=batch.multihot, w=weights).sum() + else: + raise ValueError(f'Unknown loss name: {loss_name}') + l2_reg = jnp.dot(params['beta'].T, params['beta']).mean() + loss = loss + l2_mu * l2_reg + return loss.mean() + + +def eval_classifier( + params: Any, + data_manager: classifier_data.DataManager, + eval_ids: np.ndarray, +) -> dict[str, float]: + """Evaluate a classifier on a set of examples.""" + iter_ = data_manager.batched_example_iterator( + eval_ids, add_weak_negatives=False, repeat=False + ) + # The embedding ids may be shuffled by the iterator, so we will track the ids + # of the examples we are evaluating. + got_ids = [] + pred_logits = [] + true_labels = [] + for batch in iter_: + pred_logits.append(infer(params, batch.embedding)) + true_labels.append(batch.multihot) + got_ids.append(batch.idx) + pred_logits = np.concatenate(pred_logits, axis=0) + true_labels = np.concatenate(true_labels, axis=0) + got_ids = np.concatenate(got_ids, axis=0) + + # Compute the top1 accuracy on examples with at least one label. + labeled_locs = np.where(true_labels.sum(axis=1) > 0) + top_preds = np.argmax(pred_logits, axis=1) + top1 = true_labels[np.arange(top_preds.shape[0]), top_preds] + top1 = top1[labeled_locs].mean() + + rocs = metrics.roc_auc( + logits=pred_logits, labels=true_labels, sample_threshold=1 + ) + cmaps = metrics.cmap( + logits=pred_logits, labels=true_labels, sample_threshold=1 + ) + return { + 'top1_acc': top1, + 'roc_auc': rocs['macro'], + 'roc_auc_individual': rocs['individual'], + 'cmap': cmaps['macro'], + 'cmap_individual': cmaps['individual'], + 'eval_ids': got_ids, + 'eval_preds': pred_logits, + 'eval_labels': true_labels, + } + + +def train_linear_classifier( + data_manager: classifier_data.DataManager, + learning_rate: float, + weak_neg_weight: float, + l2_mu: float, + num_train_steps: int, + loss_name: str = 'hinge', +): + """Train a linear classifier.""" + optimizer = optax.adam(learning_rate=learning_rate) + embedding_dim = data_manager.db.embedding_dimension() + num_classes = len(data_manager.target_labels) + params = { + 'beta': jnp.zeros((embedding_dim, num_classes)), + 'beta_bias': jnp.zeros((num_classes,)), + } + opt_state = optimizer.init(params) + + def update(params, batch, opt_state, **kwargs) -> jax.Array: + loss, grads = jax.value_and_grad(forward)(params, batch, **kwargs) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return loss, params, opt_state + + train_ids, eval_ids = data_manager.get_train_test_split() + iter_ = data_manager.batched_example_iterator( + train_ids, add_weak_negatives=True, repeat=True + ) + + progress = tqdm.tqdm(enumerate(iter_), total=num_train_steps) + for step, batch in enumerate(iter_): + if step >= num_train_steps: + break + loss, params, opt_state = update( + params, + batch, + opt_state, + weak_neg_weight=weak_neg_weight, + l2_mu=l2_mu, + loss_name=loss_name, + ) + progress.update() + progress.set_description(f'Loss {loss:.8f}') + + eval_scores = eval_classifier(params, data_manager, eval_ids) + return params, eval_scores diff --git a/chirp/projects/agile2/classifier_data.py b/chirp/projects/agile2/classifier_data.py index 5efa2b01..0558f301 100644 --- a/chirp/projects/agile2/classifier_data.py +++ b/chirp/projects/agile2/classifier_data.py @@ -15,6 +15,7 @@ """Tools for processing data for the Agile2 classifier.""" +import abc import dataclasses import itertools from typing import Any, Iterator, Sequence @@ -61,7 +62,80 @@ def join_batches(self, other: 'LabeledExample') -> 'LabeledExample': @dataclasses.dataclass -class ClassifierDataManager: +class DataManager: + """Base class for managing data for training and evaluation.""" + + target_labels: tuple[str, ...] + db: interface.GraphSearchDBInterface + batch_size: int + rng: np.random.Generator + + def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]: + """Create a train/test split over all labels. + + Returns: + Two numpy arrays contianing train and eval embedding ids, respectively. + """ + raise NotImplementedError('get_train_test_split is not implemented.') + + def get_multihot_labels(self, idx: int) -> tuple[np.ndarray, np.ndarray]: + """Create the multihot label for one example.""" + labels = self.db.get_labels(idx) + lbl_idxes = {label: i for i, label in enumerate(self.target_labels)} + pos = np.zeros(len(self.target_labels), dtype=np.float32) + neg = np.zeros(len(self.target_labels), dtype=np.float32) + for label in labels: + if label.type == interface.LabelType.POSITIVE: + pos[lbl_idxes[label.label]] += 1.0 + elif label.type == interface.LabelType.NEGATIVE: + neg[lbl_idxes[label.label]] += 1.0 + count = pos + neg + mask = count > 0 + denom = np.maximum(count, 1.0) + multihot = pos / denom + return multihot, mask + + def labeled_example_iterator( + self, ids: np.ndarray, repeat: bool = False + ) -> Iterator[LabeledExample]: + """Create an iterator for training a classifier for target_labels. + + Args: + ids: The embedding IDs to iterate over. + repeat: If True, repeat the iterator indefinitely. + + Yields: + LabeledExample objects. + """ + ids = ids.copy() + self.rng.shuffle(ids) + q = 0 + while True: + x = ids[q] + x_emb = self.db.get_embedding(x) + x_multihot, x_is_labeled = self.get_multihot_labels(x) + yield LabeledExample(x, x_emb, x_multihot, x_is_labeled) + q += 1 + if q >= len(ids) and repeat: + q = 0 + self.rng.shuffle(ids) + elif q >= len(ids): + break + + def batched_example_iterator( + self, + labeled_ids: np.ndarray, + repeat: bool = False, + **unused_kwargs, + ) -> Iterator[LabeledExample]: + """Labeled training data iterator with weak negatives.""" + example_iterator = self.labeled_example_iterator(labeled_ids, repeat=repeat) + for ex_batch in batched(example_iterator, self.batch_size): + yield LabeledExample.create_batched(ex_batch) + + +@dataclasses.dataclass +class AgileDataManager(DataManager): """Collects labeled data for training classifiers. Attributes: @@ -74,14 +148,9 @@ class ClassifierDataManager: weak_negatives_batch_size: The batch size for weak negatives. rng: The random number generator to use. """ - - target_labels: tuple[str, ...] - db: interface.GraphSearchDBInterface train_ratio: float min_eval_examples: int - batch_size: int weak_negatives_batch_size: int - rng: np.random.Generator def get_single_label_train_test_split( self, label: str @@ -148,59 +217,19 @@ def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]: all_train = np.setdiff1d(all_ids, all_eval) return all_train, all_eval - def get_multihot_labels(self, idx: int) -> tuple[np.ndarray, np.ndarray]: - """Create the multihot label for one example.""" - labels = self.db.get_labels(idx) - lbl_idxes = {label: i for i, label in enumerate(self.target_labels)} - pos = np.zeros(len(self.target_labels), dtype=np.float32) - neg = np.zeros(len(self.target_labels), dtype=np.float32) - for label in labels: - if label.type == interface.LabelType.POSITIVE: - pos[lbl_idxes[label.label]] += 1.0 - elif label.type == interface.LabelType.NEGATIVE: - neg[lbl_idxes[label.label]] += 1.0 - count = pos + neg - mask = count > 0 - denom = np.maximum(count, 1.0) - multihot = pos / denom - return multihot, mask - - def labeled_example_iterator( - self, ids: np.ndarray, repeat: bool = False - ) -> Iterator[LabeledExample]: - """Create an iterator for training a classifier for target_labels. - - Args: - ids: The embedding IDs to iterate over. - repeat: If True, repeat the iterator indefinitely. - - Yields: - LabeledExample objects. - """ - ids = ids.copy() - self.rng.shuffle(ids) - q = 0 - while True: - x = ids[q] - x_emb = self.db.get_embedding(x) - x_multihot, x_is_labeled = self.get_multihot_labels(x) - yield LabeledExample(x, x_emb, x_multihot, x_is_labeled) - q += 1 - if q >= len(ids) and repeat: - q = 0 - self.rng.shuffle(ids) - elif q >= len(ids): - break - def batched_example_iterator( self, labeled_ids: np.ndarray, - add_weak_negatives: bool = False, repeat: bool = False, + add_weak_negatives: bool = False, ) -> Iterator[LabeledExample]: """Labeled training data iterator with weak negatives.""" example_iterator = self.labeled_example_iterator(labeled_ids, repeat=repeat) example_iterator = batched(example_iterator, self.batch_size) + if not add_weak_negatives: + for ex_batch in example_iterator: + yield LabeledExample.create_batched(ex_batch) + return weak_ids = np.setdiff1d(self.db.get_embedding_ids(), labeled_ids) weak_iterator = self.labeled_example_iterator(weak_ids, repeat=True) @@ -216,6 +245,48 @@ def batched_example_iterator( yield ex_batch +@dataclasses.dataclass +class FullyAnnotatedDataManager(DataManager): + """A DataManager for fully-annotated datasets.""" + + train_examples_per_class: int + min_eval_examples: int + add_unlabeled_train_examples: bool + + def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]: + """Create a train/test split over the fully-annotated dataset.""" + pos_id_sets = {} + eval_id_sets = {} + for label in self.target_labels: + pos_id_sets[label] = self.db.get_embeddings_by_label( + label, interface.LabelType.POSITIVE, None + ) + self.rng.shuffle(pos_id_sets[label]) + eval_id_sets[label] = pos_id_sets[label][: self.min_eval_examples] + all_eval_ids = np.concatenate(tuple(eval_id_sets.values()), axis=0) + + # Now produce train sets of the desired size, + # avoiding the selected eval examples. + train_id_sets = {} + for label in self.target_labels: + pos_set = np.setdiff1d(pos_id_sets[label], all_eval_ids) + train_id_sets[label] = pos_set[: self.train_examples_per_class] + if self.add_unlabeled_train_examples: + unlabeled_ids = np.setdiff1d( + self.db.get_embedding_ids(), + np.concatenate(tuple(pos_id_sets.values()), axis=0), + ) + np.setdiff1d(unlabeled_ids, all_eval_ids) + train_id_sets['UNLABELED'] = unlabeled_ids[ + : self.train_examples_per_class + ] + + # The final eval set is the complement of all selected training id's. + all_train_ids = np.concatenate(tuple(train_id_sets.values()), axis=0) + eval_ids = np.setdiff1d(self.db.get_embedding_ids(), all_train_ids) + return all_train_ids, eval_ids + + def batched(iterable: Iterator[Any], n: int) -> Iterator[Any]: # TODO(tomdenton): Use itertools.batched in Python 3.12+ # batched('ABCDEFG', 3) → ABC DEF G diff --git a/chirp/projects/agile2/tests/classifier_data_test.py b/chirp/projects/agile2/tests/classifier_data_test.py index af03580a..23b2a3bb 100644 --- a/chirp/projects/agile2/tests/classifier_data_test.py +++ b/chirp/projects/agile2/tests/classifier_data_test.py @@ -70,7 +70,7 @@ def test_train_test_split_fully_labeled(self): positive_label_prob=0.5, rng=np.random.default_rng(42), ) - data_manager = classifier_data.ClassifierDataManager( + data_manager = classifier_data.AgileDataManager( target_labels=test_utils.CLASS_LABELS, db=db, train_ratio=0.8, @@ -103,7 +103,7 @@ def test_train_test_split_partially_labeled(self): positive_label_prob=0.5, rng=np.random.default_rng(42), ) - data_manager = classifier_data.ClassifierDataManager( + data_manager = classifier_data.AgileDataManager( target_labels=test_utils.CLASS_LABELS, db=db, train_ratio=0.8, @@ -136,7 +136,7 @@ def test_partial_classes(self): rng=np.random.default_rng(42), ) # Only use three labels, which is half the length of the full class list. - data_manager = classifier_data.ClassifierDataManager( + data_manager = classifier_data.AgileDataManager( target_labels=test_utils.CLASS_LABELS[:3], db=db, train_ratio=0.8, @@ -166,7 +166,7 @@ def test_multihot_labels(self): num_embeddings=100, rng=np.random.default_rng(42), ) - data_manager = classifier_data.ClassifierDataManager( + data_manager = classifier_data.AgileDataManager( target_labels=test_utils.CLASS_LABELS, db=db, train_ratio=0.8, diff --git a/chirp/taxonomy/annotations_fns.py b/chirp/taxonomy/annotations_fns.py index c737999f..9b51322d 100644 --- a/chirp/taxonomy/annotations_fns.py +++ b/chirp/taxonomy/annotations_fns.py @@ -57,7 +57,9 @@ def load_caples_annotations(annotations_path: epath.Path) -> pd.DataFrame: return segments -def load_cornell_annotations(annotations_path: epath.Path) -> pd.DataFrame: +def load_cornell_annotations( + annotations_path: epath.Path, file_id_prefix: str = '' +) -> pd.DataFrame: """Load the annotations from a Cornell Zenodo dataset.""" start_time_fn = lambda row: float(row['Start Time (s)']) end_time_fn = lambda row: float(row['End Time (s)']) @@ -66,7 +68,7 @@ def load_cornell_annotations(annotations_path: epath.Path) -> pd.DataFrame: row['Species eBird Code'].strip().replace('????', 'unknown') ] - filename_fn = lambda filepath, row: row['Filename'].strip() + filename_fn = lambda filepath, row: file_id_prefix + row['Filename'].strip() annos = annotations.read_dataset_annotations_csvs( [annotations_path], filename_fn=filename_fn,