From fbe8bcf99862e92c0ad0833dca4f2bf022551a44 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 31 Oct 2023 11:01:36 +0100 Subject: [PATCH] Creating new bool tensor straight away instead of regular clone --- paper/src/ents_heads_extraction.py | 1 - .../relationextraction/multi2oie/other/bio.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/paper/src/ents_heads_extraction.py b/paper/src/ents_heads_extraction.py index 708c911..f3e05a8 100644 --- a/paper/src/ents_heads_extraction.py +++ b/paper/src/ents_heads_extraction.py @@ -1,7 +1,6 @@ """Pipeline for headwords/entities extractions and frequency count.""" import spacy -from relationextraction import SpacyRelationExtractor # noqa from conspiracies.docprocessing.headwordextraction import contains_ents diff --git a/src/conspiracies/docprocessing/relationextraction/multi2oie/other/bio.py b/src/conspiracies/docprocessing/relationextraction/multi2oie/other/bio.py index df9c037..98e6104 100644 --- a/src/conspiracies/docprocessing/relationextraction/multi2oie/other/bio.py +++ b/src/conspiracies/docprocessing/relationextraction/multi2oie/other/bio.py @@ -32,10 +32,10 @@ def get_pred_mask(tensor): where B is the batch size, L is the sequence length. :return: masked binary tensor with the same shape. """ - res = tensor.clone() - res[tensor == pred_tag2idx["O"]] = 1 - res[tensor != pred_tag2idx["O"]] = 0 - return res.bool() + res = tensor.bool() + res[tensor == pred_tag2idx["O"]] = True + res[tensor != pred_tag2idx["O"]] = False + return res def filter_pred_tags(pred_tags, tokens):