Skip to content

Commit

Permalink
Fixes for Pytorch CEDR LR and hgf output (#126)
Browse files Browse the repository at this point in the history
* Fix Pytorch CEDR with BERT weights
* Fix Pytorch LR schedule after iters/steps change
* Enable TF similarity matrix padding
* Add smart_open dependency
* Change spacy dependency to <3.0
  • Loading branch information
andrewyates authored Feb 1, 2021
1 parent 11a232d commit afa5b7e
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 14 deletions.
6 changes: 5 additions & 1 deletion capreolus/reranker/CEDRKNRM.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def forward(self, bert_input, bert_mask, bert_segments):
bert_segments = bert_segments.view((batch_size * self.num_passages, self.maxseqlen))

# get BERT embeddings (including CLS) for each passage
bert_output, all_layer_output = self.bert(bert_input, attention_mask=bert_mask, token_type_ids=bert_segments)
# TODO switch to hgf's ModelOutput after bumping tranformers version
outputs = self.bert(bert_input, attention_mask=bert_mask, token_type_ids=bert_segments)
if self.config["pretrained"].startswith("bert-"):
outputs = (outputs[0], outputs[2])
bert_output, all_layer_output = outputs

# average CLS embeddings to create the CLS feature
cls = bert_output[:, 0, :]
Expand Down
11 changes: 6 additions & 5 deletions capreolus/reranker/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,13 @@ def new_similarity_matrix_tf(query_embed, doc_embed, query_tok, doc_tok, padding
batch, qlen, dims = query_embed.shape
doclen = doc_embed.shape[1]

# TODO apply mask for use in stuff other than KNRM
query_embed = tf.reshape(tf.nn.l2_normalize(query_embed, axis=-1), [batch, qlen, 1, dims])
# query_padding = tf.reshape(tf.cast(query_tok != padding, query_embed.dtype), [batch, qlen, 1, 1])
# query_embed = query_embed * query_padding
query_padding = tf.reshape(tf.cast(query_tok != padding, query_embed.dtype), [batch, qlen, 1, 1])
query_embed = query_embed * query_padding

doc_embed = tf.reshape(tf.nn.l2_normalize(doc_embed, axis=-1), [batch, 1, doclen, dims])
# doc_padding = tf.reshape(tf.cast(doc_tok != padding, doc_embed.dtype), [batch, 1, doclen, 1])
# doc_embed = doc_embed * doc_padding
doc_padding = tf.reshape(tf.cast(doc_tok != padding, doc_embed.dtype), [batch, 1, doclen, 1])
doc_embed = doc_embed * doc_padding

simmat = tf.reduce_sum(query_embed * doc_embed, axis=-1, keepdims=True)
return simmat
Expand Down Expand Up @@ -142,6 +141,8 @@ def forward(self, query_tok, doc_tok):
return simmat


# TODO replace this with newer ONIR version?
# https://github.com/Georgetown-IR-Lab/OpenNIR/blob/ca14dfa5e7cfef3fbbb35efbb4e7df0f1fbde590/onir/modules/interaction_matrix.py#L27
class StackedSimilarityMatrix(torch.nn.Module):
# based on SimmatModule from https://github.com/Georgetown-IR-Lab/cedr/blob/master/modeling_util.py
# which is copyright (c) 2019 Georgetown Information Retrieval Lab, MIT license
Expand Down
6 changes: 4 additions & 2 deletions capreolus/trainer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,10 @@ def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output
self.amp_train_autocast = contextlib.nullcontext
self.scaler = None

# REF-TODO how to handle interactions between fastforward and schedule?
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, self.lr_multiplier)
# REF-TODO how to handle interactions between fastforward and schedule? --> just save its state
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer, lambda epoch: self.lr_multiplier(step=epoch * self.n_batch_per_iter)
)

if self.config["softmaxloss"]:
self.loss = pair_softmax_loss
Expand Down
7 changes: 4 additions & 3 deletions docs/reproduction/CEDR-KNRM.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Capreolus: Reranking robust04 with CEDR-KNRM
This page contains instructions for running CEDR-KNRM on the robust04 benchmark.

[*CEDR: Contextualized Embeddings for Document Ranking*](https://arxiv.org/pdf/1904.07094.pd)f
[*CEDR: Contextualized Embeddings for Document Ranking*](https://arxiv.org/pdf/1904.07094.pdf).
Sean MacAvaney, Andrew Yates, Arman Cohan, and Nazli Goharian. SIGIR 2019.

## Setup
Install Capreolus v0.2.6 or later. See the [installation guide](https://capreolus.ai/en/latest/installation.html) for help installing a release. To install from GitHub, see the [PARADE reproduction guide](https://github.com/capreolus-ir/capreolus/blob/master/docs/reproduction/PARADE.md).
Install Capreolus v0.2.6 or later. See the [installation guide](https://capreolus.ai/en/latest/installation.html) for help installing a release. To install from GitHub, see the [PARADE guide](https://github.com/capreolus-ir/capreolus/blob/master/docs/reproduction/PARADE.md).

## Running CEDR-KNRM

Expand Down Expand Up @@ -38,7 +38,8 @@ When using a less powerful GPU or disabling mixed precision (`reranker.trainer.a
3. Each command will take a few hours on a single V100 GPU. Per-fold metrics are displayed after each fold completes.
4. When the final fold completes, cross-validated metrics are also displayed.

Note that the Tensorflow implementation has only been tested on TPUs.
Note that the Tensorflow implementation has primarily been tested on TPUs.


## Running BERT-KNRM, VanillaBERT, and other model variants

Expand Down
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
dependencies:
- python=3.7
- pandas
- spacy
- spacy<3.0
- numpy
- scipy
- matplotlib
Expand Down Expand Up @@ -48,3 +48,4 @@ dependencies:
- xxhash
- annoy
- fasteners
- smart_open
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ Pillow
beautifulsoup4
lxml
scispacy
spacy
smart_open
spacy<3.0
pandas
https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.4/en_core_sci_lg-0.2.4.tar.gz
# deps that the pymagnitude package isn't pulling in:
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def get_version(rel_path):
"beautifulsoup4",
"lxml",
"scispacy",
"spacy",
"smart_open",
"spacy<3.0",
"pandas",
],
classifiers=["Programming Language :: Python :: 3", "Operating System :: OS Independent"],
Expand Down

0 comments on commit afa5b7e

Please sign in to comment.