diff --git a/capreolus/extractor/lce_bertpassage.py b/capreolus/extractor/lce_bertpassage.py index 7045f3ce..cc58159f 100644 --- a/capreolus/extractor/lce_bertpassage.py +++ b/capreolus/extractor/lce_bertpassage.py @@ -48,13 +48,9 @@ def transpose_neg_input(neg_inp): label = sample["label"] features = [] - assert nneg == len(sample["negdocid"]), f"Received number of negative examples does not match config (where nneg={nneg})." - # negdoc = tf.transpose(negdoc, perm=[1, 0, 2]) - # negdoc = tf.cast(negdoc, tf.int64) - # negdoc_mask = tf.transpose(negdoc_mask, perm=[1, 0, 2]) - # negdoc_mask = tf.cast(negdoc_mask, tf.int64) - # negdoc_seg = tf.transpose(negdoc_seg, perm=[1, 0, 2]) - # negdoc_seg = tf.cast(negdoc_seg, tf.int64) + if nneg != len(sample["negdocid"]): + raise ValueError(f"Received number of negative examples does not match config (where nneg={nneg}).") + negdoc = transpose_neg_input(negdoc) negdoc_seg = transpose_neg_input(negdoc_seg) negdoc_mask = transpose_neg_input(negdoc_mask) @@ -155,10 +151,11 @@ def id2vec(self, qid, posid, negids=None, label=None): if negids is None: return data - assert nneg == len(negids), ( - f"Number of the given negative ids does not match nneg={nneg} as in {self.module_name}.config." - f"Are you sure nneg is set the same number in Sampler and {self.module_name}?" - ) + if nneg != len(negids): + raise ValueError( + f"Number of the given negative ids does not match nneg={nneg} as in {self.module_name}.config. " + f"Are you sure nneg is set the same number in Sampler and {self.module_name}?" + ) data["negdocid"] = [] data["neg_bert_input"] = [] diff --git a/capreolus/trainer/tensorflow.py b/capreolus/trainer/tensorflow.py index c9a3347b..c39e1a70 100644 --- a/capreolus/trainer/tensorflow.py +++ b/capreolus/trainer/tensorflow.py @@ -92,8 +92,6 @@ class TensorflowTrainer(Trainer): ConfigOption("decayiters", 3), ConfigOption("decaytype", None), ConfigOption("amp", False, "use automatic mixed precision"), - ConfigOption("disableposition", False, "Whether to disable the positional embedding"), - ConfigOption("disablesegment", False, "Whether to disable the segment embedding"), ] config_keys_not_in_path = ["fastforward", "boardname", "usecache", "tpuname", "tpuzone", "storage"]