Skip to content

Commit

Permalink
Add custom vocab file for ce (PaddlePaddle#963)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrostML authored Sep 7, 2021
1 parent 6ce057a commit 1a6729e
Show file tree
Hide file tree
Showing 14 changed files with 261 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ def parse_args():
default="./output/",
type=str,
help="The path to save logs when profile is enabled. ")
parser.add_argument(
"--vocab_file",
default=None,
type=str,
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
)
parser.add_argument(
"--unk_token",
default=None,
type=str,
help="The unknown token. It should be provided when use custom vocab_file. "
)
parser.add_argument(
"--bos_token",
default=None,
type=str,
help="The bos token. It should be provided when use custom vocab_file. ")
parser.add_argument(
"--eos_token",
default=None,
type=str,
help="The eos token. It should be provided when use custom vocab_file. ")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -222,6 +244,10 @@ def do_inference(args):
args.inference_model_dir = ARGS.model_dir
args.test_file = ARGS.test_file
args.save_log_path = ARGS.save_log_path
args.vocab_file = ARGS.vocab_file
args.unk_token = ARGS.unk_token
args.bos_token = ARGS.bos_token
args.eos_token = ARGS.eos_token
pprint(args)

if args.profile:
Expand Down
26 changes: 26 additions & 0 deletions examples/machine_translation/transformer/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ def parse_args():
action="store_true",
help="Whether to print logs on each cards and use benchmark vocab. Normally, not necessary to set --benchmark. "
)
parser.add_argument(
"--vocab_file",
default=None,
type=str,
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
)
parser.add_argument(
"--unk_token",
default=None,
type=str,
help="The unknown token. It should be provided when use custom vocab_file. "
)
parser.add_argument(
"--bos_token",
default=None,
type=str,
help="The bos token. It should be provided when use custom vocab_file. ")
parser.add_argument(
"--eos_token",
default=None,
type=str,
help="The eos token. It should be provided when use custom vocab_file. ")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -87,6 +109,10 @@ def do_export(args):
with open(yaml_file, 'rt') as f:
args = AttrDict(yaml.safe_load(f))
args.benchmark = ARGS.benchmark
args.vocab_file = ARGS.vocab_file
args.unk_token = ARGS.unk_token
args.bos_token = ARGS.bos_token
args.eos_token = ARGS.eos_token
pprint(args)

do_export(args)
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ def parse_args():
action="store_true",
help="Whether to print logs on each cards and use benchmark vocab. Normally, not necessary to set --benchmark. "
)
parser.add_argument(
"--vocab_file",
default=None,
type=str,
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
)
parser.add_argument(
"--unk_token",
default=None,
type=str,
help="The unknown token. It should be provided when use custom vocab_file. "
)
parser.add_argument(
"--bos_token",
default=None,
type=str,
help="The bos token. It should be provided when use custom vocab_file. ")
parser.add_argument(
"--eos_token",
default=None,
type=str,
help="The eos token. It should be provided when use custom vocab_file. ")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -191,6 +213,10 @@ def do_predict(args):
if ARGS.batch_size:
args.infer_batch_size = ARGS.batch_size
args.test_file = ARGS.test_file
args.vocab_file = ARGS.vocab_file
args.unk_token = ARGS.unk_token
args.bos_token = ARGS.bos_token
args.eos_token = ARGS.eos_token
pprint(args)

do_predict(args)
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ def parse_args():
action="store_true",
help="Whether to print logs on each cards and use benchmark vocab. Normally, not necessary to set --benchmark. "
)
parser.add_argument(
"--vocab_file",
default=None,
type=str,
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
)
parser.add_argument(
"--unk_token",
default=None,
type=str,
help="The unknown token. It should be provided when use custom vocab_file. "
)
parser.add_argument(
"--bos_token",
default=None,
type=str,
help="The bos token. It should be provided when use custom vocab_file. ")
parser.add_argument(
"--eos_token",
default=None,
type=str,
help="The eos token. It should be provided when use custom vocab_file. ")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -133,6 +155,10 @@ def do_predict(args):
args.topk = ARGS.topk
args.topp = ARGS.topp
args.benchmark = ARGS.benchmark
args.vocab_file = ARGS.vocab_file
args.unk_token = ARGS.unk_token
args.bos_token = ARGS.bos_token
args.eos_token = ARGS.eos_token
pprint(args)

do_predict(args)
26 changes: 26 additions & 0 deletions examples/machine_translation/transformer/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@ def parse_args():
"--without_ft",
action="store_true",
help="Whether to use Faster Transformer to do predict. ")
parser.add_argument(
"--vocab_file",
default=None,
type=str,
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
)
parser.add_argument(
"--unk_token",
default=None,
type=str,
help="The unknown token. It should be provided when use custom vocab_file. "
)
parser.add_argument(
"--bos_token",
default=None,
type=str,
help="The bos token. It should be provided when use custom vocab_file. ")
parser.add_argument(
"--eos_token",
default=None,
type=str,
help="The eos token. It should be provided when use custom vocab_file. ")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -127,6 +149,10 @@ def do_predict(args):
args.benchmark = ARGS.benchmark
args.test_file = ARGS.test_file
args.without_ft = ARGS.without_ft
args.vocab_file = ARGS.vocab_file
args.unk_token = ARGS.unk_token
args.bos_token = ARGS.bos_token
args.eos_token = ARGS.eos_token
pprint(args)

do_predict(args)
31 changes: 25 additions & 6 deletions examples/machine_translation/transformer/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ def create_data_loader(args, places=None):
raise ValueError(
"--train_file and --dev_file must be both or neither set. ")

if not args.benchmark:
if args.vocab_file is not None:
src_vocab = Vocab.load_vocabulary(
filepath=args.vocab_file,
unk_token=args.unk_token,
bos_token=args.bos_token,
eos_token=args.eos_token)
elif not args.benchmark:
src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["bpe"])
else:
src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["benchmark"])
Expand Down Expand Up @@ -109,7 +115,13 @@ def create_infer_loader(args):
else:
dataset = load_dataset('wmt14ende', splits=('test'))

if not args.benchmark:
if args.vocab_file is not None:
src_vocab = Vocab.load_vocabulary(
filepath=args.vocab_file,
unk_token=args.unk_token,
bos_token=args.bos_token,
eos_token=args.eos_token)
elif not args.benchmark:
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
else:
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"])
Expand Down Expand Up @@ -151,11 +163,18 @@ def convert_samples(sample):


def adapt_vocab_size(args):
dataset = load_dataset('wmt14ende', splits=('test'))
if not args.benchmark:
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
if args.vocab_file is not None:
src_vocab = Vocab.load_vocabulary(
filepath=args.vocab_file,
unk_token=args.unk_token,
bos_token=args.bos_token,
eos_token=args.eos_token)
else:
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"])
dataset = load_dataset('wmt14ende', splits=('test'))
if not args.benchmark:
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
else:
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"])
trg_vocab = src_vocab

padding_vocab = (
Expand Down
26 changes: 26 additions & 0 deletions examples/machine_translation/transformer/static/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ def parse_args():
type=str,
help="The file for testing. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to process testing."
)
parser.add_argument(
"--vocab_file",
default=None,
type=str,
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
)
parser.add_argument(
"--unk_token",
default=None,
type=str,
help="The unknown token. It should be provided when use custom vocab_file. "
)
parser.add_argument(
"--bos_token",
default=None,
type=str,
help="The bos token. It should be provided when use custom vocab_file. ")
parser.add_argument(
"--eos_token",
default=None,
type=str,
help="The eos token. It should be provided when use custom vocab_file. ")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -146,6 +168,10 @@ def do_predict(args):
args = AttrDict(yaml.safe_load(f))
args.benchmark = ARGS.benchmark
args.test_file = ARGS.test_file
args.vocab_file = ARGS.vocab_file
args.unk_token = ARGS.unk_token
args.bos_token = ARGS.bos_token
args.eos_token = ARGS.eos_token
pprint(args)

do_predict(args)
26 changes: 26 additions & 0 deletions examples/machine_translation/transformer/static/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ def parse_args():
type=str,
help="The files for validation, including [source language file, target language file]. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to do validation. "
)
parser.add_argument(
"--vocab_file",
default=None,
type=str,
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
)
parser.add_argument(
"--unk_token",
default=None,
type=str,
help="The unknown token. It should be provided when use custom vocab_file. "
)
parser.add_argument(
"--bos_token",
default=None,
type=str,
help="The bos token. It should be provided when use custom vocab_file. ")
parser.add_argument(
"--eos_token",
default=None,
type=str,
help="The eos token. It should be provided when use custom vocab_file. ")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -299,6 +321,10 @@ def do_train(args):
args.max_iter = ARGS.max_iter
args.train_file = ARGS.train_file
args.dev_file = ARGS.dev_file
args.vocab_file = ARGS.vocab_file
args.unk_token = ARGS.unk_token
args.bos_token = ARGS.bos_token
args.eos_token = ARGS.eos_token
pprint(args)

do_train(args)
26 changes: 26 additions & 0 deletions examples/machine_translation/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ def parse_args():
type=str,
help="The files for validation, including [source language file, target language file]. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to do validation. "
)
parser.add_argument(
"--vocab_file",
default=None,
type=str,
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
)
parser.add_argument(
"--unk_token",
default=None,
type=str,
help="The unknown token. It should be provided when use custom vocab_file. "
)
parser.add_argument(
"--bos_token",
default=None,
type=str,
help="The bos token. It should be provided when use custom vocab_file. ")
parser.add_argument(
"--eos_token",
default=None,
type=str,
help="The eos token. It should be provided when use custom vocab_file. ")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -270,6 +292,10 @@ def do_train(args):
args.max_iter = ARGS.max_iter
args.train_file = ARGS.train_file
args.dev_file = ARGS.dev_file
args.vocab_file = ARGS.vocab_file
args.unk_token = ARGS.unk_token
args.bos_token = ARGS.bos_token
args.eos_token = ARGS.eos_token
pprint(args)

do_train(args)
Loading

0 comments on commit 1a6729e

Please sign in to comment.