Skip to content

Commit

Permalink
Fixing Edge-Probing after muti-GPU release (#1025)
Browse files Browse the repository at this point in the history
* fixing update_metrics for EdgeProbing

* Throwing error on multi-GPU

* Fixing weight and model in different GPU multi-GPU error

* remove exception on multi-GPU

* remove unbind_predictions()

* move unbind_predictions into edge probing task handle_preds method

* update comments and docstrings

Co-authored-by: Yada Pruksachatkun <[email protected]>
Co-authored-by: Phil Yeres <[email protected]>
  • Loading branch information
3 people authored Mar 10, 2020
1 parent 57ea962 commit 333fc54
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 40 deletions.
8 changes: 7 additions & 1 deletion jiant/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,13 @@ def forward(self, task, batch, predict=False):
# Just get embeddings and invoke task module.
word_embs_in_context, sent_mask = self.sent_encoder(batch["input1"], task)
module = getattr(self, "%s_mdl" % task.name)
out = module.forward(batch, word_embs_in_context, sent_mask, task, predict)
out = module.forward(
batch=batch,
word_embs_in_context=word_embs_in_context,
sent_mask=sent_mask,
task=task,
predict=predict,
)
elif isinstance(task, SequenceGenerationTask):
out = self._seq_gen_forward(batch, task, predict)
elif isinstance(task, (MultiRCTask, ReCoRDTask)):
Expand Down
46 changes: 7 additions & 39 deletions jiant/modules/edge_probing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Implementation of edge probing module.
from typing import Dict

from typing import Dict, Iterable

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -80,19 +78,19 @@ def __init__(self, task, d_inp: int, task_params):
if self.is_symmetric or self.single_sided:
# Use None as dummy padding for readability,
# so that we can index projs[1] and projs[2]
self.projs = [None, self.proj1, self.proj1]
self.projs = nn.ModuleList([None, self.proj1, self.proj1])
else:
# Separate params for span2
self.proj2 = self._make_cnn_layer(d_inp)
self.projs = [None, self.proj1, self.proj2]
self.projs = nn.ModuleList([None, self.proj1, self.proj2])

# Span extractor, shared for both span1 and span2.
self.span_extractor1 = self._make_span_extractor()
if self.is_symmetric or self.single_sided:
self.span_extractors = [None, self.span_extractor1, self.span_extractor1]
self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor1])
else:
self.span_extractor2 = self._make_span_extractor()
self.span_extractors = [None, self.span_extractor1, self.span_extractor2]
self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor2])

# Classifier gets concatenated projections of span1, span2
clf_input_dim = self.span_extractors[1].get_output_dim()
Expand Down Expand Up @@ -131,11 +129,9 @@ def forward(
"""
out = {}

batch_size = word_embs_in_context.shape[0]
out["n_inputs"] = batch_size

# Apply projection CNN layer for each span.
word_embs_in_context_t = word_embs_in_context.transpose(1, 2) # needed for CNN layer

se_proj1 = self.projs[1](word_embs_in_context_t).transpose(2, 1).contiguous()
if not self.single_sided:
se_proj2 = self.projs[2](word_embs_in_context_t).transpose(2, 1).contiguous()
Expand Down Expand Up @@ -169,28 +165,10 @@ def forward(
out["loss"] = self.compute_loss(logits[span_mask], batch["labels"][span_mask], task)

if predict:
# Return preds as a list.
preds = self.get_predictions(logits)
out["preds"] = list(self.unbind_predictions(preds, span_mask))
out["preds"] = self.get_predictions(logits)

return out

def unbind_predictions(self, preds: torch.Tensor, masks: torch.Tensor) -> Iterable[np.ndarray]:
""" Unpack preds to varying-length numpy arrays.
Args:
preds: [batch_size, num_targets, ...]
masks: [batch_size, num_targets] boolean mask
Yields:
np.ndarray for each row of preds, selected by the corresponding row
of span_mask.
"""
preds = preds.detach().cpu()
masks = masks.detach().cpu()
for pred, mask in zip(torch.unbind(preds, dim=0), torch.unbind(masks, dim=0)):
yield pred[mask].numpy() # only non-masked predictions

def get_predictions(self, logits: torch.Tensor):
"""Return class probabilities, same shape as logits.
Expand Down Expand Up @@ -218,16 +196,6 @@ def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, task: EdgePro
Returns:
loss: scalar Tensor
"""
binary_preds = logits.ge(0).long() # {0,1}

# Matthews coefficient and accuracy computed on {0,1} labels.
task.mcc_scorer(binary_preds, labels.long())
task.acc_scorer(binary_preds, labels.long())

# F1Measure() expects [total_num_targets, n_classes, 2]
# to compute binarized F1.
binary_scores = torch.stack([-1 * logits, logits], dim=2)
task.f1_scorer(binary_scores, labels)

if self.loss_type == "sigmoid":
return F.binary_cross_entropy(torch.sigmoid(logits), labels.float())
Expand Down
40 changes: 40 additions & 0 deletions jiant/tasks/edge_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import logging as log
import os
import torch
from typing import Dict, Iterable, List, Sequence, Type

# Fields for instance processing
Expand Down Expand Up @@ -159,6 +160,45 @@ def load_data(self):
iters_by_split[split] = iter
self._iters_by_split = iters_by_split

def update_metrics(self, out, batch):
span_mask = batch["span1s"][:, :, 0] != -1
logits = out["logits"][span_mask]
labels = batch["labels"][span_mask]

binary_preds = logits.ge(0).long() # {0,1}

# Matthews coefficient and accuracy computed on {0,1} labels.
self.mcc_scorer(binary_preds, labels.long())
self.acc_scorer(binary_preds, labels.long())

# F1Measure() expects [total_num_targets, n_classes, 2]
# to compute binarized F1.
binary_scores = torch.stack([-1 * logits, logits], dim=2)
self.f1_scorer(binary_scores, labels)

def handle_preds(self, preds, batch):
"""Unpack preds into varying-length numpy arrays, return the non-masked preds in a list.
Parameters
----------
preds : [batch_size, num_targets, ...]
batch : dict
dict with key "span1s" having val w/ bool Tensor dim [batch_size, num_targets, ...].
Returns
-------
non_masked_preds : list[np.ndarray]
list of of pred np.ndarray selected by the corresponding row of span_mask.
"""
masks = batch["span1s"][:, :, 0] != -1
preds = preds.detach().cpu()
masks = masks.detach().cpu()
non_masked_preds = []
for pred, mask in zip(torch.unbind(preds, dim=0), torch.unbind(masks, dim=0)):
non_masked_preds.append(pred[mask].numpy()) # only non-masked predictions
return non_masked_preds

def get_split_text(self, split: str):
""" Get split text as iterable of records.
Expand Down

0 comments on commit 333fc54

Please sign in to comment.