Skip to content

Commit

Permalink
Support PyTorch MaxP (#184)
Browse files Browse the repository at this point in the history
* added PyTorch MaxP
* update the output of bertpassage id2vec function, so that it's compatible to both tf-maxp and pt-maxp
* update the other extractor accordingly
* updated the test case and repro docs

Co-authored-by: Nima Sadri <[email protected]>
Co-authored-by: Justin <[email protected]>
Co-authored-by: Yuetong Wang <[email protected]>
  • Loading branch information
4 people authored Aug 6, 2022
1 parent a568304 commit 5946640
Show file tree
Hide file tree
Showing 22 changed files with 577 additions and 519 deletions.
2 changes: 1 addition & 1 deletion capreolus/extractor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _build_vocab(self, qids, docids, topics):
def build_from_benchmark(self, *args, **kwargs):
raise NotImplementedError

def id2vec(self, qid, posdocid, negdocid=None, label=None):
def id2vec(self, qid, posdocid, negdocid=None, label=None, *args, **kwargs):
"""
Creates a feature from the (qid, docid) pair.
If negdocid is supplied, that's also included in the feature (needed for training with pairwise hinge loss)
Expand Down
2 changes: 1 addition & 1 deletion capreolus/extractor/bagofwords.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def preprocess(self, qids, docids, topics):

self._build_vocab(qids, docids, topics)

def id2vec(self, q_id, posdoc_id, negdoc_id=None, **kwargs):
def id2vec(self, q_id, posdoc_id, negdoc_id=None, *args, **kwargs):
query_toks = self.qid2toks[q_id]
posdoc_toks = self.docid2toks.get(posdoc_id)

Expand Down
164 changes: 62 additions & 102 deletions capreolus/extractor/bertpassage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from capreolus.utils.exceptions import MissingDocError
from capreolus.tokenizer.punkt import PunktTokenizer

from .common import SingleTrainingPassagesMixin

logger = get_logger(__name__)


@Extractor.register
class BertPassage(Extractor):
class BertPassage(Extractor, SingleTrainingPassagesMixin):
"""
Extracts passages from the document to be later consumed by a BERT based model.
Does NOT use all the passages. The first passages is always used. Use the `prob` config to control the probability
Expand All @@ -37,6 +39,7 @@ class BertPassage(Extractor):
config_spec = [
ConfigOption("maxseqlen", 256, "Maximum input length (query+document)"),
ConfigOption("maxqlen", 20, "Maximum query length"),
ConfigOption("padq", False, "Always pad queries to maxqlen"),
ConfigOption("usecache", False, "Should the extracted features be cached?"),
ConfigOption("passagelen", 150, "Length of the extracted passage"),
ConfigOption("stride", 100, "Stride"),
Expand Down Expand Up @@ -85,60 +88,6 @@ def get_tf_feature_description(self):

return feature_description

def create_tf_train_feature(self, sample):
"""
Returns a set of features from a doc.
Of the num_passages passages that are present in a document, we use only a subset of it.
params:
sample - A dict where each entry has the shape [batch_size, num_passages, maxseqlen]
Returns a list of features. Each feature is a dict, and each value in the dict has the shape [batch_size, maxseqlen].
Yes, the output shape is different to the input shape because we sample from the passages.
"""
num_passages = self.config["numpassages"]

def _bytes_feature(value):
"""Returns a bytes_list from a string / byte. Our features are multi-dimensional tensors."""
if isinstance(value, type(tf.constant(0))): # if value ist tensor
value = value.numpy() # get value of tensor
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

posdoc, negdoc, negdoc_id = sample["pos_bert_input"], sample["neg_bert_input"], sample["negdocid"]
posdoc_mask, posdoc_seg, negdoc_mask, negdoc_seg = (
sample["pos_mask"],
sample["pos_seg"],
sample["neg_mask"],
sample["neg_seg"],
)
label = sample["label"]
features = []

for i in range(num_passages):
# Always use the first passage, then sample from the remaining passages
if i > 0 and self.rng.random() > self.config["prob"]:
continue

bert_input_line = posdoc[i]
bert_input_line = " ".join(self.tokenizer.bert_tokenizer.convert_ids_to_tokens(list(bert_input_line)))
passage = bert_input_line.split(self.sep_tok)[-2]

# Ignore empty passages as well
if passage.strip() == self.pad_tok:
continue

feature = {
"pos_bert_input": _bytes_feature(tf.io.serialize_tensor(posdoc[i])),
"pos_mask": _bytes_feature(tf.io.serialize_tensor(posdoc_mask[i])),
"pos_seg": _bytes_feature(tf.io.serialize_tensor(posdoc_seg[i])),
"neg_bert_input": _bytes_feature(tf.io.serialize_tensor(negdoc[i])),
"neg_mask": _bytes_feature(tf.io.serialize_tensor(negdoc_mask[i])),
"neg_seg": _bytes_feature(tf.io.serialize_tensor(negdoc_seg[i])),
"label": _bytes_feature(tf.io.serialize_tensor(label[i])),
}
features.append(feature)

return features

def create_tf_dev_feature(self, sample):
"""
Unlike the train feature, the dev set uses all passages. Both the input and the output are dicts with the shape
Expand Down Expand Up @@ -171,13 +120,13 @@ def _bytes_feature(value):

return [feature]

def parse_tf_train_example(self, example_proto):
def parse_tf_dev_example(self, example_proto):
feature_description = self.get_tf_feature_description()
parsed_example = tf.io.parse_example(example_proto, feature_description)

def parse_tensor_as_int(x):
parsed_tensor = tf.io.parse_tensor(x, tf.int64)
parsed_tensor.set_shape([self.config["maxseqlen"]])
parsed_tensor.set_shape([self.config["numpassages"], self.config["maxseqlen"]])

return parsed_tensor

Expand All @@ -197,31 +146,31 @@ def parse_label_tensor(x):

return (pos_bert_input, pos_mask, pos_seg, neg_bert_input, neg_mask, neg_seg), label

def parse_tf_dev_example(self, example_proto):
feature_description = self.get_tf_feature_description()
parsed_example = tf.io.parse_example(example_proto, feature_description)

def parse_tensor_as_int(x):
parsed_tensor = tf.io.parse_tensor(x, tf.int64)
parsed_tensor.set_shape([self.config["numpassages"], self.config["maxseqlen"]])

return parsed_tensor

def parse_label_tensor(x):
parsed_tensor = tf.io.parse_tensor(x, tf.float32)
parsed_tensor.set_shape([self.config["numpassages"], 2])
def _filter_inputs(self, bert_inputs, bert_masks, bert_segs, n_valid_psg):
"""Preserve only one passage from all available passages."""
assert n_valid_psg <= len(
bert_inputs
), f"Passages only have {len(bert_inputs)} entries, but got {n_valid_psg} valid passages."
valid_indexes = list(range(0, n_valid_psg))
if len(valid_indexes) == 0:
valid_indexes = [0]
random_i = self.rng.choice(valid_indexes)
return list(map(lambda arr: arr[random_i], [bert_inputs, bert_masks, bert_segs]))

def _encode_inputs(self, query_toks, passages):
"""Convert the query and passages into BERT inputs, mask, segments."""
bert_inputs, bert_masks, bert_segs = [], [], []
n_valid_psg = 0
for tokenized_passage in passages:
if tokenized_passage != [self.pad_tok]: # end of the passage
n_valid_psg += 1

return parsed_tensor

pos_bert_input = tf.map_fn(parse_tensor_as_int, parsed_example["pos_bert_input"], dtype=tf.int64)
pos_mask = tf.map_fn(parse_tensor_as_int, parsed_example["pos_mask"], dtype=tf.int64)
pos_seg = tf.map_fn(parse_tensor_as_int, parsed_example["pos_seg"], dtype=tf.int64)
neg_bert_input = tf.map_fn(parse_tensor_as_int, parsed_example["neg_bert_input"], dtype=tf.int64)
neg_mask = tf.map_fn(parse_tensor_as_int, parsed_example["neg_mask"], dtype=tf.int64)
neg_seg = tf.map_fn(parse_tensor_as_int, parsed_example["neg_seg"], dtype=tf.int64)
label = tf.map_fn(parse_label_tensor, parsed_example["label"], dtype=tf.float32)
inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage)
bert_inputs.append(inp)
bert_masks.append(mask)
bert_segs.append(seg)

return (pos_bert_input, pos_mask, pos_seg, neg_bert_input, neg_mask, neg_seg), label
return bert_inputs, bert_masks, bert_segs, n_valid_psg

def _get_passages(self, docid):
doc = self.index.get_doc(docid)
Expand Down Expand Up @@ -321,60 +270,71 @@ def _prepare_bert_input(self, query_toks, psg_toks):
if len(query_toks) > maxqlen:
logger.warning(f"Truncating query from {len(query_toks)} to {maxqlen}")
query_toks = query_toks[:maxqlen]
else: # if the len(query_toks) <= maxqlen, whether to pad it
if self.config["padq"]:
query_toks = padlist(query_toks, padlen=maxqlen, pad_token=self.pad_tok)
psg_toks = psg_toks[: maxseqlen - len(query_toks) - 3]

psg_toks = " ".join(psg_toks).split() # in case that psg_toks is np.array
input_line = [self.cls_tok] + query_toks + [self.sep_tok] + psg_toks + [self.sep_tok]
padded_input_line = padlist(input_line, padlen=maxseqlen, pad_token=self.pad_tok)
inp = self.tokenizer.convert_tokens_to_ids(padded_input_line)
mask = [1] * len(input_line) + [0] * (len(padded_input_line) - len(input_line))
mask = [1 if tok != self.pad_tok else 0 for tok in input_line] + [0] * (len(padded_input_line) - len(input_line))
seg = [0] * (len(query_toks) + 2) + [1] * (len(padded_input_line) - len(query_toks) - 2)
return inp, mask, seg

def id2vec(self, qid, posid, negid=None, label=None):
def id2vec(self, qid, posid, negid=None, label=None, *args, **kwargs):
"""
See parent class for docstring
"""
training = kwargs.get("training", True) # default to be training

assert label is not None
maxseqlen = self.config["maxseqlen"]
numpassages = self.config["numpassages"]

query_toks = self.qid2toks[qid]
pos_bert_inputs, pos_bert_masks, pos_bert_segs = [], [], []

# N.B: The passages in self.docid2passages are not bert tokenized
pos_passages = self._get_passages(posid)
for tokenized_passage in pos_passages:
inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage)
pos_bert_inputs.append(inp)
pos_bert_masks.append(mask)
pos_bert_segs.append(seg)
pos_bert_inputs, pos_bert_masks, pos_bert_segs, n_valid_psg = self._encode_inputs(query_toks, pos_passages)
if training:
pos_bert_inputs, pos_bert_masks, pos_bert_segs = self._filter_inputs(
pos_bert_inputs, pos_bert_masks, pos_bert_segs, n_valid_psg
)
else:
assert len(pos_bert_inputs) == numpassages

pos_bert_inputs, pos_bert_masks, pos_bert_segs = map(
lambda lst: np.array(lst, dtype=np.long), [pos_bert_inputs, pos_bert_masks, pos_bert_segs]
)

# TODO: Rename the posdoc key in the below dict to 'pos_bert_input'
data = {
"qid": qid,
"posdocid": posid,
"pos_bert_input": np.array(pos_bert_inputs, dtype=np.long),
"pos_mask": np.array(pos_bert_masks, dtype=np.long),
"pos_seg": np.array(pos_bert_segs, dtype=np.long),
"pos_bert_input": pos_bert_inputs,
"pos_mask": pos_bert_masks,
"pos_seg": pos_bert_segs,
"negdocid": "",
"neg_bert_input": np.zeros((numpassages, maxseqlen), dtype=np.long),
"neg_mask": np.zeros((numpassages, maxseqlen), dtype=np.long),
"neg_seg": np.zeros((numpassages, maxseqlen), dtype=np.long),
"label": np.repeat(np.array([label], dtype=np.float32), numpassages, 0),
"neg_bert_input": np.zeros_like(pos_bert_inputs, dtype=np.long),
"neg_mask": np.zeros_like(pos_bert_masks, dtype=np.long),
"neg_seg": np.zeros_like(pos_bert_segs, dtype=np.long),
"label": np.array(label, dtype=np.float32),
# ^^^ not change the shape of the label as it is only needed during training
}

if not negid:
return data

neg_bert_inputs, neg_bert_masks, neg_bert_segs = [], [], []
neg_passages = self._get_passages(negid)

for tokenized_passage in neg_passages:
inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage)
neg_bert_inputs.append(inp)
neg_bert_masks.append(mask)
neg_bert_segs.append(seg)
neg_bert_inputs, neg_bert_masks, neg_bert_segs, n_valid_psg = self._encode_inputs(query_toks, neg_passages)
if training:
neg_bert_inputs, neg_bert_masks, neg_bert_segs = self._filter_inputs(
neg_bert_inputs, neg_bert_masks, neg_bert_segs, n_valid_psg
)
else:
assert len(neg_bert_inputs) == numpassages

if not neg_bert_inputs:
raise MissingDocError(qid, negid)
Expand Down
2 changes: 1 addition & 1 deletion capreolus/extractor/berttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def preprocess(self, qids, docids, topics):

self._build_vocab(qids, docids, topics)

def id2vec(self, qid, posid, negid=None):
def id2vec(self, qid, posid, negid=None, *args, **kwargs):
tokenizer = self.tokenizer
qlen, doclen = self.config["maxqlen"], self.config["maxdoclen"]

Expand Down
71 changes: 71 additions & 0 deletions capreolus/extractor/birch_bertpassage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import tensorflow as tf
import numpy as np

from capreolus import get_logger
from capreolus.utils.exceptions import MissingDocError
from . import Extractor
from .bertpassage import BertPassage
from .common import MultipleTrainingPassagesMixin

logger = get_logger(__name__)


@Extractor.register
class BirchBertPassage(MultipleTrainingPassagesMixin, BertPassage):
module_name = "birchbertpassage"

config_spec = BertPassage.config_spec

def id2vec(self, qid, posid, negid=None, label=None, **kwargs):
"""
See parent class for docstring
"""
assert label is not None
maxseqlen = self.config["maxseqlen"]
numpassages = self.config["numpassages"]

query_toks = self.qid2toks[qid]
pos_bert_inputs, pos_bert_masks, pos_bert_segs = [], [], []

# N.B: The passages in self.docid2passages are not bert tokenized
pos_passages = self._get_passages(posid)
for tokenized_passage in pos_passages:
inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage)
pos_bert_inputs.append(inp)
pos_bert_masks.append(mask)
pos_bert_segs.append(seg)

# TODO: Rename the posdoc key in the below dict to 'pos_bert_input'
data = {
"qid": qid,
"posdocid": posid,
"pos_bert_input": np.array(pos_bert_inputs, dtype=np.long),
"pos_mask": np.array(pos_bert_masks, dtype=np.long),
"pos_seg": np.array(pos_bert_segs, dtype=np.long),
"negdocid": "",
"neg_bert_input": np.zeros((numpassages, maxseqlen), dtype=np.long),
"neg_mask": np.zeros((numpassages, maxseqlen), dtype=np.long),
"neg_seg": np.zeros((numpassages, maxseqlen), dtype=np.long),
"label": np.repeat(np.array([label], dtype=np.float32), numpassages, 0),
}

if not negid:
return data

neg_bert_inputs, neg_bert_masks, neg_bert_segs = [], [], []
neg_passages = self._get_passages(negid)

for tokenized_passage in neg_passages:
inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage)
neg_bert_inputs.append(inp)
neg_bert_masks.append(mask)
neg_bert_segs.append(seg)

if not neg_bert_inputs:
raise MissingDocError(qid, negid)

data["negdocid"] = negid
data["neg_bert_input"] = np.array(neg_bert_inputs, dtype=np.long)
data["neg_mask"] = np.array(neg_bert_masks, dtype=np.long)
data["neg_seg"] = np.array(neg_bert_segs, dtype=np.long)
return data
Loading

0 comments on commit 5946640

Please sign in to comment.