diff --git a/jiant/models.py b/jiant/models.py index fe9d1616b..53d6e51b2 100644 --- a/jiant/models.py +++ b/jiant/models.py @@ -868,7 +868,13 @@ def forward(self, task, batch, predict=False): # Just get embeddings and invoke task module. word_embs_in_context, sent_mask = self.sent_encoder(batch["input1"], task) module = getattr(self, "%s_mdl" % task.name) - out = module.forward(batch, word_embs_in_context, sent_mask, task, predict) + out = module.forward( + batch=batch, + word_embs_in_context=word_embs_in_context, + sent_mask=sent_mask, + task=task, + predict=predict, + ) elif isinstance(task, SequenceGenerationTask): out = self._seq_gen_forward(batch, task, predict) elif isinstance(task, (MultiRCTask, ReCoRDTask)): diff --git a/jiant/modules/edge_probing.py b/jiant/modules/edge_probing.py index 64208fb7b..fce4606bd 100644 --- a/jiant/modules/edge_probing.py +++ b/jiant/modules/edge_probing.py @@ -1,8 +1,6 @@ # Implementation of edge probing module. +from typing import Dict -from typing import Dict, Iterable - -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -80,19 +78,19 @@ def __init__(self, task, d_inp: int, task_params): if self.is_symmetric or self.single_sided: # Use None as dummy padding for readability, # so that we can index projs[1] and projs[2] - self.projs = [None, self.proj1, self.proj1] + self.projs = nn.ModuleList([None, self.proj1, self.proj1]) else: # Separate params for span2 self.proj2 = self._make_cnn_layer(d_inp) - self.projs = [None, self.proj1, self.proj2] + self.projs = nn.ModuleList([None, self.proj1, self.proj2]) # Span extractor, shared for both span1 and span2. self.span_extractor1 = self._make_span_extractor() if self.is_symmetric or self.single_sided: - self.span_extractors = [None, self.span_extractor1, self.span_extractor1] + self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor1]) else: self.span_extractor2 = self._make_span_extractor() - self.span_extractors = [None, self.span_extractor1, self.span_extractor2] + self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor2]) # Classifier gets concatenated projections of span1, span2 clf_input_dim = self.span_extractors[1].get_output_dim() @@ -131,11 +129,9 @@ def forward( """ out = {} - batch_size = word_embs_in_context.shape[0] - out["n_inputs"] = batch_size - # Apply projection CNN layer for each span. word_embs_in_context_t = word_embs_in_context.transpose(1, 2) # needed for CNN layer + se_proj1 = self.projs[1](word_embs_in_context_t).transpose(2, 1).contiguous() if not self.single_sided: se_proj2 = self.projs[2](word_embs_in_context_t).transpose(2, 1).contiguous() @@ -169,28 +165,10 @@ def forward( out["loss"] = self.compute_loss(logits[span_mask], batch["labels"][span_mask], task) if predict: - # Return preds as a list. - preds = self.get_predictions(logits) - out["preds"] = list(self.unbind_predictions(preds, span_mask)) + out["preds"] = self.get_predictions(logits) return out - def unbind_predictions(self, preds: torch.Tensor, masks: torch.Tensor) -> Iterable[np.ndarray]: - """ Unpack preds to varying-length numpy arrays. - - Args: - preds: [batch_size, num_targets, ...] - masks: [batch_size, num_targets] boolean mask - - Yields: - np.ndarray for each row of preds, selected by the corresponding row - of span_mask. - """ - preds = preds.detach().cpu() - masks = masks.detach().cpu() - for pred, mask in zip(torch.unbind(preds, dim=0), torch.unbind(masks, dim=0)): - yield pred[mask].numpy() # only non-masked predictions - def get_predictions(self, logits: torch.Tensor): """Return class probabilities, same shape as logits. @@ -218,16 +196,6 @@ def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, task: EdgePro Returns: loss: scalar Tensor """ - binary_preds = logits.ge(0).long() # {0,1} - - # Matthews coefficient and accuracy computed on {0,1} labels. - task.mcc_scorer(binary_preds, labels.long()) - task.acc_scorer(binary_preds, labels.long()) - - # F1Measure() expects [total_num_targets, n_classes, 2] - # to compute binarized F1. - binary_scores = torch.stack([-1 * logits, logits], dim=2) - task.f1_scorer(binary_scores, labels) if self.loss_type == "sigmoid": return F.binary_cross_entropy(torch.sigmoid(logits), labels.float()) diff --git a/jiant/tasks/edge_probing.py b/jiant/tasks/edge_probing.py index 6821441ef..e81cd940a 100644 --- a/jiant/tasks/edge_probing.py +++ b/jiant/tasks/edge_probing.py @@ -3,6 +3,7 @@ import itertools import logging as log import os +import torch from typing import Dict, Iterable, List, Sequence, Type # Fields for instance processing @@ -159,6 +160,45 @@ def load_data(self): iters_by_split[split] = iter self._iters_by_split = iters_by_split + def update_metrics(self, out, batch): + span_mask = batch["span1s"][:, :, 0] != -1 + logits = out["logits"][span_mask] + labels = batch["labels"][span_mask] + + binary_preds = logits.ge(0).long() # {0,1} + + # Matthews coefficient and accuracy computed on {0,1} labels. + self.mcc_scorer(binary_preds, labels.long()) + self.acc_scorer(binary_preds, labels.long()) + + # F1Measure() expects [total_num_targets, n_classes, 2] + # to compute binarized F1. + binary_scores = torch.stack([-1 * logits, logits], dim=2) + self.f1_scorer(binary_scores, labels) + + def handle_preds(self, preds, batch): + """Unpack preds into varying-length numpy arrays, return the non-masked preds in a list. + + Parameters + ---------- + preds : [batch_size, num_targets, ...] + batch : dict + dict with key "span1s" having val w/ bool Tensor dim [batch_size, num_targets, ...]. + + Returns + ------- + non_masked_preds : list[np.ndarray] + list of of pred np.ndarray selected by the corresponding row of span_mask. + + """ + masks = batch["span1s"][:, :, 0] != -1 + preds = preds.detach().cpu() + masks = masks.detach().cpu() + non_masked_preds = [] + for pred, mask in zip(torch.unbind(preds, dim=0), torch.unbind(masks, dim=0)): + non_masked_preds.append(pred[mask].numpy()) # only non-masked predictions + return non_masked_preds + def get_split_text(self, split: str): """ Get split text as iterable of records.