Skip to content

Commit

Permalink
Merge pull request #7 from DianLiI/improve-allen-efficiency
Browse files Browse the repository at this point in the history
Improve AllenNLPProcessor's performance
  • Loading branch information
hunterhector authored Apr 5, 2021
2 parents cda5876 + b87fca0 commit 9cb5724
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions forte_wrapper/allennlp/allennlp_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@ def initialize(self, resources: Resources, configs: Config):
"that the entries of the same type as produced by "
"this processor will be overwritten if found.")
if configs.allow_parallel_entries:
logger.warning('Both `overwrite_entries` (whether to overwrite '
'the entries of the same type as produced by '
logger.warning('Both `overwrite_entries` (whether to overwrite'
' the entries of the same type as produced by '
'this processor) and '
'`allow_parallel_entries` (whether to allow '
'similar new entries when they already exist) '
'are True, all existing conflicting entries '
'will be deleted.')
else:
if not configs.allow_parallel_entries:
logger.warning('Both `overwrite_entries` (whether to overwrite '
'the entries of the same type as produced by '
logger.warning('Both `overwrite_entries` (whether to overwrite'
' the entries of the same type as produced by '
'this processor) and '
'`allow_parallel_entries` (whether to allow '
'similar new entries when they already exist) '
Expand Down Expand Up @@ -133,19 +133,22 @@ def default_configs(cls):
def _process(self, input_pack: DataPack):
# handle existing entries
self._process_existing_entries(input_pack)

for sentence in input_pack.get(Sentence):
result: Dict[str, List[str]] = {}
for key in self.predictor:
predicted_result = self.predictor[key].predict( # type: ignore
sentence=sentence.text)
sentences = [_ for _ in input_pack.get(Sentence)]
inputs = [{"sentence": s.text} for s in sentences]
results = {k: p.predict_batch_json(inputs)
for k, p in self.predictor.items()}
for i in range(len(sentences)):
result = {}
for key in self.predictor.keys():
if key == 'srl':
predicted_result = parse_allennlp_srl_results(
predicted_result['verbs'])
result.update(predicted_result)
result.update(
parse_allennlp_srl_results(results[key][i]["verbs"])
)
else:
result.update(results[key][i])
if "tokenize" in self.configs.processors:
# creating new tokens and dependencies
tokens = self._create_tokens(input_pack, sentence, result)
tokens = self._create_tokens(input_pack, sentences[i], result)
if "depparse" in self.configs.processors:
self._create_dependencies(input_pack, tokens, result)
if 'srl' in self.configs.processors:
Expand All @@ -159,8 +162,8 @@ def _process_existing_entries(self, input_pack):
if not self.configs.overwrite_entries:
if not self.configs.allow_parallel_entries:
raise ProcessorConfigError(
"Found existing entries, either `overwrite_entries` or "
"`allow_parallel_entries` should be True")
"Found existing entries, either `overwrite_entries` or"
" `allow_parallel_entries` should be True")
else:
# delete existing tokens and dependencies
for entry_type in (Token, Dependency):
Expand Down

0 comments on commit 9cb5724

Please sign in to comment.