Skip to content

Commit

Permalink
Load data at the correct position when resuming from a checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored and Thai Chau Truong committed Mar 31, 2024
1 parent c570639 commit d2dc3f8
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 16 deletions.
26 changes: 24 additions & 2 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
batch_type,
batch_size,
batch_size_multiple,
resume_corpora_info={},
data_type="text",
bucket_size=2048,
bucket_size_init=-1,
Expand All @@ -144,6 +145,7 @@ def __init__(
self.transforms = transforms
self.vocabs = vocabs
self.corpora_info = corpora_info
self.resume_corpora_info = resume_corpora_info
self.task = task
self.init_iterators = False
self.batch_size = batch_size
Expand Down Expand Up @@ -171,7 +173,17 @@ def __init__(

@classmethod
def from_opt(
cls, corpora, transforms, vocabs, opt, task, copy, device, stride=1, offset=0
cls,
corpora,
transforms,
vocabs,
opt,
task,
copy,
device,
resume_corpora_info={},
stride=1,
offset=0,
):
"""Initilize `DynamicDatasetIter` with options parsed from `opt`."""
corpora_info = {}
Expand Down Expand Up @@ -206,6 +218,7 @@ def from_opt(
opt.batch_type,
batch_size,
batch_size_multiple,
resume_corpora_info=resume_corpora_info,
data_type=opt.data_type,
bucket_size=bucket_size,
bucket_size_init=bucket_size_init,
Expand Down Expand Up @@ -388,6 +401,7 @@ def build_dynamic_dataset_iter(
vocabs,
copy=False,
task=CorpusTask.TRAIN,
resume_corpora_info={},
stride=1,
offset=0,
src=None,
Expand All @@ -412,7 +426,14 @@ def build_dynamic_dataset_iter(
advance to avoid the GPU waiting during the refilling of the bucket.
"""
transforms = make_transforms(opt, transforms_cls, vocabs)
corpora = get_corpora(opt, task, src=src, tgt=tgt, align=align)
corpora = get_corpora(
opt,
task,
src=src,
tgt=tgt,
align=align,
resume_corpora_info=resume_corpora_info,
)
if corpora is None:
assert task != CorpusTask.TRAIN, "only valid corpus is ignorable."
return None
Expand Down Expand Up @@ -442,6 +463,7 @@ def build_dynamic_dataset_iter(
vocabs,
opt,
task,
resume_corpora_info=resume_corpora_info,
copy=copy,
stride=stride,
offset=offset,
Expand Down
46 changes: 38 additions & 8 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(
self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None
self,
name,
src,
tgt,
align=None,
n_src_feats=0,
src_feats_defaults=None,
line_number_to_resume=0,
):
"""Initialize src & tgt side file path."""
self.id = name
Expand All @@ -108,6 +115,12 @@ def __init__(
self.align = align
self.n_src_feats = n_src_feats
self.src_feats_defaults = src_feats_defaults
self.line_number_to_resume = line_number_to_resume
self.can_read_file = False

def activate_reading_mode(self, line_number):
if line_number >= self.line_number_to_resume:
self.can_read_file = True

def load(self, offset=0, stride=1):
"""
Expand All @@ -116,7 +129,7 @@ def load(self, offset=0, stride=1):
`stride` example, starting from `offset`.
"""

def make_ex(sline, tline, align):
def make_ex(sline, tline, align, line_number):
sline, sfeats = parse_features(
sline,
n_feats=self.n_src_feats,
Expand All @@ -131,6 +144,7 @@ def make_ex(sline, tline, align):
"tgt": tline,
"src_original": sline,
"tgt_original": tline,
"cid_line_number": line_number,
}
if align is not None:
example["align"] = align
Expand All @@ -145,19 +159,25 @@ def make_ex(sline, tline, align):
for i, (sline, tline, align) in enumerate(
itertools.zip_longest(fs, ft, fa)
):
self.activate_reading_mode(line_index=i)
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
yield make_ex(sline, tline, align)
yield make_ex(sline, tline, align, i)
else:
with exfile_open(self.src, mode="rb") as fs, exfile_open(
self.tgt, mode="rb"
) as ft, exfile_open(self.align, mode="rb") as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
self.activate_reading_mode(line_number=i)
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
if tline is not None:
tline = tline.decode("utf-8")
if align is not None:
align = align.decode("utf-8")
yield make_ex(sline.decode("utf-8"), tline, align)
yield make_ex(sline.decode("utf-8"), tline, align, i)

def __str__(self):
cls_name = type(self).__name__
Expand All @@ -169,19 +189,25 @@ def __str__(self):
)


def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
def get_corpora(
opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None, resume_corpora_info={}
):
corpora_dict = {}
if task == CorpusTask.TRAIN:
for corpus_id, corpus_dict in opts.data.items():
if corpus_id != CorpusName.VALID:
if corpus_dict.get("path_txt", None) is None:
resume_line = 0
if corpus_id in resume_corpora_info:
resume_line = resume_corpora_info[corpus_id]
corpora_dict[corpus_id] = ParallelCorpus(
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults,
line_number_to_resume=resume_line,
)
else:
corpora_dict[corpus_id] = BlockwiseCorpus(
Expand Down Expand Up @@ -244,8 +270,6 @@ def _process(self, stream):
example["src_feats"] = [
feat.strip().split(" ") for feat in example["src_feats"]
]
line_number = i * self.stride + self.offset
example["cid_line_number"] = line_number
example["cid"] = self.cid
if "align" in example:
example["align"] = example["align"].strip().split(" ")
Expand All @@ -258,6 +282,7 @@ def _process(self, stream):
or ("align" in example and example["align"] == 0)
):
# empty example: skip
line_number = example["cid_line_number"]
empty_msg = f"Empty line in {self.cid}#{line_number}."
if self.skip_empty_level == "error":
raise IOError(empty_msg)
Expand All @@ -282,7 +307,12 @@ def __iter__(self):


def build_corpora_iters(
corpora, transforms, corpora_info, skip_empty_level="warning", stride=1, offset=0
corpora,
transforms,
corpora_info,
skip_empty_level="warning",
stride=1,
offset=0,
):
"""Return `ParallelCorpusIterator` for all corpora defined in opts."""
corpora_iters = dict()
Expand Down
92 changes: 90 additions & 2 deletions onmt/models/model_saver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
import torch
import re
import subprocess
from collections import deque
import onmt.utils
from onmt.utils.logging import logger
from onmt.inputters.inputter import vocabs_to_dict
from onmt.modules.lora import lora_state_dict


def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
def build_model_saver(
model_opt, opt, model, vocabs, optim, resume_corpora_info, device_id
):
# _check_save_model_path
save_model_path = os.path.abspath(opt.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
Expand All @@ -20,6 +24,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
optim,
opt.keep_checkpoint,
opt.save_format,
resume_corpora_info,
device_id,
)
return model_saver
Expand Down Expand Up @@ -81,6 +86,65 @@ def fix_key(s):
return checkpoint


def load_corpora_info(opts, checkpoint):
message_resume_from_beginning = (
"The training will resume from the beginning of each corpus."
)
# Check if resume_from_corpora is True
if not opts.resume_from_corpora:
logger.info(
"No resume from corpora is specified. " + message_resume_from_beginning
)
return {}

# Check if the corpus list from the last training
# and in the new training are identical.
checkpoint_corpora = checkpoint.get("corpus_info", None)
if checkpoint_corpora is None:
logger.info(
"Incoherent info: Some corpora in the last training "
+ "and in the new list do not match. "
+ message_resume_from_beginning
)
return {}

checkpoint_corpus_names = [name for name in checkpoint_corpora]
new_corpus_names = [name for name in opts.data]
if set(checkpoint_corpus_names) != set(new_corpus_names):
logger.info(
"Incoherent info: Some corpora in the last training "
+ "and in the new list do not match. "
+ message_resume_from_beginning
)
return {}

# For each corpus, check if the last line number to resume
# is smaller than or equal to the number of text lines.
message_incoherent_line_number = (
"Incoherent info: text line numbers "
+ "to resume in some corpora exceed their total numbers of lines. "
+ message_resume_from_beginning
)
for c_name in checkpoint_corpora:
number_of_text_lines = int(
subprocess.getoutput(
"wc -l " + opts.data[c_name]["path_src"] + " | awk '{print $1}'"
)
)
if checkpoint_corpora[c_name] > number_of_text_lines - 1:
logger.info(message_incoherent_line_number)
return {}

# To set the text lines to resume, we increase all text lines by 1
# (and return to the beginning if the end is reached)
checkpoint_corpora[c_name] = (
checkpoint_corpora[c_name] + 1
) % number_of_text_lines

logger.info("The training will resume from the saved text line in each corpus.")
return checkpoint_corpora


class ModelSaverBase(object):
"""Base class for model saving operations
Expand All @@ -98,6 +162,7 @@ def __init__(
optim,
keep_checkpoint=-1,
save_format="pytorch",
resume_corpora_info={},
device_id=0,
):
self.base_path = base_path
Expand All @@ -108,14 +173,35 @@ def __init__(
self.last_saved_step = None
self.keep_checkpoint = keep_checkpoint
self.save_format = save_format
self.corpus_info = resume_corpora_info
self.device_id = device_id

if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
if save_format == "safetensors":
self.model_queue = deque([], maxlen=keep_checkpoint)

def save(self, step, moving_average=None):
def update_corpus_info_from_batches(self, batches, distributed=False):
# Update the last text line of each corpus
if batches is not None:
# Gather corpus line numbers to save to checkpoints
batch_cids = sum([batch["cid"] for batch in batches], [])
batch_cid_line_numbers = sum(
[batch["cid_line_number"] for batch in batches], []
)
if distributed:
batch_cids = sum(onmt.utils.distributed.all_gather_list(batch_cids), [])
batch_cid_line_numbers = sum(
onmt.utils.distributed.all_gather_list(batch_cid_line_numbers), []
)
# Save the last processed line number of each corpus
new_corpus_info = {
c_name: cid_line_number
for c_name, cid_line_number in zip(batch_cids, batch_cid_line_numbers)
}
self.corpus_info = {**self.corpus_info, **new_corpus_info}

def save(self, step, moving_average=None, batches=None, distributed=False):
"""Main entry point for model saver
It wraps the `_save` method with checks and apply `keep_checkpoint`
Expand Down Expand Up @@ -266,6 +352,7 @@ def _save(self, step, model):
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
"corpus_info": self.corpus_info,
}
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
Expand Down Expand Up @@ -355,6 +442,7 @@ def _st_save(self, step, model):
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
"corpus_info": self.corpus_info,
}

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
Expand Down
7 changes: 7 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,13 @@ def _add_train_general_opts(parser):
help="If training from a checkpoint then this is the "
"path to the pretrained model's state_dict.",
)
group.add(
"--resume_from_corpora",
"-resume_from_corpora",
action="store_true",
help="If training from a checkpoint and this is set to True "
" then the data generator will resume from the last line of each corpora.",
)
group.add(
"--reset_optim",
"-reset_optim",
Expand Down
Loading

0 comments on commit d2dc3f8

Please sign in to comment.