diff --git a/forte_wrapper/allennlp/allennlp_processors.py b/forte_wrapper/allennlp/allennlp_processors.py index 2e9f2d6..c49198b 100644 --- a/forte_wrapper/allennlp/allennlp_processors.py +++ b/forte_wrapper/allennlp/allennlp_processors.py @@ -76,8 +76,8 @@ def initialize(self, resources: Resources, configs: Config): "that the entries of the same type as produced by " "this processor will be overwritten if found.") if configs.allow_parallel_entries: - logger.warning('Both `overwrite_entries` (whether to overwrite ' - 'the entries of the same type as produced by ' + logger.warning('Both `overwrite_entries` (whether to overwrite' + ' the entries of the same type as produced by ' 'this processor) and ' '`allow_parallel_entries` (whether to allow ' 'similar new entries when they already exist) ' @@ -85,8 +85,8 @@ def initialize(self, resources: Resources, configs: Config): 'will be deleted.') else: if not configs.allow_parallel_entries: - logger.warning('Both `overwrite_entries` (whether to overwrite ' - 'the entries of the same type as produced by ' + logger.warning('Both `overwrite_entries` (whether to overwrite' + ' the entries of the same type as produced by ' 'this processor) and ' '`allow_parallel_entries` (whether to allow ' 'similar new entries when they already exist) ' @@ -133,19 +133,22 @@ def default_configs(cls): def _process(self, input_pack: DataPack): # handle existing entries self._process_existing_entries(input_pack) - - for sentence in input_pack.get(Sentence): - result: Dict[str, List[str]] = {} - for key in self.predictor: - predicted_result = self.predictor[key].predict( # type: ignore - sentence=sentence.text) + sentences = [_ for _ in input_pack.get(Sentence)] + inputs = [{"sentence": s.text} for s in sentences] + results = {k: p.predict_batch_json(inputs) + for k, p in self.predictor.items()} + for i in range(len(sentences)): + result = {} + for key in self.predictor.keys(): if key == 'srl': - predicted_result = parse_allennlp_srl_results( - predicted_result['verbs']) - result.update(predicted_result) + result.update( + parse_allennlp_srl_results(results[key][i]["verbs"]) + ) + else: + result.update(results[key][i]) if "tokenize" in self.configs.processors: # creating new tokens and dependencies - tokens = self._create_tokens(input_pack, sentence, result) + tokens = self._create_tokens(input_pack, sentences[i], result) if "depparse" in self.configs.processors: self._create_dependencies(input_pack, tokens, result) if 'srl' in self.configs.processors: @@ -159,8 +162,8 @@ def _process_existing_entries(self, input_pack): if not self.configs.overwrite_entries: if not self.configs.allow_parallel_entries: raise ProcessorConfigError( - "Found existing entries, either `overwrite_entries` or " - "`allow_parallel_entries` should be True") + "Found existing entries, either `overwrite_entries` or" + " `allow_parallel_entries` should be True") else: # delete existing tokens and dependencies for entry_type in (Token, Dependency):