Skip to content

Commit

Permalink
Transformer input id supports int32 & mv beam search v2 (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#905)

* support int32

* mv beam search v2
  • Loading branch information
FrostML authored Aug 20, 2021
1 parent 2ae7a87 commit 4075ea6
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 220 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ random_seed: None
output_file: "predict.txt"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# The data type of input ids.
input_dtype: "int64"

# Device to use.
device: "gpu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ random_seed: None
output_file: "predict.txt"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# The data type of input ids.
input_dtype: "int64"

# Device to use.
device: "gpu"
Expand Down
24 changes: 18 additions & 6 deletions examples/machine_translation/transformer/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def convert_samples(sample):
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
pad_idx=args.bos_idx,
pad_seq=args.pad_seq),
pad_seq=args.pad_seq,
dtype=args.input_dtype),
num_workers=0)
data_loaders[i] = (data_loader)
return data_loaders
Expand Down Expand Up @@ -142,7 +143,8 @@ def convert_samples(sample):
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
pad_idx=args.bos_idx,
pad_seq=args.pad_seq),
pad_seq=args.pad_seq,
dtype=args.input_dtype),
num_workers=0,
return_list=True)
return data_loader, trg_vocab.to_tokens
Expand All @@ -163,11 +165,16 @@ def adapt_vocab_size(args):
args.trg_vocab_size = padding_vocab(len(trg_vocab))


def prepare_train_input(insts, bos_idx, eos_idx, pad_idx, pad_seq=1):
def prepare_train_input(insts,
bos_idx,
eos_idx,
pad_idx,
pad_seq=1,
dtype="int64"):
"""
Put all padded data needed by training into a list.
"""
word_pad = Pad(pad_idx, dtype="int64")
word_pad = Pad(pad_idx, dtype=dtype)
src_max_len = (
max([len(inst[0]) for inst in insts]) + pad_seq) // pad_seq * pad_seq
trg_max_len = (
Expand All @@ -190,11 +197,16 @@ def prepare_train_input(insts, bos_idx, eos_idx, pad_idx, pad_seq=1):
return data_inputs


def prepare_infer_input(insts, bos_idx, eos_idx, pad_idx, pad_seq=1):
def prepare_infer_input(insts,
bos_idx,
eos_idx,
pad_idx,
pad_seq=1,
dtype="int64"):
"""
Put all padded data needed by beam search decoder into a list.
"""
word_pad = Pad(pad_idx, dtype="int64")
word_pad = Pad(pad_idx, dtype=dtype)
src_max_len = (
max([len(inst[0]) for inst in insts]) + pad_seq) // pad_seq * pad_seq
src_word = word_pad([
Expand Down
2 changes: 1 addition & 1 deletion examples/machine_translation/transformer/static/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def do_predict(args):
startup_program = paddle.static.Program()
with paddle.static.program_guard(test_program, startup_program):
src_word = paddle.static.data(
name="src_word", shape=[None, None], dtype="int64")
name="src_word", shape=[None, None], dtype=args.input_dtype)

# Define model
transformer = InferTransformerModel(
Expand Down
6 changes: 3 additions & 3 deletions examples/machine_translation/transformer/static/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def do_train(args):
startup_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
src_word = paddle.static.data(
name="src_word", shape=[None, None], dtype="int64")
name="src_word", shape=[None, None], dtype=args.input_dtype)
trg_word = paddle.static.data(
name="trg_word", shape=[None, None], dtype="int64")
name="trg_word", shape=[None, None], dtype=args.input_dtype)
lbl_word = paddle.static.data(
name="lbl_word", shape=[None, None, 1], dtype="int64")
name="lbl_word", shape=[None, None, 1], dtype=args.input_dtype)

# Define model
transformer = TransformerModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def forward(self, src_word):
src_word == self.bos_id,
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
src_pos = paddle.cast(
src_word != self.bos_id, dtype="int64") * paddle.arange(
src_word != self.bos_id, dtype=src_word.dtype) * paddle.arange(
start=0, end=src_max_len)

# Run encoder
Expand Down
Loading

0 comments on commit 4075ea6

Please sign in to comment.