diff --git a/preprocessing/create_pretraining_data.py b/preprocessing/create_pretraining_data.py index 225d2e7..a935215 100644 --- a/preprocessing/create_pretraining_data.py +++ b/preprocessing/create_pretraining_data.py @@ -76,13 +76,11 @@ def write_instance_to_example_file(instances, tokenizer, max_seq_length, for inst_index, instance in enumerate(tqdm(instances)): input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) input_mask = [1] * len(input_ids) - segment_ids = list(instance.segment_ids) assert len(input_ids) <= max_seq_length while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) - segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length @@ -125,8 +123,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length, for input_file in input_files: print("creating instance from {}".format(input_file)) with open(input_file, "r") as reader: - while True: - line = tokenization.convert_to_unicode(reader.readline()) + for dirty_line in reader: + line = tokenization.convert_to_unicode(dirty_line) line = line.strip() tokens = tokenizer.tokenize(line) all_sequences.append(tokens) @@ -137,8 +135,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length, instances = [] for _ in range(dupe_factor): for sequence_index in range(len(all_sequences)): - instances.extend( - create_instances_from_document( + instances.append( + create_instance_from_sequence( all_sequences, sequence_index, max_seq_length, masked_lm_prob, max_predictions_per_seq, vocab_residues, rng)) @@ -146,29 +144,24 @@ def create_training_instances(input_files, tokenizer, max_seq_length, return instances -def create_instances_from_document( +def create_instance_from_sequence( all_sequences, sequence_index, max_seq_length, masked_lm_prob, max_predictions_per_seq, vocab_residues, rng): """Creates `TrainingInstance`s for a single document.""" sequence = all_sequences[sequence_index] + + tokens = sequence[:max_seq_length] + tokens[0] = '[CLS]' + tokens[-1] = '[SEP]' + (tokens, masked_lm_positions, + masked_lm_labels) = create_masked_lm_predictions( + tokens, masked_lm_prob, max_predictions_per_seq, vocab_residues, rng) + instance = TrainingInstance( + tokens=tokens, + masked_lm_positions=masked_lm_positions, + masked_lm_labels=masked_lm_labels) - instances = [] - i = 0 - while i < len(sequence): - tokens = sequence[:max_seq_length] - tokens[0] = '[CLS]' - tokens[-1] = '[SEP]' - (tokens, masked_lm_positions, - masked_lm_labels) = create_masked_lm_predictions( - tokens, masked_lm_prob, max_predictions_per_seq, vocab_residues, rng) - instance = TrainingInstance( - tokens=tokens, - masked_lm_positions=masked_lm_positions, - masked_lm_labels=masked_lm_labels) - instances.append(instance) - i += 1 - - return instances + return instance MaskedLmInstance = collections.namedtuple("MaskedLmInstance", @@ -295,7 +288,8 @@ def main(): args = parser.parse_args() - tokenizer = ProteoNeMoTokenizer(args.vocab_file, args.small_vocab_file, do_upper_case=args.do_upper_case, max_len=512) + tokenizer = ProteoNeMoTokenizer(args.vocab_file, args.small_vocab_file, do_upper_case=args.do_upper_case, + max_len=args.max_seq_length) input_files = [] if os.path.isfile(args.input_file): diff --git a/preprocessing/tokenization.py b/preprocessing/tokenization.py index c80de82..be35e11 100644 --- a/preprocessing/tokenization.py +++ b/preprocessing/tokenization.py @@ -48,7 +48,7 @@ def convert_to_unicode(text): def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() - index = 0 + index = 1 with open(vocab_file, "r", encoding="utf-8") as reader: while True: token = reader.readline()