Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
crystina-z committed Jan 9, 2022
1 parent 3e0c979 commit dc21935
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
19 changes: 8 additions & 11 deletions capreolus/extractor/lce_bertpassage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,9 @@ def transpose_neg_input(neg_inp):
label = sample["label"]
features = []

assert nneg == len(sample["negdocid"]), f"Received number of negative examples does not match config (where nneg={nneg})."
# negdoc = tf.transpose(negdoc, perm=[1, 0, 2])
# negdoc = tf.cast(negdoc, tf.int64)
# negdoc_mask = tf.transpose(negdoc_mask, perm=[1, 0, 2])
# negdoc_mask = tf.cast(negdoc_mask, tf.int64)
# negdoc_seg = tf.transpose(negdoc_seg, perm=[1, 0, 2])
# negdoc_seg = tf.cast(negdoc_seg, tf.int64)
if nneg != len(sample["negdocid"]):
raise ValueError(f"Received number of negative examples does not match config (where nneg={nneg}).")

negdoc = transpose_neg_input(negdoc)
negdoc_seg = transpose_neg_input(negdoc_seg)
negdoc_mask = transpose_neg_input(negdoc_mask)
Expand Down Expand Up @@ -155,10 +151,11 @@ def id2vec(self, qid, posid, negids=None, label=None):
if negids is None:
return data

assert nneg == len(negids), (
f"Number of the given negative ids does not match nneg={nneg} as in {self.module_name}.config."
f"Are you sure nneg is set the same number in Sampler and {self.module_name}?"
)
if nneg != len(negids):
raise ValueError(
f"Number of the given negative ids does not match nneg={nneg} as in {self.module_name}.config. "
f"Are you sure nneg is set the same number in Sampler and {self.module_name}?"
)

data["negdocid"] = []
data["neg_bert_input"] = []
Expand Down
2 changes: 0 additions & 2 deletions capreolus/trainer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ class TensorflowTrainer(Trainer):
ConfigOption("decayiters", 3),
ConfigOption("decaytype", None),
ConfigOption("amp", False, "use automatic mixed precision"),
ConfigOption("disableposition", False, "Whether to disable the positional embedding"),
ConfigOption("disablesegment", False, "Whether to disable the segment embedding"),
]
config_keys_not_in_path = ["fastforward", "boardname", "usecache", "tpuname", "tpuzone", "storage"]

Expand Down

0 comments on commit dc21935

Please sign in to comment.