-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Pytorch MaxP Feature/ptmaxp #184
Merged
andrewyates
merged 51 commits into
capreolus-ir:master
from
crystina-z:feature/eval+ptmaxp
Aug 6, 2022
Merged
Changes from all commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
5f6ab05
first version of benchmark.eval with ir-measures
crystina-z 018e5ca
benchmark.eval add relevance level support
crystina-z b0a4502
minor fix
crystina-z b25bbd4
remove msmarco-eval
crystina-z 90920e9
clean
crystina-z 2691d76
change all measures into str repr to avoid black problem
crystina-z ee3a0ef
skip evaluation if there is no matching qids
crystina-z f24bd3d
speed up training data prep - use set rather than list for train-qids…
crystina-z f915e45
add pt-maxp (train 30k + rerank top100: MRR@10=0.329)
crystina-z edace93
adapt config msmarco for pt monobert
crystina-z 8cf1ef4
remove tqdm
crystina-z bd3d3aa
add decay into msmarco config
crystina-z 2d41e28
fix import
crystina-z 305ff92
add notes to ptmaxp
crystina-z 16d6bdc
add shape for CE loss
crystina-z 9c796c4
change sampling logic of pairsampler - sample one pos and neg at once…
crystina-z 9d891d4
shuffle loaded tfrecord dataset
crystina-z 18c144e
MSMARCO reproductino logs - nima
nimasadri11 ce4444d
Merge pull request #1 from nimasadri11/master
crystina-z ce392ac
tf amp: use both / None to align with pt
crystina-z 23dcb3f
ms marco prepro doc; MRR@10=0.352 for pt-maxp; MRR@10=0.354 for tf-ma…
crystina-z addcb98
merge
crystina-z 133de84
cross entropy; use avg rather than sum
crystina-z 24cee86
support firstp, sump, avgp (same score on msp-v1)
crystina-z b5e7448
config for pt-maxp (rob04)
crystina-z 9137bec
support eval dev and external runfile using external ckpt (dir)
crystina-z e788e9a
Update repro log for MS MARCO passage ranking task
leungjch 1c570c3
Merge pull request #2 from leungjch/justin/update-repro-oct-19
crystina-z 5d9fe65
Update msmarco reproduction log
edanerg c1bce9b
Fix markdown
edanerg 65f0117
Merge branch 'feature/eval+ptmaxp' of github.com:crystina-z/capreolus…
crystina-z b730b98
add training flag to id2vec() to control different data format during…
crystina-z f2039ac
cleanup pt-maxp; mRR@10=0.352
crystina-z 581ac27
Merge pull request #3 from AlexWang000/feature/eval+ptmaxp
crystina-z 78d54be
revert the files that involving changing evaluation s.t. the PR isn't…
crystina-z 7a7de77
merge with master
crystina-z 3db0ff9
clean
crystina-z a87bfe7
adapt lce-passage extractor to the new extractor framework
crystina-z ef0f73d
make default msmarco-lce config a "small" version
crystina-z 2edbb47
update repro doc
crystina-z 38407df
update config msmarco
crystina-z 10e0dc6
clean
crystina-z 7a1ec64
first attmp to solve issue when warmup==epoch==1
crystina-z c263868
allow extrector to pad queries to the specified length
crystina-z db5e1ee
newline at the end of file
crystina-z ae536a5
black
crystina-z ea7e04a
dead code
crystina-z cdd90f3
bugfix
crystina-z 95fd1d4
change the id2vec test case; so that the testing n-passage is 1
crystina-z 30f3096
revert quick.md
crystina-z db0e405
for birch extractor; move the create_tf_train_feature and parse_tf_tr…
crystina-z File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
|
||
from capreolus import get_logger | ||
from capreolus.utils.exceptions import MissingDocError | ||
from . import Extractor | ||
from .bertpassage import BertPassage | ||
from .common import MultipleTrainingPassagesMixin | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
@Extractor.register | ||
class BirchBertPassage(MultipleTrainingPassagesMixin, BertPassage): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inherit the |
||
module_name = "birchbertpassage" | ||
|
||
config_spec = BertPassage.config_spec | ||
|
||
def id2vec(self, qid, posid, negid=None, label=None, **kwargs): | ||
""" | ||
See parent class for docstring | ||
""" | ||
assert label is not None | ||
maxseqlen = self.config["maxseqlen"] | ||
numpassages = self.config["numpassages"] | ||
|
||
query_toks = self.qid2toks[qid] | ||
pos_bert_inputs, pos_bert_masks, pos_bert_segs = [], [], [] | ||
|
||
# N.B: The passages in self.docid2passages are not bert tokenized | ||
pos_passages = self._get_passages(posid) | ||
for tokenized_passage in pos_passages: | ||
inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage) | ||
pos_bert_inputs.append(inp) | ||
pos_bert_masks.append(mask) | ||
pos_bert_segs.append(seg) | ||
|
||
# TODO: Rename the posdoc key in the below dict to 'pos_bert_input' | ||
data = { | ||
"qid": qid, | ||
"posdocid": posid, | ||
"pos_bert_input": np.array(pos_bert_inputs, dtype=np.long), | ||
"pos_mask": np.array(pos_bert_masks, dtype=np.long), | ||
"pos_seg": np.array(pos_bert_segs, dtype=np.long), | ||
"negdocid": "", | ||
"neg_bert_input": np.zeros((numpassages, maxseqlen), dtype=np.long), | ||
"neg_mask": np.zeros((numpassages, maxseqlen), dtype=np.long), | ||
"neg_seg": np.zeros((numpassages, maxseqlen), dtype=np.long), | ||
"label": np.repeat(np.array([label], dtype=np.float32), numpassages, 0), | ||
} | ||
|
||
if not negid: | ||
return data | ||
|
||
neg_bert_inputs, neg_bert_masks, neg_bert_segs = [], [], [] | ||
neg_passages = self._get_passages(negid) | ||
|
||
for tokenized_passage in neg_passages: | ||
inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage) | ||
neg_bert_inputs.append(inp) | ||
neg_bert_masks.append(mask) | ||
neg_bert_segs.append(seg) | ||
|
||
if not neg_bert_inputs: | ||
raise MissingDocError(qid, negid) | ||
|
||
data["negdocid"] = negid | ||
data["neg_bert_input"] = np.array(neg_bert_inputs, dtype=np.long) | ||
data["neg_mask"] = np.array(neg_bert_masks, dtype=np.long) | ||
data["neg_seg"] = np.array(neg_bert_segs, dtype=np.long) | ||
return data |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicitly for training, this function randomly select one passage from the
n-passages
, this is done in extractor now so thatpytorch
andtensorflow
trainer can both use it.