Skip to content

Commit

Permalink
black and flake my friends
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Mar 12, 2024
1 parent 2354c30 commit c07da3e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
17 changes: 10 additions & 7 deletions onmt/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,16 @@ def _repr_args(self):
def make_transforms(opts, transforms_cls, vocabs):
"""Build transforms in `transforms_cls` with vocab of `fields`."""
transforms = {}
for name, transform_cls in transforms_cls.items():
if transform_cls.require_vocab() and vocabs is None:
logger.warning(f"{transform_cls.__name__} require vocab to apply, skip it.")
continue
transform_obj = transform_cls(opts)
transform_obj.warm_up(vocabs)
transforms[name] = transform_obj
if transforms_cls:
for name, transform_cls in transforms_cls.items():
if transform_cls.require_vocab() and vocabs is None:
logger.warning(
f"{transform_cls.__name__} require vocab to apply, skip it."
)
continue
transform_obj = transform_cls(opts)
transform_obj.warm_up(vocabs)
transforms[name] = transform_obj
return transforms


Expand Down
6 changes: 4 additions & 2 deletions onmt/utils/scoring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from onmt.opts import translate_opts
from onmt.constants import CorpusTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe
from onmt.transforms import get_transforms_cls


class ScoringPreparator:
Expand All @@ -19,6 +19,8 @@ def __init__(self, vocabs, opt):
if self.opt.dump_preds is not None:
if not os.path.exists(self.opt.dump_preds):
os.makedirs(self.opt.dump_preds)
self.transforms = None
self.transforms_cls = None

def warm_up(self, transforms):
self.transforms = transforms
Expand Down Expand Up @@ -78,7 +80,7 @@ def translate(self, model, gpu_rank, step):

# Reinstantiate the validation iterator

#transforms_cls = get_transforms_cls(model_opt._all_transform)
# transforms_cls = get_transforms_cls(model_opt._all_transform)
model_opt.num_workers = 0
model_opt.tgt = None

Expand Down

0 comments on commit c07da3e

Please sign in to comment.