Skip to content

Commit

Permalink
Linear probe for agile v2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 656436223
  • Loading branch information
sdenton4 authored and copybara-github committed Jul 26, 2024
1 parent e682e0a commit 26830db
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 57 deletions.
161 changes: 161 additions & 0 deletions chirp/projects/agile2/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for training and applying a linear classifier."""

from typing import Any

from chirp.models import metrics
from chirp.projects.agile2 import classifier_data
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tqdm


def hinge_loss(pred: jax.Array, y: jax.Array, w: jax.Array) -> jax.Array:
"""Weighted SVM hinge loss."""
# Convert multihot to +/- 1 labels.
y = 2 * y - 1
return w * jnp.maximum(0, 1 - y * pred)


def bce_loss(pred: jax.Array, y: jax.Array, w: jax.Array) -> jax.Array:
return w * optax.losses.sigmoid_binary_cross_entropy(pred, y)


def infer(params, embeddings: jax.Array | np.ndarray):
"""Apply the model to embeddings."""
return jnp.dot(embeddings, params['beta']) + params['beta_bias']


def forward(
params,
batch,
weak_neg_weight: float,
l2_mu: float,
loss_name: str = 'hinge',
) -> jax.Array:
"""Forward pass for classifier training."""
embeddings = batch.embedding
pred = infer(params, embeddings)
weights = (
batch.is_labeled_mask + (1.0 - batch.is_labeled_mask) * weak_neg_weight
)
# Loss shape is [B, C]
if loss_name == 'hinge':
loss = hinge_loss(pred=pred, y=batch.multihot, w=weights).sum()
elif loss_name == 'bce':
loss = bce_loss(pred=pred, y=batch.multihot, w=weights).sum()
else:
raise ValueError(f'Unknown loss name: {loss_name}')
l2_reg = jnp.dot(params['beta'].T, params['beta']).mean()
loss = loss + l2_mu * l2_reg
return loss.mean()


def eval_classifier(
params: Any,
data_manager: classifier_data.DataManager,
eval_ids: np.ndarray,
) -> dict[str, float]:
"""Evaluate a classifier on a set of examples."""
iter_ = data_manager.batched_example_iterator(
eval_ids, add_weak_negatives=False, repeat=False
)
# The embedding ids may be shuffled by the iterator, so we will track the ids
# of the examples we are evaluating.
got_ids = []
pred_logits = []
true_labels = []
for batch in iter_:
pred_logits.append(infer(params, batch.embedding))
true_labels.append(batch.multihot)
got_ids.append(batch.idx)
pred_logits = np.concatenate(pred_logits, axis=0)
true_labels = np.concatenate(true_labels, axis=0)
got_ids = np.concatenate(got_ids, axis=0)

# Compute the top1 accuracy on examples with at least one label.
labeled_locs = np.where(true_labels.sum(axis=1) > 0)
top_preds = np.argmax(pred_logits, axis=1)
top1 = true_labels[np.arange(top_preds.shape[0]), top_preds]
top1 = top1[labeled_locs].mean()

rocs = metrics.roc_auc(
logits=pred_logits, labels=true_labels, sample_threshold=1
)
cmaps = metrics.cmap(
logits=pred_logits, labels=true_labels, sample_threshold=1
)
return {
'top1_acc': top1,
'roc_auc': rocs['macro'],
'roc_auc_individual': rocs['individual'],
'cmap': cmaps['macro'],
'cmap_individual': cmaps['individual'],
'eval_ids': got_ids,
'eval_preds': pred_logits,
'eval_labels': true_labels,
}


def train_linear_classifier(
data_manager: classifier_data.DataManager,
learning_rate: float,
weak_neg_weight: float,
l2_mu: float,
num_train_steps: int,
loss_name: str = 'hinge',
):
"""Train a linear classifier."""
optimizer = optax.adam(learning_rate=learning_rate)
embedding_dim = data_manager.db.embedding_dimension()
num_classes = len(data_manager.target_labels)
params = {
'beta': jnp.zeros((embedding_dim, num_classes)),
'beta_bias': jnp.zeros((num_classes,)),
}
opt_state = optimizer.init(params)

def update(params, batch, opt_state, **kwargs) -> jax.Array:
loss, grads = jax.value_and_grad(forward)(params, batch, **kwargs)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return loss, params, opt_state

train_ids, eval_ids = data_manager.get_train_test_split()
iter_ = data_manager.batched_example_iterator(
train_ids, add_weak_negatives=True, repeat=True
)

progress = tqdm.tqdm(enumerate(iter_), total=num_train_steps)
for step, batch in enumerate(iter_):
if step >= num_train_steps:
break
loss, params, opt_state = update(
params,
batch,
opt_state,
weak_neg_weight=weak_neg_weight,
l2_mu=l2_mu,
loss_name=loss_name,
)
progress.update()
progress.set_description(f'Loss {loss:.8f}')

eval_scores = eval_classifier(params, data_manager, eval_ids)
return params, eval_scores
173 changes: 122 additions & 51 deletions chirp/projects/agile2/classifier_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Tools for processing data for the Agile2 classifier."""

import abc
import dataclasses
import itertools
from typing import Any, Iterator, Sequence
Expand Down Expand Up @@ -61,7 +62,80 @@ def join_batches(self, other: 'LabeledExample') -> 'LabeledExample':


@dataclasses.dataclass
class ClassifierDataManager:
class DataManager:
"""Base class for managing data for training and evaluation."""

target_labels: tuple[str, ...]
db: interface.GraphSearchDBInterface
batch_size: int
rng: np.random.Generator

def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]:
"""Create a train/test split over all labels.
Returns:
Two numpy arrays contianing train and eval embedding ids, respectively.
"""
raise NotImplementedError('get_train_test_split is not implemented.')

def get_multihot_labels(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
"""Create the multihot label for one example."""
labels = self.db.get_labels(idx)
lbl_idxes = {label: i for i, label in enumerate(self.target_labels)}
pos = np.zeros(len(self.target_labels), dtype=np.float32)
neg = np.zeros(len(self.target_labels), dtype=np.float32)
for label in labels:
if label.type == interface.LabelType.POSITIVE:
pos[lbl_idxes[label.label]] += 1.0
elif label.type == interface.LabelType.NEGATIVE:
neg[lbl_idxes[label.label]] += 1.0
count = pos + neg
mask = count > 0
denom = np.maximum(count, 1.0)
multihot = pos / denom
return multihot, mask

def labeled_example_iterator(
self, ids: np.ndarray, repeat: bool = False
) -> Iterator[LabeledExample]:
"""Create an iterator for training a classifier for target_labels.
Args:
ids: The embedding IDs to iterate over.
repeat: If True, repeat the iterator indefinitely.
Yields:
LabeledExample objects.
"""
ids = ids.copy()
self.rng.shuffle(ids)
q = 0
while True:
x = ids[q]
x_emb = self.db.get_embedding(x)
x_multihot, x_is_labeled = self.get_multihot_labels(x)
yield LabeledExample(x, x_emb, x_multihot, x_is_labeled)
q += 1
if q >= len(ids) and repeat:
q = 0
self.rng.shuffle(ids)
elif q >= len(ids):
break

def batched_example_iterator(
self,
labeled_ids: np.ndarray,
repeat: bool = False,
**unused_kwargs,
) -> Iterator[LabeledExample]:
"""Labeled training data iterator with weak negatives."""
example_iterator = self.labeled_example_iterator(labeled_ids, repeat=repeat)
for ex_batch in batched(example_iterator, self.batch_size):
yield LabeledExample.create_batched(ex_batch)


@dataclasses.dataclass
class AgileDataManager(DataManager):
"""Collects labeled data for training classifiers.
Attributes:
Expand All @@ -74,14 +148,9 @@ class ClassifierDataManager:
weak_negatives_batch_size: The batch size for weak negatives.
rng: The random number generator to use.
"""

target_labels: tuple[str, ...]
db: interface.GraphSearchDBInterface
train_ratio: float
min_eval_examples: int
batch_size: int
weak_negatives_batch_size: int
rng: np.random.Generator

def get_single_label_train_test_split(
self, label: str
Expand Down Expand Up @@ -148,59 +217,19 @@ def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]:
all_train = np.setdiff1d(all_ids, all_eval)
return all_train, all_eval

def get_multihot_labels(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
"""Create the multihot label for one example."""
labels = self.db.get_labels(idx)
lbl_idxes = {label: i for i, label in enumerate(self.target_labels)}
pos = np.zeros(len(self.target_labels), dtype=np.float32)
neg = np.zeros(len(self.target_labels), dtype=np.float32)
for label in labels:
if label.type == interface.LabelType.POSITIVE:
pos[lbl_idxes[label.label]] += 1.0
elif label.type == interface.LabelType.NEGATIVE:
neg[lbl_idxes[label.label]] += 1.0
count = pos + neg
mask = count > 0
denom = np.maximum(count, 1.0)
multihot = pos / denom
return multihot, mask

def labeled_example_iterator(
self, ids: np.ndarray, repeat: bool = False
) -> Iterator[LabeledExample]:
"""Create an iterator for training a classifier for target_labels.
Args:
ids: The embedding IDs to iterate over.
repeat: If True, repeat the iterator indefinitely.
Yields:
LabeledExample objects.
"""
ids = ids.copy()
self.rng.shuffle(ids)
q = 0
while True:
x = ids[q]
x_emb = self.db.get_embedding(x)
x_multihot, x_is_labeled = self.get_multihot_labels(x)
yield LabeledExample(x, x_emb, x_multihot, x_is_labeled)
q += 1
if q >= len(ids) and repeat:
q = 0
self.rng.shuffle(ids)
elif q >= len(ids):
break

def batched_example_iterator(
self,
labeled_ids: np.ndarray,
add_weak_negatives: bool = False,
repeat: bool = False,
add_weak_negatives: bool = False,
) -> Iterator[LabeledExample]:
"""Labeled training data iterator with weak negatives."""
example_iterator = self.labeled_example_iterator(labeled_ids, repeat=repeat)
example_iterator = batched(example_iterator, self.batch_size)
if not add_weak_negatives:
for ex_batch in example_iterator:
yield LabeledExample.create_batched(ex_batch)
return

weak_ids = np.setdiff1d(self.db.get_embedding_ids(), labeled_ids)
weak_iterator = self.labeled_example_iterator(weak_ids, repeat=True)
Expand All @@ -216,6 +245,48 @@ def batched_example_iterator(
yield ex_batch


@dataclasses.dataclass
class FullyAnnotatedDataManager(DataManager):
"""A DataManager for fully-annotated datasets."""

train_examples_per_class: int
min_eval_examples: int
add_unlabeled_train_examples: bool

def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]:
"""Create a train/test split over the fully-annotated dataset."""
pos_id_sets = {}
eval_id_sets = {}
for label in self.target_labels:
pos_id_sets[label] = self.db.get_embeddings_by_label(
label, interface.LabelType.POSITIVE, None
)
self.rng.shuffle(pos_id_sets[label])
eval_id_sets[label] = pos_id_sets[label][: self.min_eval_examples]
all_eval_ids = np.concatenate(tuple(eval_id_sets.values()), axis=0)

# Now produce train sets of the desired size,
# avoiding the selected eval examples.
train_id_sets = {}
for label in self.target_labels:
pos_set = np.setdiff1d(pos_id_sets[label], all_eval_ids)
train_id_sets[label] = pos_set[: self.train_examples_per_class]
if self.add_unlabeled_train_examples:
unlabeled_ids = np.setdiff1d(
self.db.get_embedding_ids(),
np.concatenate(tuple(pos_id_sets.values()), axis=0),
)
np.setdiff1d(unlabeled_ids, all_eval_ids)
train_id_sets['UNLABELED'] = unlabeled_ids[
: self.train_examples_per_class
]

# The final eval set is the complement of all selected training id's.
all_train_ids = np.concatenate(tuple(train_id_sets.values()), axis=0)
eval_ids = np.setdiff1d(self.db.get_embedding_ids(), all_train_ids)
return all_train_ids, eval_ids


def batched(iterable: Iterator[Any], n: int) -> Iterator[Any]:
# TODO(tomdenton): Use itertools.batched in Python 3.12+
# batched('ABCDEFG', 3) → ABC DEF G
Expand Down
Loading

0 comments on commit 26830db

Please sign in to comment.