diff --git a/index.html b/index.html index b1f3299..db377fb 100755 --- a/index.html +++ b/index.html @@ -498,10 +498,10 @@
Red Coast (redco) is a lightweight and user-friendly tool designed to automate distributed training and inference for large models while simplifying the ML pipeline development process without necessitating MLSys expertise from users.
-RedCoast supports Large Models + Complex Algorithms, in a lightweight and user-friendly manner:
+RedCoast supports Large Models + Complex Algorithms, in a lightweight and user-friendly way:
With RedCoast, to define a ML pipeline, only three functions are needed:
Red Coast (redco) is a lightweight and user-friendly tool designed to automate distributed training and inference for large models while simplifying the ML pipeline development process without necessitating MLSys expertise from users.
RedCoast supports Large Models + Complex Algorithms, in a lightweight and user-friendly manner:
With RedCoast, to define a ML pipeline, only three functions are needed:
Redco automates all the remaining of pipeline execution such as data and model parallelism, multi-host related processing, distributed checkpointing, randomness controlling, logging, etc.
"},{"location":"deployer/","title":"Deployer","text":""},{"location":"deployer/#redco.deployers.deployer.Deployer","title":"Deployer
","text":"Handles low-level operations to support Trainer and Predictor, e.g., automatic data/model parallelism, distributed checkpointing, data processing, logging, randomness controlling, etc.
Attributes:
Name Type Descriptionworkdir
str
Working directory for saving checkpoints and logs.
mesh
jax Mesh
Mesh used for model sharding.
Source code inredco/deployers/deployer.py
class Deployer:\n \"\"\" Handles low-level operations to support Trainer and Predictor,\n e.g., automatic data/model parallelism, distributed checkpointing,\n data processing, logging, randomness controlling, etc.\n\n Attributes:\n workdir (str): Working directory for saving checkpoints and logs.\n mesh (jax Mesh): Mesh used for model sharding.\n \"\"\"\n def __init__(self,\n jax_seed,\n n_model_shards=1,\n verbose=True,\n workdir=None,\n n_processes=None,\n host0_address=None,\n host0_port=None,\n process_id=None,\n n_local_devices=None,\n run_tensorboard=False,\n wandb_init_kwargs=None):\n \"\"\" Initializes a Deployer.\n\n Args:\n jax_seed (int): Seed for random number generation.\n n_model_shards (int): Number of shards for running large model.\n verbose (bool): Whether to enable verbose logging.\n workdir (str): Directory for saving logs and checkpoints.\n n_processes (int): For multi-host, number of processes/nodes.\n host0_address (str): For multi-host, address of the host0.\n host0_port (int): For multi-host, port of the host0.\n process_id (int): For multi-host, index of the current process.\n n_local_devices (int): For multi-host, number of local devices.\n run_tensorboard (bool): Whether to enable TensorBoard logging.\n wandb_init_kwargs (dict): wandb.init arguments if using wandb.\n \"\"\"\n if n_processes is None:\n if 'SLURM_JOB_NUM_NODES' in os.environ:\n n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])\n process_id = int(os.environ['SLURM_NODEID'])\n else:\n n_processes = 1\n\n if n_processes > 1:\n local_device_ids = None if n_local_devices is None \\\n else list(range(n_local_devices))\n\n if host0_port is None:\n host0_port = DEFAULT_HOST0_PORT\n\n jax.distributed.initialize(\n coordinator_address=f'{host0_address}:{host0_port}',\n num_processes=n_processes,\n process_id=process_id,\n local_device_ids=local_device_ids)\n\n if workdir is not None:\n os.makedirs(workdir, exist_ok=True)\n\n self._verbose = verbose\n self._workdir = workdir\n self._logger = get_logger(verbose=verbose, workdir=workdir)\n\n if wandb_init_kwargs is not None and jax.process_index() == 0:\n import wandb\n wandb.init(**wandb_init_kwargs)\n self._wandb_log_fn = wandb.log\n else:\n self._wandb_log_fn = None\n\n if run_tensorboard and jax.process_index() == 0:\n from flax.metrics import tensorboard\n self._summary_writer = tensorboard.SummaryWriter(workdir)\n else:\n self._summary_writer = None\n\n self.log_info(\n f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')\n\n self._rng = jax.random.PRNGKey(seed=jax_seed)\n self._mesh = get_mesh(n_model_shards=n_model_shards)\n self._checkpointer = ocp.PyTreeCheckpointer()\n\n def get_local_global_micro_batch_size(self, per_device_batch_size):\n \"\"\"Get local/global micro batch sizes based on per-device batch size.\"\"\"\n if self._mesh is None:\n local_micro_batch_size = \\\n per_device_batch_size * jax.local_device_count()\n global_micro_batch_size = \\\n local_micro_batch_size * jax.process_count()\n else:\n global_micro_batch_size = local_micro_batch_size = \\\n per_device_batch_size * self._mesh.shape['dp']\n\n return local_micro_batch_size, global_micro_batch_size\n\n def get_accumulate_grad_batches(\n self, global_batch_size, per_device_batch_size):\n \"\"\"Calculates the number of gradient accumulation batches.\"\"\"\n _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n assert global_batch_size % global_micro_batch_size == 0\n accumulate_grad_batches = global_batch_size // global_micro_batch_size\n\n return accumulate_grad_batches\n\n def get_model_input_batches(self,\n examples,\n per_device_batch_size,\n collate_fn,\n shuffle,\n shuffle_rng,\n desc,\n is_train=False,\n accumulate_grad_batches=None):\n \"\"\"Prepares model input batches from examples.\n\n Args:\n examples (list): List of input examples.\n per_device_batch_size (int): Batch size per device.\n collate_fn (Callable): Function to collate the examples.\n shuffle (bool): Whether to shuffle the examples.\n shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.\n desc (str): Description in the progress bar.\n is_train (bool): Whether the data is for training.\n accumulate_grad_batches (int): gradient accumulation batches.\n\n Returns:\n (generator): A python generator of batched model inputs.\n \"\"\"\n local_micro_batch_size, global_micro_batch_size = \\\n self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n\n examples = get_host_examples(\n examples=examples,\n global_micro_batch_size=global_micro_batch_size,\n shuffle=shuffle,\n shuffle_rng=shuffle_rng,\n mesh=self._mesh)\n\n if not is_train:\n desc = f'{desc} (global_batch_size = {global_micro_batch_size})'\n elif accumulate_grad_batches is None:\n desc = \\\n f'{desc} (global_micro_batch_size = {global_micro_batch_size})'\n else:\n desc = (f'{desc} ('\n f'global_micro_batch_size = {global_micro_batch_size}, '\n f'accumulate_grad_batches = {accumulate_grad_batches})')\n\n return get_data_batches(\n examples=examples,\n batch_size=local_micro_batch_size,\n collate_fn=collate_fn,\n mesh=self._mesh,\n desc=desc,\n verbose=self._verbose)\n\n def get_lr_schedule_fn(self,\n train_size,\n per_device_batch_size,\n n_epochs,\n learning_rate,\n schedule_type='linear',\n warmup_ratio=0.,\n warmup_steps=None,\n init_learning_rate=0.,\n end_learning_rate=0.):\n \"\"\"Creates a learning rate schedule function.\n\n Args:\n train_size (int): Number of training examples per epoch.\n per_device_batch_size (int): Batch size per device.\n n_epochs (int): Number of epochs.\n learning_rate (float): Peak learning rate.\n schedule_type (str): Type of lr schedule, \"linear\" or \"cosine\".\n warmup_ratio (float): Ratio of lr warmup.\n warmup_steps (int): Number of warmup steps.\n init_learning_rate (float): Initial learning rate before warmup.\n end_learning_rate (float): End learning rate for the schedule.\n\n Returns:\n (Callable): A lr schedule function, step -> learning rate.\n \"\"\"\n _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n total_train_steps = n_epochs * (train_size // global_micro_batch_size)\n\n if warmup_steps is None:\n warmup_steps = int(total_train_steps * warmup_ratio)\n\n return get_lr_schedule_fn(\n schedule_type=schedule_type,\n total_train_steps=total_train_steps,\n warmup_steps=warmup_steps,\n init_learning_rate=init_learning_rate,\n learning_rate=learning_rate,\n end_learning_rate=end_learning_rate)\n\n def get_sharding_rules(self, params_shape_or_params):\n \"\"\"Get sharding rules based on the parameter shapes.\"\"\"\n if self._mesh is None:\n return None\n else:\n sharding_rules = get_sharding_rules(\n params_shape_or_params=params_shape_or_params,\n n_model_shards=self._mesh.shape['mp'])\n return sharding_rules\n\n def get_params_spec(self, params_shape_or_params, params_sharding_rules):\n \"\"\"Generates parameter specs based on sharding rules.\"\"\"\n return get_params_spec(\n params_shape_or_params=params_shape_or_params,\n params_sharding_rules=params_sharding_rules)\n\n def get_opt_state_spec(\n self, params_shape_or_params, params_spec, optimizer):\n \"\"\"Get optimizer state specs\"\"\"\n return get_opt_state_spec(\n params_shape_or_params=params_shape_or_params,\n params_spec=params_spec,\n optimizer=optimizer)\n\n def shard_params(self, params, params_spec, desc='params'):\n \"\"\"Distributes parameters to all devices based on the provided specs.\"\"\"\n self.log_info(info=f'Sharding {desc} ...')\n return shard_params(\n mesh=self._mesh, params=params, params_spec=params_spec)\n\n def run_model_step(self, step_fn, input_args):\n \"\"\"Executes a model step function with the provided inputs.\"\"\"\n if self._mesh is None:\n return step_fn(*input_args)\n else:\n with self._mesh:\n return step_fn(*input_args)\n\n def gen_rng(self):\n \"\"\"Get a new random number generator key and update the random state.\"\"\"\n self._rng, new_rng = jax.random.split(self._rng)\n return new_rng\n\n def log_info(self, info, title=None, step=None):\n \"\"\"Logs a messages\"\"\"\n log_info(\n info=info,\n title=title,\n logger=self._logger,\n summary_writer=self._summary_writer,\n step=step)\n\n def log_metrics(self, metrics, step):\n \"\"\"Logs metrics to TensorBoard and Weights and Biases (wandb).\"\"\"\n if self._summary_writer is not None:\n for metric_name, value in metrics.items():\n self._summary_writer.scalar(metric_name, value, step=step)\n\n if self._wandb_log_fn is not None:\n self._wandb_log_fn(metrics, step)\n\n def save_outputs(self, outputs, desc, step):\n \"\"\"Saves model outputs to workdir.\"\"\"\n if self._workdir is not None and jax.process_index() == 0:\n save_outputs(\n workdir=self._workdir,\n outputs=outputs,\n desc=desc,\n step=step,\n logger=self._logger,\n summary_writer=self._summary_writer)\n\n def save_ckpt(\n self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):\n \"\"\"Saves a checkpoint to the specified directory.\n\n Args:\n ckpt_dir (str): Directory to save the checkpoint.\n params (dict): Model parameters.\n opt_state (dict): Optimizer state.\n float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n **kwargs (dict): Additional information to be saved into\n info.json, e.g., current training step, epoch index, etc.\n \"\"\"\n ckpt_dir = os.path.abspath(ckpt_dir)\n self.log_info(f'Saving ckpt to {ckpt_dir} ...')\n save_ckpt(\n ckpt_dir=ckpt_dir,\n checkpointer=self._checkpointer,\n params=params,\n opt_state=opt_state,\n float_dtype=float_dtype,\n rng=self._rng,\n **kwargs)\n self.log_info(f'Ckpt saved into {ckpt_dir}')\n\n def load_params_shape(self, ckpt_dir):\n \"\"\"Loads the shape of the parameters from a checkpoint.\"\"\"\n return load_params_shape(ckpt_dir=ckpt_dir)\n\n def load_ckpt(self,\n ckpt_dir,\n params_sharding_rules=None,\n optimizer=None,\n float_dtype=None,\n load_params=True,\n load_opt_state=True,\n update_rng=False):\n \"\"\"Loads a checkpoint from the specified directory.\n\n Args:\n ckpt_dir (str): Directory of the checkpoint.\n params_sharding_rules (list[tuple]): Sharding rules for parameters.\n optimizer (optax optimizer): Optimizer for loading opt_state.\n float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n load_params (bool): Whether to load the parameters.\n load_opt_state (bool): Whether to load the optimizer state.\n update_rng (bool): if updating the random state of the deployer.\n\n Returns:\n (tuple): A tuple with the loaded checkpoint (in a dict with\n `\"params\"` and `\"opt_state\"`) and additional information (in a\n dict, usually including `\"steps\"`, `\"epoch_idx\"`, and `\"rng\"`).\n \"\"\"\n ckpt_dir = os.path.abspath(ckpt_dir)\n self.log_info(f'Loading ckpt from {ckpt_dir} ...')\n\n params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)\n\n specs = {}\n if self._mesh is not None:\n if params_sharding_rules is None:\n params_sharding_rules = self.get_sharding_rules(\n params_shape_or_params=params_shape)\n\n specs['params'] = self.get_params_spec(\n params_shape_or_params=params_shape,\n params_sharding_rules=params_sharding_rules)\n if optimizer is not None:\n specs['opt_state'] = self.get_opt_state_spec(\n params_shape_or_params=params_shape,\n params_spec=specs['params'],\n optimizer=optimizer)\n\n ckpt, info = load_ckpt(\n ckpt_dir=ckpt_dir,\n checkpointer=self._checkpointer,\n params_shape_or_params=params_shape,\n optimizer=optimizer,\n float_dtype=float_dtype,\n mesh=self._mesh,\n specs=specs,\n load_params=load_params,\n load_opt_state=load_opt_state)\n\n for key, value in info.items():\n if not update_rng and key == 'rng':\n continue\n self.log_info(f'{ckpt_dir}::{key} = {value}')\n\n if update_rng:\n self._rng = info['rng']\n self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')\n\n return ckpt, info\n\n def load_last_ckpt(self,\n optimizer=None,\n params_sharding_rules=None,\n float_dtype=None,\n load_params=True,\n load_opt_state=True,\n update_rng=True):\n \"\"\"Loads the last checkpoint from the work directory (self.workdir).\n See load_ckpt() for the explanation of arguments.\n \"\"\"\n try:\n last_ckpt_name = open(\n f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()\n except:\n self.log_info(\n f'{self._workdir}/ckpts/last_ckpt.txt not found. '\n f'no ckpt loaded.')\n return None, None\n\n return self.load_ckpt(\n ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',\n optimizer=optimizer,\n float_dtype=float_dtype,\n params_sharding_rules=params_sharding_rules,\n load_params=load_params,\n load_opt_state=load_opt_state,\n update_rng=update_rng)\n\n @property\n def mesh(self):\n \"\"\"Returns the mesh for model sharding\"\"\"\n return self._mesh\n\n @property\n def workdir(self):\n \"\"\"Returns the work directory.\"\"\"\n return self._workdir\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.mesh","title":"mesh
property
","text":"Returns the mesh for model sharding
"},{"location":"deployer/#redco.deployers.deployer.Deployer.workdir","title":"workdir
property
","text":"Returns the work directory.
"},{"location":"deployer/#redco.deployers.deployer.Deployer.__init__","title":"__init__(jax_seed, n_model_shards=1, verbose=True, workdir=None, n_processes=None, host0_address=None, host0_port=None, process_id=None, n_local_devices=None, run_tensorboard=False, wandb_init_kwargs=None)
","text":"Initializes a Deployer.
Parameters:
Name Type Description Defaultjax_seed
int
Seed for random number generation.
requiredn_model_shards
int
Number of shards for running large model.
1
verbose
bool
Whether to enable verbose logging.
True
workdir
str
Directory for saving logs and checkpoints.
None
n_processes
int
For multi-host, number of processes/nodes.
None
host0_address
str
For multi-host, address of the host0.
None
host0_port
int
For multi-host, port of the host0.
None
process_id
int
For multi-host, index of the current process.
None
n_local_devices
int
For multi-host, number of local devices.
None
run_tensorboard
bool
Whether to enable TensorBoard logging.
False
wandb_init_kwargs
dict
wandb.init arguments if using wandb.
None
Source code in redco/deployers/deployer.py
def __init__(self,\n jax_seed,\n n_model_shards=1,\n verbose=True,\n workdir=None,\n n_processes=None,\n host0_address=None,\n host0_port=None,\n process_id=None,\n n_local_devices=None,\n run_tensorboard=False,\n wandb_init_kwargs=None):\n \"\"\" Initializes a Deployer.\n\n Args:\n jax_seed (int): Seed for random number generation.\n n_model_shards (int): Number of shards for running large model.\n verbose (bool): Whether to enable verbose logging.\n workdir (str): Directory for saving logs and checkpoints.\n n_processes (int): For multi-host, number of processes/nodes.\n host0_address (str): For multi-host, address of the host0.\n host0_port (int): For multi-host, port of the host0.\n process_id (int): For multi-host, index of the current process.\n n_local_devices (int): For multi-host, number of local devices.\n run_tensorboard (bool): Whether to enable TensorBoard logging.\n wandb_init_kwargs (dict): wandb.init arguments if using wandb.\n \"\"\"\n if n_processes is None:\n if 'SLURM_JOB_NUM_NODES' in os.environ:\n n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])\n process_id = int(os.environ['SLURM_NODEID'])\n else:\n n_processes = 1\n\n if n_processes > 1:\n local_device_ids = None if n_local_devices is None \\\n else list(range(n_local_devices))\n\n if host0_port is None:\n host0_port = DEFAULT_HOST0_PORT\n\n jax.distributed.initialize(\n coordinator_address=f'{host0_address}:{host0_port}',\n num_processes=n_processes,\n process_id=process_id,\n local_device_ids=local_device_ids)\n\n if workdir is not None:\n os.makedirs(workdir, exist_ok=True)\n\n self._verbose = verbose\n self._workdir = workdir\n self._logger = get_logger(verbose=verbose, workdir=workdir)\n\n if wandb_init_kwargs is not None and jax.process_index() == 0:\n import wandb\n wandb.init(**wandb_init_kwargs)\n self._wandb_log_fn = wandb.log\n else:\n self._wandb_log_fn = None\n\n if run_tensorboard and jax.process_index() == 0:\n from flax.metrics import tensorboard\n self._summary_writer = tensorboard.SummaryWriter(workdir)\n else:\n self._summary_writer = None\n\n self.log_info(\n f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')\n\n self._rng = jax.random.PRNGKey(seed=jax_seed)\n self._mesh = get_mesh(n_model_shards=n_model_shards)\n self._checkpointer = ocp.PyTreeCheckpointer()\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.gen_rng","title":"gen_rng()
","text":"Get a new random number generator key and update the random state.
Source code inredco/deployers/deployer.py
def gen_rng(self):\n \"\"\"Get a new random number generator key and update the random state.\"\"\"\n self._rng, new_rng = jax.random.split(self._rng)\n return new_rng\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_accumulate_grad_batches","title":"get_accumulate_grad_batches(global_batch_size, per_device_batch_size)
","text":"Calculates the number of gradient accumulation batches.
Source code inredco/deployers/deployer.py
def get_accumulate_grad_batches(\n self, global_batch_size, per_device_batch_size):\n \"\"\"Calculates the number of gradient accumulation batches.\"\"\"\n _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n assert global_batch_size % global_micro_batch_size == 0\n accumulate_grad_batches = global_batch_size // global_micro_batch_size\n\n return accumulate_grad_batches\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_local_global_micro_batch_size","title":"get_local_global_micro_batch_size(per_device_batch_size)
","text":"Get local/global micro batch sizes based on per-device batch size.
Source code inredco/deployers/deployer.py
def get_local_global_micro_batch_size(self, per_device_batch_size):\n \"\"\"Get local/global micro batch sizes based on per-device batch size.\"\"\"\n if self._mesh is None:\n local_micro_batch_size = \\\n per_device_batch_size * jax.local_device_count()\n global_micro_batch_size = \\\n local_micro_batch_size * jax.process_count()\n else:\n global_micro_batch_size = local_micro_batch_size = \\\n per_device_batch_size * self._mesh.shape['dp']\n\n return local_micro_batch_size, global_micro_batch_size\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_lr_schedule_fn","title":"get_lr_schedule_fn(train_size, per_device_batch_size, n_epochs, learning_rate, schedule_type='linear', warmup_ratio=0.0, warmup_steps=None, init_learning_rate=0.0, end_learning_rate=0.0)
","text":"Creates a learning rate schedule function.
Parameters:
Name Type Description Defaulttrain_size
int
Number of training examples per epoch.
requiredper_device_batch_size
int
Batch size per device.
requiredn_epochs
int
Number of epochs.
requiredlearning_rate
float
Peak learning rate.
requiredschedule_type
str
Type of lr schedule, \"linear\" or \"cosine\".
'linear'
warmup_ratio
float
Ratio of lr warmup.
0.0
warmup_steps
int
Number of warmup steps.
None
init_learning_rate
float
Initial learning rate before warmup.
0.0
end_learning_rate
float
End learning rate for the schedule.
0.0
Returns:
Type DescriptionCallable
A lr schedule function, step -> learning rate.
Source code inredco/deployers/deployer.py
def get_lr_schedule_fn(self,\n train_size,\n per_device_batch_size,\n n_epochs,\n learning_rate,\n schedule_type='linear',\n warmup_ratio=0.,\n warmup_steps=None,\n init_learning_rate=0.,\n end_learning_rate=0.):\n \"\"\"Creates a learning rate schedule function.\n\n Args:\n train_size (int): Number of training examples per epoch.\n per_device_batch_size (int): Batch size per device.\n n_epochs (int): Number of epochs.\n learning_rate (float): Peak learning rate.\n schedule_type (str): Type of lr schedule, \"linear\" or \"cosine\".\n warmup_ratio (float): Ratio of lr warmup.\n warmup_steps (int): Number of warmup steps.\n init_learning_rate (float): Initial learning rate before warmup.\n end_learning_rate (float): End learning rate for the schedule.\n\n Returns:\n (Callable): A lr schedule function, step -> learning rate.\n \"\"\"\n _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n total_train_steps = n_epochs * (train_size // global_micro_batch_size)\n\n if warmup_steps is None:\n warmup_steps = int(total_train_steps * warmup_ratio)\n\n return get_lr_schedule_fn(\n schedule_type=schedule_type,\n total_train_steps=total_train_steps,\n warmup_steps=warmup_steps,\n init_learning_rate=init_learning_rate,\n learning_rate=learning_rate,\n end_learning_rate=end_learning_rate)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_model_input_batches","title":"get_model_input_batches(examples, per_device_batch_size, collate_fn, shuffle, shuffle_rng, desc, is_train=False, accumulate_grad_batches=None)
","text":"Prepares model input batches from examples.
Parameters:
Name Type Description Defaultexamples
list
List of input examples.
requiredper_device_batch_size
int
Batch size per device.
requiredcollate_fn
Callable
Function to collate the examples.
requiredshuffle
bool
Whether to shuffle the examples.
requiredshuffle_rng
`jax.numpy.Array`
RNG for randomness of shuffling.
requireddesc
str
Description in the progress bar.
requiredis_train
bool
Whether the data is for training.
False
accumulate_grad_batches
int
gradient accumulation batches.
None
Returns:
Type Descriptiongenerator
A python generator of batched model inputs.
Source code inredco/deployers/deployer.py
def get_model_input_batches(self,\n examples,\n per_device_batch_size,\n collate_fn,\n shuffle,\n shuffle_rng,\n desc,\n is_train=False,\n accumulate_grad_batches=None):\n \"\"\"Prepares model input batches from examples.\n\n Args:\n examples (list): List of input examples.\n per_device_batch_size (int): Batch size per device.\n collate_fn (Callable): Function to collate the examples.\n shuffle (bool): Whether to shuffle the examples.\n shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.\n desc (str): Description in the progress bar.\n is_train (bool): Whether the data is for training.\n accumulate_grad_batches (int): gradient accumulation batches.\n\n Returns:\n (generator): A python generator of batched model inputs.\n \"\"\"\n local_micro_batch_size, global_micro_batch_size = \\\n self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n\n examples = get_host_examples(\n examples=examples,\n global_micro_batch_size=global_micro_batch_size,\n shuffle=shuffle,\n shuffle_rng=shuffle_rng,\n mesh=self._mesh)\n\n if not is_train:\n desc = f'{desc} (global_batch_size = {global_micro_batch_size})'\n elif accumulate_grad_batches is None:\n desc = \\\n f'{desc} (global_micro_batch_size = {global_micro_batch_size})'\n else:\n desc = (f'{desc} ('\n f'global_micro_batch_size = {global_micro_batch_size}, '\n f'accumulate_grad_batches = {accumulate_grad_batches})')\n\n return get_data_batches(\n examples=examples,\n batch_size=local_micro_batch_size,\n collate_fn=collate_fn,\n mesh=self._mesh,\n desc=desc,\n verbose=self._verbose)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_opt_state_spec","title":"get_opt_state_spec(params_shape_or_params, params_spec, optimizer)
","text":"Get optimizer state specs
Source code inredco/deployers/deployer.py
def get_opt_state_spec(\n self, params_shape_or_params, params_spec, optimizer):\n \"\"\"Get optimizer state specs\"\"\"\n return get_opt_state_spec(\n params_shape_or_params=params_shape_or_params,\n params_spec=params_spec,\n optimizer=optimizer)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_params_spec","title":"get_params_spec(params_shape_or_params, params_sharding_rules)
","text":"Generates parameter specs based on sharding rules.
Source code inredco/deployers/deployer.py
def get_params_spec(self, params_shape_or_params, params_sharding_rules):\n \"\"\"Generates parameter specs based on sharding rules.\"\"\"\n return get_params_spec(\n params_shape_or_params=params_shape_or_params,\n params_sharding_rules=params_sharding_rules)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_sharding_rules","title":"get_sharding_rules(params_shape_or_params)
","text":"Get sharding rules based on the parameter shapes.
Source code inredco/deployers/deployer.py
def get_sharding_rules(self, params_shape_or_params):\n \"\"\"Get sharding rules based on the parameter shapes.\"\"\"\n if self._mesh is None:\n return None\n else:\n sharding_rules = get_sharding_rules(\n params_shape_or_params=params_shape_or_params,\n n_model_shards=self._mesh.shape['mp'])\n return sharding_rules\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.load_ckpt","title":"load_ckpt(ckpt_dir, params_sharding_rules=None, optimizer=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=False)
","text":"Loads a checkpoint from the specified directory.
Parameters:
Name Type Description Defaultckpt_dir
str
Directory of the checkpoint.
requiredparams_sharding_rules
list[tuple]
Sharding rules for parameters.
None
optimizer
optax optimizer
Optimizer for loading opt_state.
None
float_dtype
`jax.numpy.dtype`
Dtype for floating point numbers.
None
load_params
bool
Whether to load the parameters.
True
load_opt_state
bool
Whether to load the optimizer state.
True
update_rng
bool
if updating the random state of the deployer.
False
Returns:
Type Descriptiontuple
A tuple with the loaded checkpoint (in a dict with \"params\"
and \"opt_state\"
) and additional information (in a dict, usually including \"steps\"
, \"epoch_idx\"
, and \"rng\"
).
redco/deployers/deployer.py
def load_ckpt(self,\n ckpt_dir,\n params_sharding_rules=None,\n optimizer=None,\n float_dtype=None,\n load_params=True,\n load_opt_state=True,\n update_rng=False):\n \"\"\"Loads a checkpoint from the specified directory.\n\n Args:\n ckpt_dir (str): Directory of the checkpoint.\n params_sharding_rules (list[tuple]): Sharding rules for parameters.\n optimizer (optax optimizer): Optimizer for loading opt_state.\n float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n load_params (bool): Whether to load the parameters.\n load_opt_state (bool): Whether to load the optimizer state.\n update_rng (bool): if updating the random state of the deployer.\n\n Returns:\n (tuple): A tuple with the loaded checkpoint (in a dict with\n `\"params\"` and `\"opt_state\"`) and additional information (in a\n dict, usually including `\"steps\"`, `\"epoch_idx\"`, and `\"rng\"`).\n \"\"\"\n ckpt_dir = os.path.abspath(ckpt_dir)\n self.log_info(f'Loading ckpt from {ckpt_dir} ...')\n\n params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)\n\n specs = {}\n if self._mesh is not None:\n if params_sharding_rules is None:\n params_sharding_rules = self.get_sharding_rules(\n params_shape_or_params=params_shape)\n\n specs['params'] = self.get_params_spec(\n params_shape_or_params=params_shape,\n params_sharding_rules=params_sharding_rules)\n if optimizer is not None:\n specs['opt_state'] = self.get_opt_state_spec(\n params_shape_or_params=params_shape,\n params_spec=specs['params'],\n optimizer=optimizer)\n\n ckpt, info = load_ckpt(\n ckpt_dir=ckpt_dir,\n checkpointer=self._checkpointer,\n params_shape_or_params=params_shape,\n optimizer=optimizer,\n float_dtype=float_dtype,\n mesh=self._mesh,\n specs=specs,\n load_params=load_params,\n load_opt_state=load_opt_state)\n\n for key, value in info.items():\n if not update_rng and key == 'rng':\n continue\n self.log_info(f'{ckpt_dir}::{key} = {value}')\n\n if update_rng:\n self._rng = info['rng']\n self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')\n\n return ckpt, info\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.load_last_ckpt","title":"load_last_ckpt(optimizer=None, params_sharding_rules=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=True)
","text":"Loads the last checkpoint from the work directory (self.workdir). See load_ckpt() for the explanation of arguments.
Source code inredco/deployers/deployer.py
def load_last_ckpt(self,\n optimizer=None,\n params_sharding_rules=None,\n float_dtype=None,\n load_params=True,\n load_opt_state=True,\n update_rng=True):\n \"\"\"Loads the last checkpoint from the work directory (self.workdir).\n See load_ckpt() for the explanation of arguments.\n \"\"\"\n try:\n last_ckpt_name = open(\n f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()\n except:\n self.log_info(\n f'{self._workdir}/ckpts/last_ckpt.txt not found. '\n f'no ckpt loaded.')\n return None, None\n\n return self.load_ckpt(\n ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',\n optimizer=optimizer,\n float_dtype=float_dtype,\n params_sharding_rules=params_sharding_rules,\n load_params=load_params,\n load_opt_state=load_opt_state,\n update_rng=update_rng)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.load_params_shape","title":"load_params_shape(ckpt_dir)
","text":"Loads the shape of the parameters from a checkpoint.
Source code inredco/deployers/deployer.py
def load_params_shape(self, ckpt_dir):\n \"\"\"Loads the shape of the parameters from a checkpoint.\"\"\"\n return load_params_shape(ckpt_dir=ckpt_dir)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.log_info","title":"log_info(info, title=None, step=None)
","text":"Logs a messages
Source code inredco/deployers/deployer.py
def log_info(self, info, title=None, step=None):\n \"\"\"Logs a messages\"\"\"\n log_info(\n info=info,\n title=title,\n logger=self._logger,\n summary_writer=self._summary_writer,\n step=step)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.log_metrics","title":"log_metrics(metrics, step)
","text":"Logs metrics to TensorBoard and Weights and Biases (wandb).
Source code inredco/deployers/deployer.py
def log_metrics(self, metrics, step):\n \"\"\"Logs metrics to TensorBoard and Weights and Biases (wandb).\"\"\"\n if self._summary_writer is not None:\n for metric_name, value in metrics.items():\n self._summary_writer.scalar(metric_name, value, step=step)\n\n if self._wandb_log_fn is not None:\n self._wandb_log_fn(metrics, step)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.run_model_step","title":"run_model_step(step_fn, input_args)
","text":"Executes a model step function with the provided inputs.
Source code inredco/deployers/deployer.py
def run_model_step(self, step_fn, input_args):\n \"\"\"Executes a model step function with the provided inputs.\"\"\"\n if self._mesh is None:\n return step_fn(*input_args)\n else:\n with self._mesh:\n return step_fn(*input_args)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.save_ckpt","title":"save_ckpt(ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs)
","text":"Saves a checkpoint to the specified directory.
Parameters:
Name Type Description Defaultckpt_dir
str
Directory to save the checkpoint.
requiredparams
dict
Model parameters.
requiredopt_state
dict
Optimizer state.
None
float_dtype
`jax.numpy.dtype`
Dtype for floating point numbers.
None
**kwargs
dict
Additional information to be saved into info.json, e.g., current training step, epoch index, etc.
{}
Source code in redco/deployers/deployer.py
def save_ckpt(\n self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):\n \"\"\"Saves a checkpoint to the specified directory.\n\n Args:\n ckpt_dir (str): Directory to save the checkpoint.\n params (dict): Model parameters.\n opt_state (dict): Optimizer state.\n float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n **kwargs (dict): Additional information to be saved into\n info.json, e.g., current training step, epoch index, etc.\n \"\"\"\n ckpt_dir = os.path.abspath(ckpt_dir)\n self.log_info(f'Saving ckpt to {ckpt_dir} ...')\n save_ckpt(\n ckpt_dir=ckpt_dir,\n checkpointer=self._checkpointer,\n params=params,\n opt_state=opt_state,\n float_dtype=float_dtype,\n rng=self._rng,\n **kwargs)\n self.log_info(f'Ckpt saved into {ckpt_dir}')\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.save_outputs","title":"save_outputs(outputs, desc, step)
","text":"Saves model outputs to workdir.
Source code inredco/deployers/deployer.py
def save_outputs(self, outputs, desc, step):\n \"\"\"Saves model outputs to workdir.\"\"\"\n if self._workdir is not None and jax.process_index() == 0:\n save_outputs(\n workdir=self._workdir,\n outputs=outputs,\n desc=desc,\n step=step,\n logger=self._logger,\n summary_writer=self._summary_writer)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.shard_params","title":"shard_params(params, params_spec, desc='params')
","text":"Distributes parameters to all devices based on the provided specs.
Source code inredco/deployers/deployer.py
def shard_params(self, params, params_spec, desc='params'):\n \"\"\"Distributes parameters to all devices based on the provided specs.\"\"\"\n self.log_info(info=f'Sharding {desc} ...')\n return shard_params(\n mesh=self._mesh, params=params, params_spec=params_spec)\n
"},{"location":"mnist/","title":"MNIST Example","text":"This is a trivial MNIST example with RedCoast. Runnable by
python main.py\n
To simulate multiple devices in cpu-only envs,
XLA_FLAGS=\"--xla_force_host_platform_device_count=8\" python main.py\n
"},{"location":"mnist/#source-code","title":"Source Code","text":"from functools import partial\nimport fire\nimport numpy as np\nfrom flax import linen as nn\nimport optax\nfrom torchvision.datasets import MNIST\nfrom redco import Deployer, Trainer, Predictor\n\n\n# A simple CNN model \n# Copied from https://github.com/google/flax/blob/main/examples/mnist/train.py\nclass CNN(nn.Module):\n @nn.compact\n def __call__(self, x):\n x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n x = nn.relu(x)\n x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n x = nn.relu(x)\n x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n x = x.reshape((x.shape[0], -1)) # flatten\n x = nn.Dense(features=256)(x)\n x = nn.relu(x)\n x = nn.Dense(features=10)(x)\n return x\n\n\n# Collate function converting a batch of raw examples to model inputs (in numpy) \ndef collate_fn(examples):\n images = np.stack(\n [np.array(example['image'])[:, :, None] for example in examples])\n labels = np.array([example['label'] for example in examples])\n\n return {'images': images, 'labels': labels}\n\n\n# Loss function converting model inputs to a scalar loss\ndef loss_fn(train_rng, state, params, batch, is_training):\n logits = state.apply_fn({'params': params}, batch['images'])\n return optax.softmax_cross_entropy_with_integer_labels(\n logits=logits, labels=batch['labels']).mean()\n\n\n# Predict function converting model inputs to the model outputs\ndef pred_fn(pred_rng, params, batch, model):\n accs = model.apply({'params': params}, batch['images']).argmax(axis=-1)\n return {'acc': accs}\n\n\n# (Optional) Evaluation function in trainer.fit. Here it computes accuracy.\ndef eval_metric_fn(examples, preds):\n preds = np.array([pred['acc'] for pred in preds])\n labels = np.array([example['label'] for example in examples])\n return {'acc': np.mean(preds == labels).item()}\n\n\ndef main(per_device_batch_size=64, learning_rate=1e-3, jax_seed=42):\n deployer = Deployer(jax_seed=jax_seed, workdir='./workdir')\n\n dataset = {\n 'train': [{'image': t[0], 'label': t[1]} for t in list(\n MNIST('./data', train=True, download=True))],\n 'test': [{'image': t[0], 'label': t[1]} for t in list(\n MNIST('./data', train=False, download=True))],\n }\n\n model = CNN()\n dummy_batch = collate_fn(examples=[dataset['train'][0]])\n params = model.init(deployer.gen_rng(), dummy_batch['images'])['params']\n\n trainer = Trainer(\n deployer=deployer,\n collate_fn=collate_fn,\n apply_fn=model.apply,\n loss_fn=loss_fn,\n params=params,\n optimizer=optax.adamw(learning_rate=learning_rate))\n\n predictor = Predictor(\n deployer=deployer,\n collate_fn=collate_fn,\n pred_fn=partial(pred_fn, model=model))\n\n trainer.fit(\n train_examples=dataset['train'],\n per_device_batch_size=per_device_batch_size,\n n_epochs=2,\n eval_examples=dataset['test'],\n eval_predictor=predictor,\n eval_metric_fn=eval_metric_fn)\n\n\nif __name__ == '__main__':\n fire.Fire(main)\n
"},{"location":"predictor/","title":"Predictor","text":""},{"location":"predictor/#redco.predictors.predictor.Predictor","title":"Predictor
","text":"Predictor class managing distributed inference process.
Attributes:
Name Type Descriptionmesh
jax Mesh
Mesh used for distributed inference.
Source code inredco/predictors/predictor.py
class Predictor:\n \"\"\"Predictor class managing distributed inference process.\n\n Attributes:\n mesh (jax Mesh): Mesh used for distributed inference.\n \"\"\"\n def __init__(self,\n deployer,\n collate_fn,\n pred_fn,\n output_fn=None,\n params_sharding_rules=None):\n \"\"\"Initializes a Predictor instance.\n\n Args:\n deployer (Deployer): A deployer for low-level operations.\n collate_fn (Callable): A function making model inputs from raw data,\n e.g., tokenizing sentences into input_ids.\n pred_fn (Callable): A function producing model outputs from inputs,\n e.g., running beam search with a language model.\n output_fn (Callable): A function post-processing model outputs,\n e.g., decoding generated ids to text.\n params_sharding_rules (list[tuple]): Rules for sharding parameters.\n \"\"\"\n self._deployer = deployer\n self._collate_fn = partial(collate_fn_wrapper, collate_fn=collate_fn)\n self._params_sharding_rules = params_sharding_rules\n self._pred_fn = partial(pred_fn_wrapper, pred_fn=pred_fn)\n self._p_pred_step = None\n\n if output_fn is None:\n self._output_fn = default_output_fn\n else:\n self._output_fn = output_fn\n\n def setup_running_step(self, dummy_batch, params_shape_or_params):\n \"\"\"Sets up the prediction step function for distributed inference.\n\n Args:\n dummy_batch (PyTree): A dummy batch used to determine data shapes.\n params_shape_or_params (dict): The shape of params or actual params.\n \"\"\"\n pred_step_fn = partial(pred_step, pred_fn=self._pred_fn, mesh=self.mesh)\n\n if self.mesh is None:\n self._p_pred_step = jax.pmap(pred_step_fn, axis_name='dp')\n else:\n data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params_shape_or_params,\n params_sharding_rules=self._params_sharding_rules)\n self._p_pred_step = pjit(\n pred_step_fn,\n in_shardings=(None, params_spec, data_spec),\n out_shardings=None)\n\n def predict(self,\n examples,\n per_device_batch_size,\n params,\n params_replicated=False,\n params_sharded=False,\n desc=None):\n \"\"\"Runs distributed prediction on a list of examples.\n\n Args:\n examples (list): Input examples for prediction.\n per_device_batch_size (int): Batch size per device.\n params (dict): Model parameters in a dict/FrozenDict.\n params_replicated (bool): if the params are already replicated.\n params_sharded (bool): if the parameters are already sharded.\n desc (str): Description to show in the progress bar.\n\n Returns:\n (list): A list of predictions corresponding to the input examples.\n \"\"\"\n raw_n_inputs = len(examples)\n _, global_micro_batch_size = \\\n self._deployer.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n examples = examples + [examples[0]] * (global_micro_batch_size - 1)\n examples = add_idxes(examples=examples)\n\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=False,\n shuffle_rng=None,\n desc=f'Predicting ({desc})' if desc is not None else 'Predicting')\n\n params = freeze(params)\n if (self.mesh is None) and (not params_replicated):\n params = replicate(params)\n if (self.mesh is not None) and (not params_sharded):\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params,\n params_sharding_rules=self._params_sharding_rules)\n params = self._deployer.shard_params(\n params=params, params_spec=params_spec)\n\n preds = []\n for batch in data_batches:\n if self._p_pred_step is None:\n self.setup_running_step(\n dummy_batch=batch, params_shape_or_params=params)\n\n pred_rng = self._deployer.gen_rng()\n if self.mesh is None:\n pred_rng = jax.random.split(\n pred_rng, num=jax.process_count())[jax.process_index()]\n pred_rng = shard_prng_key(pred_rng)\n\n batch_preds_with_idxes = self._deployer.run_model_step(\n step_fn=self._p_pred_step,\n input_args=(pred_rng, params, batch))\n batch_preds = process_batch_preds(\n batch_preds_with_idxes=batch_preds_with_idxes, mesh=self.mesh)\n batch_preds = self._output_fn(batch_preds)\n\n assert isinstance(batch_preds, list) and \\\n len(batch_preds) == global_micro_batch_size\n preds.extend(batch_preds)\n\n return preds[:raw_n_inputs]\n\n @property\n def mesh(self):\n \"\"\"Returns the mesh used for distributed inference.\"\"\"\n return self._deployer.mesh\n
"},{"location":"predictor/#redco.predictors.predictor.Predictor.mesh","title":"mesh
property
","text":"Returns the mesh used for distributed inference.
"},{"location":"predictor/#redco.predictors.predictor.Predictor.__init__","title":"__init__(deployer, collate_fn, pred_fn, output_fn=None, params_sharding_rules=None)
","text":"Initializes a Predictor instance.
Parameters:
Name Type Description Defaultdeployer
Deployer
A deployer for low-level operations.
requiredcollate_fn
Callable
A function making model inputs from raw data, e.g., tokenizing sentences into input_ids.
requiredpred_fn
Callable
A function producing model outputs from inputs, e.g., running beam search with a language model.
requiredoutput_fn
Callable
A function post-processing model outputs, e.g., decoding generated ids to text.
None
params_sharding_rules
list[tuple]
Rules for sharding parameters.
None
Source code in redco/predictors/predictor.py
def __init__(self,\n deployer,\n collate_fn,\n pred_fn,\n output_fn=None,\n params_sharding_rules=None):\n \"\"\"Initializes a Predictor instance.\n\n Args:\n deployer (Deployer): A deployer for low-level operations.\n collate_fn (Callable): A function making model inputs from raw data,\n e.g., tokenizing sentences into input_ids.\n pred_fn (Callable): A function producing model outputs from inputs,\n e.g., running beam search with a language model.\n output_fn (Callable): A function post-processing model outputs,\n e.g., decoding generated ids to text.\n params_sharding_rules (list[tuple]): Rules for sharding parameters.\n \"\"\"\n self._deployer = deployer\n self._collate_fn = partial(collate_fn_wrapper, collate_fn=collate_fn)\n self._params_sharding_rules = params_sharding_rules\n self._pred_fn = partial(pred_fn_wrapper, pred_fn=pred_fn)\n self._p_pred_step = None\n\n if output_fn is None:\n self._output_fn = default_output_fn\n else:\n self._output_fn = output_fn\n
"},{"location":"predictor/#redco.predictors.predictor.Predictor.predict","title":"predict(examples, per_device_batch_size, params, params_replicated=False, params_sharded=False, desc=None)
","text":"Runs distributed prediction on a list of examples.
Parameters:
Name Type Description Defaultexamples
list
Input examples for prediction.
requiredper_device_batch_size
int
Batch size per device.
requiredparams
dict
Model parameters in a dict/FrozenDict.
requiredparams_replicated
bool
if the params are already replicated.
False
params_sharded
bool
if the parameters are already sharded.
False
desc
str
Description to show in the progress bar.
None
Returns:
Type Descriptionlist
A list of predictions corresponding to the input examples.
Source code inredco/predictors/predictor.py
def predict(self,\n examples,\n per_device_batch_size,\n params,\n params_replicated=False,\n params_sharded=False,\n desc=None):\n \"\"\"Runs distributed prediction on a list of examples.\n\n Args:\n examples (list): Input examples for prediction.\n per_device_batch_size (int): Batch size per device.\n params (dict): Model parameters in a dict/FrozenDict.\n params_replicated (bool): if the params are already replicated.\n params_sharded (bool): if the parameters are already sharded.\n desc (str): Description to show in the progress bar.\n\n Returns:\n (list): A list of predictions corresponding to the input examples.\n \"\"\"\n raw_n_inputs = len(examples)\n _, global_micro_batch_size = \\\n self._deployer.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n examples = examples + [examples[0]] * (global_micro_batch_size - 1)\n examples = add_idxes(examples=examples)\n\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=False,\n shuffle_rng=None,\n desc=f'Predicting ({desc})' if desc is not None else 'Predicting')\n\n params = freeze(params)\n if (self.mesh is None) and (not params_replicated):\n params = replicate(params)\n if (self.mesh is not None) and (not params_sharded):\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params,\n params_sharding_rules=self._params_sharding_rules)\n params = self._deployer.shard_params(\n params=params, params_spec=params_spec)\n\n preds = []\n for batch in data_batches:\n if self._p_pred_step is None:\n self.setup_running_step(\n dummy_batch=batch, params_shape_or_params=params)\n\n pred_rng = self._deployer.gen_rng()\n if self.mesh is None:\n pred_rng = jax.random.split(\n pred_rng, num=jax.process_count())[jax.process_index()]\n pred_rng = shard_prng_key(pred_rng)\n\n batch_preds_with_idxes = self._deployer.run_model_step(\n step_fn=self._p_pred_step,\n input_args=(pred_rng, params, batch))\n batch_preds = process_batch_preds(\n batch_preds_with_idxes=batch_preds_with_idxes, mesh=self.mesh)\n batch_preds = self._output_fn(batch_preds)\n\n assert isinstance(batch_preds, list) and \\\n len(batch_preds) == global_micro_batch_size\n preds.extend(batch_preds)\n\n return preds[:raw_n_inputs]\n
"},{"location":"predictor/#redco.predictors.predictor.Predictor.setup_running_step","title":"setup_running_step(dummy_batch, params_shape_or_params)
","text":"Sets up the prediction step function for distributed inference.
Parameters:
Name Type Description Defaultdummy_batch
PyTree
A dummy batch used to determine data shapes.
requiredparams_shape_or_params
dict
The shape of params or actual params.
required Source code inredco/predictors/predictor.py
def setup_running_step(self, dummy_batch, params_shape_or_params):\n \"\"\"Sets up the prediction step function for distributed inference.\n\n Args:\n dummy_batch (PyTree): A dummy batch used to determine data shapes.\n params_shape_or_params (dict): The shape of params or actual params.\n \"\"\"\n pred_step_fn = partial(pred_step, pred_fn=self._pred_fn, mesh=self.mesh)\n\n if self.mesh is None:\n self._p_pred_step = jax.pmap(pred_step_fn, axis_name='dp')\n else:\n data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params_shape_or_params,\n params_sharding_rules=self._params_sharding_rules)\n self._p_pred_step = pjit(\n pred_step_fn,\n in_shardings=(None, params_spec, data_spec),\n out_shardings=None)\n
"},{"location":"trainer/","title":"Trainer","text":""},{"location":"trainer/#redco.trainers.trainer.Trainer","title":"Trainer
","text":"Trainer class managing distributed training process.
Attributes:
Name Type Descriptionstep
int
Current training step.
workdir
str
Working directory for saving checkpoints and logs.
mesh
jax Mesh
Mesh used for distributed training.
state
flax TrainState
Current training state.
Source code inredco/trainers/trainer.py
class Trainer:\n \"\"\"Trainer class managing distributed training process.\n\n Attributes:\n step (int): Current training step.\n workdir (str): Working directory for saving checkpoints and logs.\n mesh (jax Mesh): Mesh used for distributed training.\n state (flax TrainState): Current training state.\n \"\"\"\n def __init__(self,\n deployer,\n collate_fn,\n apply_fn,\n loss_fn,\n params,\n optimizer,\n opt_state=None,\n compute_dtype=jnp.float32,\n last_ckpt_info=None,\n lr_schedule_fn=None,\n accumulate_grad_batches=None,\n params_sharding_rules=None,\n train_step_fn=None):\n \"\"\"Initializes the Trainer with initial parameters, etc.\n\n Args:\n deployer (Deployer): A deployer supporting low-level operations.\n collate_fn (Callable): The function converting a data batch to model\n inputs, e.g., tokenizing sentences into input_ids.\n apply_fn (Callable): The function to apply the model, such as\n model.apply for Flax modules, or model itself for HuggingFace\n models. It would be set as state.apply_fn, and used in loss_fn.\n loss_fn (Callable): The loss function converting model inputs to a\n scalar loss, e.g., computing cross-entropy loss from input_ids.\n params (dict): Initial model parameters.\n optimizer (optax optimizer): The optimizer used for training.\n opt_state (dict): optimizer state.\n compute_dtype (dtype): Computation dtype, e.g., `jnp.bfloat16`,\n independent of param dtypes. (for mixed-precision training)\n last_ckpt_info (dict): the beginning step and epoch.\n lr_schedule_fn (Callable): The learning rate schedule\n function converting step to learning rate.\n accumulate_grad_batches (int): Gradient accumulation step.\n params_sharding_rules (list): Sharding rules.\n train_step_fn (Callable): For fully customizing every training step,\n e.g., per-sample gradient noising for data-private training.\n \"\"\"\n self._deployer = deployer\n self._collate_fn = collate_fn\n self._apply_fn = apply_fn\n self._loss_fn = loss_fn\n self._optimizer = optimizer\n self._compute_dtype = compute_dtype\n self._lr_schedule_fn = lr_schedule_fn\n self._accumulate_grad_batches = accumulate_grad_batches\n self._params_sharding_rules = params_sharding_rules\n self._train_step_fn = train_step_fn\n\n self._state = None\n self._state_spec = None\n self._p_train_step = None\n self._p_eval_step = None\n\n self._init_step = 0\n self._init_epoch_idx = 0\n if last_ckpt_info is not None:\n self._init_step = last_ckpt_info.get('step', 0)\n self._init_epoch_idx = last_ckpt_info.get('epoch_idx', -1) + 1\n\n n_params = sum([param.size for param in jax.tree.leaves(params)])\n self._deployer.log_info(f'{n_params:,}', title='Parameters')\n\n self.set_train_state(\n apply_fn=self._apply_fn,\n params=params,\n optimizer=self._optimizer,\n step=self._init_step,\n opt_state=opt_state)\n\n def set_train_state(\n self, apply_fn, params, optimizer, step, opt_state=None):\n \"\"\"Sets/Resets the training state with given parameters and optimizer.\n\n Args:\n apply_fn (Callable): The function to apply the model.\n params (dict): Model parameters.\n optimizer (dict): The optimizer used for training.\n step (int): The training step.\n opt_state (dict): The state of the optimizer.\n \"\"\"\n self._deployer.log_info('Setting train_state ...')\n params = freeze(params)\n\n if self.mesh is None:\n params = jax.device_put(params, jax.local_devices()[0])\n if opt_state is None:\n self._deployer.log_info('Initializing opt_state ...')\n opt_state = optimizer.init(params)\n else:\n opt_state = jax.device_put(opt_state, jax.local_devices()[0])\n\n self._state = train_state.TrainState(\n step=step,\n apply_fn=apply_fn,\n params=params,\n tx=optimizer,\n opt_state=opt_state)\n self._state = replicate(self._state)\n else:\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params,\n params_sharding_rules=self._params_sharding_rules)\n params = self._deployer.shard_params(\n params=params, params_spec=params_spec)\n\n if opt_state is None:\n self._deployer.log_info('Initializing opt_state ...')\n opt_state = optimizer.init(params)\n\n opt_state_spec = self._deployer.get_opt_state_spec(\n params_shape_or_params=params,\n params_spec=params_spec,\n optimizer=optimizer)\n opt_state = self._deployer.shard_params(\n params=opt_state,\n params_spec=opt_state_spec,\n desc='opt_state')\n\n self._state = train_state.TrainState(\n apply_fn=apply_fn,\n params=params,\n tx=optimizer,\n opt_state=opt_state,\n step=step)\n\n self._state_spec = train_state.TrainState(\n apply_fn=apply_fn,\n params=params_spec,\n tx=optimizer,\n opt_state=opt_state_spec,\n step=None)\n\n def setup_running_step(self, dummy_batch):\n \"\"\"Sets up the running step functions for training and evaluation.\n\n Args:\n dummy_batch (PyTree): A dummy batch of data.\n \"\"\"\n train_step_fn = partial(\n self._train_step_fn or default_train_step,\n loss_fn=self._loss_fn,\n lr_schedule_fn=self._lr_schedule_fn,\n mesh=self.mesh,\n compute_dtype=self._compute_dtype)\n eval_step_fn = partial(\n eval_step,\n loss_fn=self._loss_fn,\n mesh=self.mesh,\n compute_dtype=self._compute_dtype)\n\n if self.mesh is None:\n self._p_train_step = jax.pmap(train_step_fn, axis_name='dp')\n self._p_eval_step = jax.pmap(eval_step_fn, axis_name='dp')\n else:\n data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n self._p_train_step = pjit(\n train_step_fn,\n in_shardings=(None, self._state_spec, data_spec),\n out_shardings=(self._state_spec, None),\n donate_argnums=(1, ))\n self._p_eval_step = pjit(\n eval_step_fn,\n in_shardings=(self._state_spec, data_spec),\n out_shardings=None)\n\n def train(self, examples, per_device_batch_size, desc=None):\n \"\"\"Trains the model on the provided examples.\n\n Args:\n examples (list): Training examples in python list.\n per_device_batch_size (int): The batch size per device.\n desc (str): Description in the progress bar.\n \"\"\"\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=True,\n shuffle_rng=self._deployer.gen_rng(),\n desc=f'Training ({desc})' if desc is not None else 'Training',\n is_train=True,\n accumulate_grad_batches=self._accumulate_grad_batches)\n\n for batch in data_batches:\n if self._p_train_step is None:\n self.setup_running_step(dummy_batch=batch)\n\n train_rng = self._deployer.gen_rng()\n if self.mesh is None:\n train_rng = jax.random.split(\n train_rng, num=jax.process_count())[jax.process_index()]\n train_rng = shard_prng_key(train_rng)\n self._state, metrics = self._deployer.run_model_step(\n step_fn=self._p_train_step,\n input_args=(train_rng, self._state, batch))\n\n if self.mesh is None:\n metrics = unreplicate(metrics)\n data_batches.set_postfix(**metrics)\n self._deployer.log_metrics(metrics=metrics, step=self.step)\n\n def eval_loss(self, examples, per_device_batch_size, desc=None):\n \"\"\"Evaluates the loss on the provided examples.\n\n Args:\n examples (list): Evaluation examples in list.\n per_device_batch_size (int): The batch size per device.\n desc (str): Description in the progress bar.\n\n Returns:\n (float): The average loss over the evaluation examples.\n \"\"\"\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=False,\n shuffle_rng=None,\n desc=f'Evaluating ({desc})' if desc is not None else 'Evaluating')\n\n losses = []\n for batch in data_batches:\n if self._p_eval_step is None:\n self.setup_running_step(dummy_batch=batch)\n\n metrics = self._deployer.run_model_step(\n step_fn=self._p_eval_step, input_args=(self._state, batch))\n\n if self.mesh is None:\n metrics = unreplicate(metrics)\n\n losses.append(metrics['loss'].item())\n data_batches.set_postfix(**metrics)\n\n return np.mean(losses).item()\n\n def fit(self,\n train_examples,\n per_device_batch_size,\n n_epochs,\n eval_examples=None,\n eval_per_device_batch_size=None,\n eval_loss=True,\n eval_predictor=None,\n eval_metric_fn=None,\n eval_sanity_check=True,\n save_every_ckpt=False,\n save_last_ckpt=False,\n save_argmin_ckpt_by_metrics=None,\n save_argmax_ckpt_by_metrics=None,\n save_opt_states=True,\n save_float_dtype=None):\n \"\"\"Fits the model on the training data for a given number of epochs,\n optionally evaluating and saving checkpoints.\n\n Args:\n train_examples (list or Callable): Training examples, can be a\n list or a function of epoch_idx (for assigning different\n examples in separate epochs/chunks),\n e.g., `train_examples=lambda epoch_idx: load_data(chunk_idx)`\n per_device_batch_size (int): The batch size per device.\n n_epochs (int): Number of epochs to train.\n eval_examples (list): Examples for evaluation and prediction.\n eval_per_device_batch_size (int): Batch size for evaluation\n eval_loss (bool): Whether to evaluate loss.\n eval_predictor (Predictor): Predictor working on `eval_examples`.\n eval_metric_fn (Callable): Metric function for prediction.\n eval_sanity_check (bool): if to run a sanity check for\n evaluation & predict functions before training.\n save_every_ckpt (bool): if to save a ckpt after every epoch.\n save_last_ckpt (bool): Whether to save the last checkpoint.\n save_argmin_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n based on minimum values.\n save_argmax_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n based on maximum values.\n save_opt_states (bool): of to save optimizer states in ckpts.\n save_float_dtype (bool): The data type for saving checkpoints.\n \"\"\"\n if eval_per_device_batch_size is None:\n eval_per_device_batch_size = per_device_batch_size\n\n if save_argmax_ckpt_by_metrics is None:\n save_argmax_ckpt_by_metrics = []\n if save_argmin_ckpt_by_metrics is None:\n save_argmin_ckpt_by_metrics = []\n min_metrics, max_metrics = {}, {}\n\n if os.path.exists(f'{self.workdir}/min_metrics.json'):\n min_metrics = json.load(open(\n f'{self.workdir}/min_metrics.json'))\n self._deployer.log_info(\n json.dumps(min_metrics, indent=4), title='Detected min_metrics')\n\n if os.path.exists(f'{self.workdir}/max_metrics.json'):\n max_metrics = json.load(open(\n f'{self.workdir}/max_metrics.json'))\n self._deployer.log_info(\n json.dumps(max_metrics, indent=4), title='Detected max_metrics')\n\n if eval_sanity_check and eval_examples is not None:\n rng_backup = self._deployer._rng\n _, eval_global_micro_batch_size = \\\n self._deployer.get_local_global_micro_batch_size(\n per_device_batch_size=eval_per_device_batch_size)\n\n if eval_loss:\n self.eval_loss(\n examples=eval_examples[:eval_global_micro_batch_size],\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'Sanity check')\n self._deployer.log_info(\n 'Sanity check (for evaluation loss) passed.')\n\n if eval_predictor is not None:\n preds = eval_predictor.predict(\n examples=eval_examples[:eval_global_micro_batch_size],\n params=self._state.params,\n params_replicated=(self.mesh is None),\n params_sharded=(self.mesh is not None),\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'Sanity check')\n self._deployer.log_info(\n 'Sanity check (for prediction) passed.')\n\n if eval_metric_fn is not None:\n json.dumps(eval_metric_fn(\n examples=eval_examples[:eval_global_micro_batch_size],\n preds=preds))\n self._deployer.log_info(\n 'Sanity check (for evaluation metrics) passed.')\n\n self._deployer._rng = rng_backup\n\n for epoch_idx in range(self._init_epoch_idx, n_epochs):\n if isinstance(train_examples, list):\n epoch_train_examples = train_examples\n else:\n epoch_train_examples = train_examples(epoch_idx=epoch_idx)\n\n self.train(\n examples=epoch_train_examples,\n per_device_batch_size=per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n\n save_ckpt_kwargs = {\n 'epoch_idx': epoch_idx,\n 'save_opt_state': save_opt_states,\n 'float_dtype': save_float_dtype\n }\n\n if eval_examples is None:\n self._deployer.log_info(\n 'No evaluation cuz \\'eval_examples\\' is None.')\n else:\n eval_metrics = {}\n\n if eval_loss:\n loss = self.eval_loss(\n examples=eval_examples,\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n eval_metrics['loss'] = loss\n\n if eval_predictor is not None:\n preds = eval_predictor.predict(\n examples=eval_examples,\n params=self._state.params,\n params_replicated=(self.mesh is None),\n params_sharded=(self.mesh is not None),\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n\n if eval_metric_fn is not None:\n eval_metrics.update(eval_metric_fn(\n examples=eval_examples, preds=preds))\n\n eval_outputs = [\n {'example': example, 'pred': pred}\n for example, pred in zip(eval_examples, preds)]\n\n self._deployer.save_outputs(\n outputs=eval_outputs,\n desc=f'epoch{epoch_idx}',\n step=self.step)\n\n self._deployer.log_info(\n info=json.dumps(eval_metrics, indent=4),\n title=f'Eval results',\n step=self.step)\n self._deployer.log_metrics(metrics={\n f'eval_{key}': value\n for key, value in eval_metrics.items()\n }, step=self.step)\n\n if self.workdir is not None:\n result_filepath = \\\n f'{self.workdir}/eval_results_epoch{epoch_idx}.json'\n json.dump(\n eval_metrics, open(result_filepath, 'w'), indent=4)\n self._deployer.log_info(\n f'eval_results saved into {result_filepath}.')\n\n for key in save_argmin_ckpt_by_metrics:\n assert self.workdir is not None\n if eval_metrics[key] < min_metrics.get(key, float('inf')):\n min_metrics[key] = eval_metrics[key]\n\n if jax.process_index() == 0:\n self._deployer.log_info(\n f'minimal {key} updated to {min_metrics[key]}')\n json.dump(min_metrics, open(\n f'{self.workdir}/min_metrics.json', 'w'))\n\n self.save_ckpt(\n ckpt_name=f'min_{key}', **save_ckpt_kwargs)\n\n for key in save_argmax_ckpt_by_metrics:\n assert self.workdir is not None\n if eval_metrics[key] > max_metrics.get(key, float('-inf')):\n max_metrics[key] = eval_metrics[key]\n\n if jax.process_index() == 0:\n self._deployer.log_info(\n f'maximal {key} updated to {max_metrics[key]}')\n json.dump(max_metrics, open(\n f'{self.workdir}/max_metrics.json', 'w'))\n\n self.save_ckpt(\n ckpt_name=f'max_{key}', **save_ckpt_kwargs)\n\n if save_every_ckpt:\n self.save_ckpt(\n ckpt_name=f'epoch_{epoch_idx}', **save_ckpt_kwargs)\n elif save_last_ckpt:\n self.save_ckpt(ckpt_name='last', **save_ckpt_kwargs)\n\n def save_ckpt(self, epoch_idx, ckpt_name, save_opt_state, float_dtype):\n \"\"\"Saves a checkpoint into `{self.workdir}/ckpts`.\n\n Args:\n epoch_idx (int): The current epoch index.\n ckpt_name (str): The name of the checkpoint.\n save_opt_state (bool): Whether to save the optimizer state.\n float_dtype (`jax.numpy.dtype`): Data type for saving float params.\n \"\"\"\n if self.mesh is None:\n params = jax.tree.map(\n fully_replicated_host_local_array_to_global_array,\n self._state.params)\n else:\n params = self._state.params\n\n opt_state = None\n if save_opt_state:\n if self.mesh is None:\n opt_state = jax.tree.map(\n fully_replicated_host_local_array_to_global_array,\n self._state.opt_state)\n else:\n opt_state = self._state.opt_state\n\n ckpt_dir = f'{self.workdir}/ckpts/{ckpt_name}'\n self._deployer.save_ckpt(\n ckpt_dir=ckpt_dir,\n params=params,\n opt_state=opt_state,\n float_dtype=float_dtype,\n step=self.step,\n epoch_idx=epoch_idx)\n\n if jax.process_index() == 0:\n open(f'{self.workdir}/ckpts/last_ckpt.txt', 'w').write(ckpt_name)\n self._deployer.log_info(f'last ckpt updated -- {ckpt_dir}')\n\n @property\n def step(self):\n \"\"\"Returns the current training step.\"\"\"\n if self.mesh is None:\n return unreplicate(self._state.step).item()\n else:\n return self._state.step.item()\n\n @property\n def workdir(self):\n \"\"\"Returns the working directory for saving checkpoints and logs.\"\"\"\n return self._deployer.workdir\n\n @property\n def mesh(self):\n \"\"\"Returns the mesh used for distributed training.\"\"\"\n return self._deployer.mesh\n\n @property\n def state(self):\n \"\"\"Returns the current training state.\"\"\"\n return self._state\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.mesh","title":"mesh
property
","text":"Returns the mesh used for distributed training.
"},{"location":"trainer/#redco.trainers.trainer.Trainer.state","title":"state
property
","text":"Returns the current training state.
"},{"location":"trainer/#redco.trainers.trainer.Trainer.step","title":"step
property
","text":"Returns the current training step.
"},{"location":"trainer/#redco.trainers.trainer.Trainer.workdir","title":"workdir
property
","text":"Returns the working directory for saving checkpoints and logs.
"},{"location":"trainer/#redco.trainers.trainer.Trainer.__init__","title":"__init__(deployer, collate_fn, apply_fn, loss_fn, params, optimizer, opt_state=None, compute_dtype=jnp.float32, last_ckpt_info=None, lr_schedule_fn=None, accumulate_grad_batches=None, params_sharding_rules=None, train_step_fn=None)
","text":"Initializes the Trainer with initial parameters, etc.
Parameters:
Name Type Description Defaultdeployer
Deployer
A deployer supporting low-level operations.
requiredcollate_fn
Callable
The function converting a data batch to model inputs, e.g., tokenizing sentences into input_ids.
requiredapply_fn
Callable
The function to apply the model, such as model.apply for Flax modules, or model itself for HuggingFace models. It would be set as state.apply_fn, and used in loss_fn.
requiredloss_fn
Callable
The loss function converting model inputs to a scalar loss, e.g., computing cross-entropy loss from input_ids.
requiredparams
dict
Initial model parameters.
requiredoptimizer
optax optimizer
The optimizer used for training.
requiredopt_state
dict
optimizer state.
None
compute_dtype
dtype
Computation dtype, e.g., jnp.bfloat16
, independent of param dtypes. (for mixed-precision training)
float32
last_ckpt_info
dict
the beginning step and epoch.
None
lr_schedule_fn
Callable
The learning rate schedule function converting step to learning rate.
None
accumulate_grad_batches
int
Gradient accumulation step.
None
params_sharding_rules
list
Sharding rules.
None
train_step_fn
Callable
For fully customizing every training step, e.g., per-sample gradient noising for data-private training.
None
Source code in redco/trainers/trainer.py
def __init__(self,\n deployer,\n collate_fn,\n apply_fn,\n loss_fn,\n params,\n optimizer,\n opt_state=None,\n compute_dtype=jnp.float32,\n last_ckpt_info=None,\n lr_schedule_fn=None,\n accumulate_grad_batches=None,\n params_sharding_rules=None,\n train_step_fn=None):\n \"\"\"Initializes the Trainer with initial parameters, etc.\n\n Args:\n deployer (Deployer): A deployer supporting low-level operations.\n collate_fn (Callable): The function converting a data batch to model\n inputs, e.g., tokenizing sentences into input_ids.\n apply_fn (Callable): The function to apply the model, such as\n model.apply for Flax modules, or model itself for HuggingFace\n models. It would be set as state.apply_fn, and used in loss_fn.\n loss_fn (Callable): The loss function converting model inputs to a\n scalar loss, e.g., computing cross-entropy loss from input_ids.\n params (dict): Initial model parameters.\n optimizer (optax optimizer): The optimizer used for training.\n opt_state (dict): optimizer state.\n compute_dtype (dtype): Computation dtype, e.g., `jnp.bfloat16`,\n independent of param dtypes. (for mixed-precision training)\n last_ckpt_info (dict): the beginning step and epoch.\n lr_schedule_fn (Callable): The learning rate schedule\n function converting step to learning rate.\n accumulate_grad_batches (int): Gradient accumulation step.\n params_sharding_rules (list): Sharding rules.\n train_step_fn (Callable): For fully customizing every training step,\n e.g., per-sample gradient noising for data-private training.\n \"\"\"\n self._deployer = deployer\n self._collate_fn = collate_fn\n self._apply_fn = apply_fn\n self._loss_fn = loss_fn\n self._optimizer = optimizer\n self._compute_dtype = compute_dtype\n self._lr_schedule_fn = lr_schedule_fn\n self._accumulate_grad_batches = accumulate_grad_batches\n self._params_sharding_rules = params_sharding_rules\n self._train_step_fn = train_step_fn\n\n self._state = None\n self._state_spec = None\n self._p_train_step = None\n self._p_eval_step = None\n\n self._init_step = 0\n self._init_epoch_idx = 0\n if last_ckpt_info is not None:\n self._init_step = last_ckpt_info.get('step', 0)\n self._init_epoch_idx = last_ckpt_info.get('epoch_idx', -1) + 1\n\n n_params = sum([param.size for param in jax.tree.leaves(params)])\n self._deployer.log_info(f'{n_params:,}', title='Parameters')\n\n self.set_train_state(\n apply_fn=self._apply_fn,\n params=params,\n optimizer=self._optimizer,\n step=self._init_step,\n opt_state=opt_state)\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.eval_loss","title":"eval_loss(examples, per_device_batch_size, desc=None)
","text":"Evaluates the loss on the provided examples.
Parameters:
Name Type Description Defaultexamples
list
Evaluation examples in list.
requiredper_device_batch_size
int
The batch size per device.
requireddesc
str
Description in the progress bar.
None
Returns:
Type Descriptionfloat
The average loss over the evaluation examples.
Source code inredco/trainers/trainer.py
def eval_loss(self, examples, per_device_batch_size, desc=None):\n \"\"\"Evaluates the loss on the provided examples.\n\n Args:\n examples (list): Evaluation examples in list.\n per_device_batch_size (int): The batch size per device.\n desc (str): Description in the progress bar.\n\n Returns:\n (float): The average loss over the evaluation examples.\n \"\"\"\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=False,\n shuffle_rng=None,\n desc=f'Evaluating ({desc})' if desc is not None else 'Evaluating')\n\n losses = []\n for batch in data_batches:\n if self._p_eval_step is None:\n self.setup_running_step(dummy_batch=batch)\n\n metrics = self._deployer.run_model_step(\n step_fn=self._p_eval_step, input_args=(self._state, batch))\n\n if self.mesh is None:\n metrics = unreplicate(metrics)\n\n losses.append(metrics['loss'].item())\n data_batches.set_postfix(**metrics)\n\n return np.mean(losses).item()\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.fit","title":"fit(train_examples, per_device_batch_size, n_epochs, eval_examples=None, eval_per_device_batch_size=None, eval_loss=True, eval_predictor=None, eval_metric_fn=None, eval_sanity_check=True, save_every_ckpt=False, save_last_ckpt=False, save_argmin_ckpt_by_metrics=None, save_argmax_ckpt_by_metrics=None, save_opt_states=True, save_float_dtype=None)
","text":"Fits the model on the training data for a given number of epochs, optionally evaluating and saving checkpoints.
Parameters:
Name Type Description Defaulttrain_examples
list or Callable
Training examples, can be a list or a function of epoch_idx (for assigning different examples in separate epochs/chunks), e.g., train_examples=lambda epoch_idx: load_data(chunk_idx)
per_device_batch_size
int
The batch size per device.
requiredn_epochs
int
Number of epochs to train.
requiredeval_examples
list
Examples for evaluation and prediction.
None
eval_per_device_batch_size
int
Batch size for evaluation
None
eval_loss
bool
Whether to evaluate loss.
True
eval_predictor
Predictor
Predictor working on eval_examples
.
None
eval_metric_fn
Callable
Metric function for prediction.
None
eval_sanity_check
bool
if to run a sanity check for evaluation & predict functions before training.
True
save_every_ckpt
bool
if to save a ckpt after every epoch.
False
save_last_ckpt
bool
Whether to save the last checkpoint.
False
save_argmin_ckpt_by_metrics
list[str]
Metrics to save checkpoints based on minimum values.
None
save_argmax_ckpt_by_metrics
list[str]
Metrics to save checkpoints based on maximum values.
None
save_opt_states
bool
of to save optimizer states in ckpts.
True
save_float_dtype
bool
The data type for saving checkpoints.
None
Source code in redco/trainers/trainer.py
def fit(self,\n train_examples,\n per_device_batch_size,\n n_epochs,\n eval_examples=None,\n eval_per_device_batch_size=None,\n eval_loss=True,\n eval_predictor=None,\n eval_metric_fn=None,\n eval_sanity_check=True,\n save_every_ckpt=False,\n save_last_ckpt=False,\n save_argmin_ckpt_by_metrics=None,\n save_argmax_ckpt_by_metrics=None,\n save_opt_states=True,\n save_float_dtype=None):\n \"\"\"Fits the model on the training data for a given number of epochs,\n optionally evaluating and saving checkpoints.\n\n Args:\n train_examples (list or Callable): Training examples, can be a\n list or a function of epoch_idx (for assigning different\n examples in separate epochs/chunks),\n e.g., `train_examples=lambda epoch_idx: load_data(chunk_idx)`\n per_device_batch_size (int): The batch size per device.\n n_epochs (int): Number of epochs to train.\n eval_examples (list): Examples for evaluation and prediction.\n eval_per_device_batch_size (int): Batch size for evaluation\n eval_loss (bool): Whether to evaluate loss.\n eval_predictor (Predictor): Predictor working on `eval_examples`.\n eval_metric_fn (Callable): Metric function for prediction.\n eval_sanity_check (bool): if to run a sanity check for\n evaluation & predict functions before training.\n save_every_ckpt (bool): if to save a ckpt after every epoch.\n save_last_ckpt (bool): Whether to save the last checkpoint.\n save_argmin_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n based on minimum values.\n save_argmax_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n based on maximum values.\n save_opt_states (bool): of to save optimizer states in ckpts.\n save_float_dtype (bool): The data type for saving checkpoints.\n \"\"\"\n if eval_per_device_batch_size is None:\n eval_per_device_batch_size = per_device_batch_size\n\n if save_argmax_ckpt_by_metrics is None:\n save_argmax_ckpt_by_metrics = []\n if save_argmin_ckpt_by_metrics is None:\n save_argmin_ckpt_by_metrics = []\n min_metrics, max_metrics = {}, {}\n\n if os.path.exists(f'{self.workdir}/min_metrics.json'):\n min_metrics = json.load(open(\n f'{self.workdir}/min_metrics.json'))\n self._deployer.log_info(\n json.dumps(min_metrics, indent=4), title='Detected min_metrics')\n\n if os.path.exists(f'{self.workdir}/max_metrics.json'):\n max_metrics = json.load(open(\n f'{self.workdir}/max_metrics.json'))\n self._deployer.log_info(\n json.dumps(max_metrics, indent=4), title='Detected max_metrics')\n\n if eval_sanity_check and eval_examples is not None:\n rng_backup = self._deployer._rng\n _, eval_global_micro_batch_size = \\\n self._deployer.get_local_global_micro_batch_size(\n per_device_batch_size=eval_per_device_batch_size)\n\n if eval_loss:\n self.eval_loss(\n examples=eval_examples[:eval_global_micro_batch_size],\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'Sanity check')\n self._deployer.log_info(\n 'Sanity check (for evaluation loss) passed.')\n\n if eval_predictor is not None:\n preds = eval_predictor.predict(\n examples=eval_examples[:eval_global_micro_batch_size],\n params=self._state.params,\n params_replicated=(self.mesh is None),\n params_sharded=(self.mesh is not None),\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'Sanity check')\n self._deployer.log_info(\n 'Sanity check (for prediction) passed.')\n\n if eval_metric_fn is not None:\n json.dumps(eval_metric_fn(\n examples=eval_examples[:eval_global_micro_batch_size],\n preds=preds))\n self._deployer.log_info(\n 'Sanity check (for evaluation metrics) passed.')\n\n self._deployer._rng = rng_backup\n\n for epoch_idx in range(self._init_epoch_idx, n_epochs):\n if isinstance(train_examples, list):\n epoch_train_examples = train_examples\n else:\n epoch_train_examples = train_examples(epoch_idx=epoch_idx)\n\n self.train(\n examples=epoch_train_examples,\n per_device_batch_size=per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n\n save_ckpt_kwargs = {\n 'epoch_idx': epoch_idx,\n 'save_opt_state': save_opt_states,\n 'float_dtype': save_float_dtype\n }\n\n if eval_examples is None:\n self._deployer.log_info(\n 'No evaluation cuz \\'eval_examples\\' is None.')\n else:\n eval_metrics = {}\n\n if eval_loss:\n loss = self.eval_loss(\n examples=eval_examples,\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n eval_metrics['loss'] = loss\n\n if eval_predictor is not None:\n preds = eval_predictor.predict(\n examples=eval_examples,\n params=self._state.params,\n params_replicated=(self.mesh is None),\n params_sharded=(self.mesh is not None),\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n\n if eval_metric_fn is not None:\n eval_metrics.update(eval_metric_fn(\n examples=eval_examples, preds=preds))\n\n eval_outputs = [\n {'example': example, 'pred': pred}\n for example, pred in zip(eval_examples, preds)]\n\n self._deployer.save_outputs(\n outputs=eval_outputs,\n desc=f'epoch{epoch_idx}',\n step=self.step)\n\n self._deployer.log_info(\n info=json.dumps(eval_metrics, indent=4),\n title=f'Eval results',\n step=self.step)\n self._deployer.log_metrics(metrics={\n f'eval_{key}': value\n for key, value in eval_metrics.items()\n }, step=self.step)\n\n if self.workdir is not None:\n result_filepath = \\\n f'{self.workdir}/eval_results_epoch{epoch_idx}.json'\n json.dump(\n eval_metrics, open(result_filepath, 'w'), indent=4)\n self._deployer.log_info(\n f'eval_results saved into {result_filepath}.')\n\n for key in save_argmin_ckpt_by_metrics:\n assert self.workdir is not None\n if eval_metrics[key] < min_metrics.get(key, float('inf')):\n min_metrics[key] = eval_metrics[key]\n\n if jax.process_index() == 0:\n self._deployer.log_info(\n f'minimal {key} updated to {min_metrics[key]}')\n json.dump(min_metrics, open(\n f'{self.workdir}/min_metrics.json', 'w'))\n\n self.save_ckpt(\n ckpt_name=f'min_{key}', **save_ckpt_kwargs)\n\n for key in save_argmax_ckpt_by_metrics:\n assert self.workdir is not None\n if eval_metrics[key] > max_metrics.get(key, float('-inf')):\n max_metrics[key] = eval_metrics[key]\n\n if jax.process_index() == 0:\n self._deployer.log_info(\n f'maximal {key} updated to {max_metrics[key]}')\n json.dump(max_metrics, open(\n f'{self.workdir}/max_metrics.json', 'w'))\n\n self.save_ckpt(\n ckpt_name=f'max_{key}', **save_ckpt_kwargs)\n\n if save_every_ckpt:\n self.save_ckpt(\n ckpt_name=f'epoch_{epoch_idx}', **save_ckpt_kwargs)\n elif save_last_ckpt:\n self.save_ckpt(ckpt_name='last', **save_ckpt_kwargs)\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.save_ckpt","title":"save_ckpt(epoch_idx, ckpt_name, save_opt_state, float_dtype)
","text":"Saves a checkpoint into {self.workdir}/ckpts
.
Parameters:
Name Type Description Defaultepoch_idx
int
The current epoch index.
requiredckpt_name
str
The name of the checkpoint.
requiredsave_opt_state
bool
Whether to save the optimizer state.
requiredfloat_dtype
`jax.numpy.dtype`
Data type for saving float params.
required Source code inredco/trainers/trainer.py
def save_ckpt(self, epoch_idx, ckpt_name, save_opt_state, float_dtype):\n \"\"\"Saves a checkpoint into `{self.workdir}/ckpts`.\n\n Args:\n epoch_idx (int): The current epoch index.\n ckpt_name (str): The name of the checkpoint.\n save_opt_state (bool): Whether to save the optimizer state.\n float_dtype (`jax.numpy.dtype`): Data type for saving float params.\n \"\"\"\n if self.mesh is None:\n params = jax.tree.map(\n fully_replicated_host_local_array_to_global_array,\n self._state.params)\n else:\n params = self._state.params\n\n opt_state = None\n if save_opt_state:\n if self.mesh is None:\n opt_state = jax.tree.map(\n fully_replicated_host_local_array_to_global_array,\n self._state.opt_state)\n else:\n opt_state = self._state.opt_state\n\n ckpt_dir = f'{self.workdir}/ckpts/{ckpt_name}'\n self._deployer.save_ckpt(\n ckpt_dir=ckpt_dir,\n params=params,\n opt_state=opt_state,\n float_dtype=float_dtype,\n step=self.step,\n epoch_idx=epoch_idx)\n\n if jax.process_index() == 0:\n open(f'{self.workdir}/ckpts/last_ckpt.txt', 'w').write(ckpt_name)\n self._deployer.log_info(f'last ckpt updated -- {ckpt_dir}')\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.set_train_state","title":"set_train_state(apply_fn, params, optimizer, step, opt_state=None)
","text":"Sets/Resets the training state with given parameters and optimizer.
Parameters:
Name Type Description Defaultapply_fn
Callable
The function to apply the model.
requiredparams
dict
Model parameters.
requiredoptimizer
dict
The optimizer used for training.
requiredstep
int
The training step.
requiredopt_state
dict
The state of the optimizer.
None
Source code in redco/trainers/trainer.py
def set_train_state(\n self, apply_fn, params, optimizer, step, opt_state=None):\n \"\"\"Sets/Resets the training state with given parameters and optimizer.\n\n Args:\n apply_fn (Callable): The function to apply the model.\n params (dict): Model parameters.\n optimizer (dict): The optimizer used for training.\n step (int): The training step.\n opt_state (dict): The state of the optimizer.\n \"\"\"\n self._deployer.log_info('Setting train_state ...')\n params = freeze(params)\n\n if self.mesh is None:\n params = jax.device_put(params, jax.local_devices()[0])\n if opt_state is None:\n self._deployer.log_info('Initializing opt_state ...')\n opt_state = optimizer.init(params)\n else:\n opt_state = jax.device_put(opt_state, jax.local_devices()[0])\n\n self._state = train_state.TrainState(\n step=step,\n apply_fn=apply_fn,\n params=params,\n tx=optimizer,\n opt_state=opt_state)\n self._state = replicate(self._state)\n else:\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params,\n params_sharding_rules=self._params_sharding_rules)\n params = self._deployer.shard_params(\n params=params, params_spec=params_spec)\n\n if opt_state is None:\n self._deployer.log_info('Initializing opt_state ...')\n opt_state = optimizer.init(params)\n\n opt_state_spec = self._deployer.get_opt_state_spec(\n params_shape_or_params=params,\n params_spec=params_spec,\n optimizer=optimizer)\n opt_state = self._deployer.shard_params(\n params=opt_state,\n params_spec=opt_state_spec,\n desc='opt_state')\n\n self._state = train_state.TrainState(\n apply_fn=apply_fn,\n params=params,\n tx=optimizer,\n opt_state=opt_state,\n step=step)\n\n self._state_spec = train_state.TrainState(\n apply_fn=apply_fn,\n params=params_spec,\n tx=optimizer,\n opt_state=opt_state_spec,\n step=None)\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.setup_running_step","title":"setup_running_step(dummy_batch)
","text":"Sets up the running step functions for training and evaluation.
Parameters:
Name Type Description Defaultdummy_batch
PyTree
A dummy batch of data.
required Source code inredco/trainers/trainer.py
def setup_running_step(self, dummy_batch):\n \"\"\"Sets up the running step functions for training and evaluation.\n\n Args:\n dummy_batch (PyTree): A dummy batch of data.\n \"\"\"\n train_step_fn = partial(\n self._train_step_fn or default_train_step,\n loss_fn=self._loss_fn,\n lr_schedule_fn=self._lr_schedule_fn,\n mesh=self.mesh,\n compute_dtype=self._compute_dtype)\n eval_step_fn = partial(\n eval_step,\n loss_fn=self._loss_fn,\n mesh=self.mesh,\n compute_dtype=self._compute_dtype)\n\n if self.mesh is None:\n self._p_train_step = jax.pmap(train_step_fn, axis_name='dp')\n self._p_eval_step = jax.pmap(eval_step_fn, axis_name='dp')\n else:\n data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n self._p_train_step = pjit(\n train_step_fn,\n in_shardings=(None, self._state_spec, data_spec),\n out_shardings=(self._state_spec, None),\n donate_argnums=(1, ))\n self._p_eval_step = pjit(\n eval_step_fn,\n in_shardings=(self._state_spec, data_spec),\n out_shardings=None)\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.train","title":"train(examples, per_device_batch_size, desc=None)
","text":"Trains the model on the provided examples.
Parameters:
Name Type Description Defaultexamples
list
Training examples in python list.
requiredper_device_batch_size
int
The batch size per device.
requireddesc
str
Description in the progress bar.
None
Source code in redco/trainers/trainer.py
def train(self, examples, per_device_batch_size, desc=None):\n \"\"\"Trains the model on the provided examples.\n\n Args:\n examples (list): Training examples in python list.\n per_device_batch_size (int): The batch size per device.\n desc (str): Description in the progress bar.\n \"\"\"\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=True,\n shuffle_rng=self._deployer.gen_rng(),\n desc=f'Training ({desc})' if desc is not None else 'Training',\n is_train=True,\n accumulate_grad_batches=self._accumulate_grad_batches)\n\n for batch in data_batches:\n if self._p_train_step is None:\n self.setup_running_step(dummy_batch=batch)\n\n train_rng = self._deployer.gen_rng()\n if self.mesh is None:\n train_rng = jax.random.split(\n train_rng, num=jax.process_count())[jax.process_index()]\n train_rng = shard_prng_key(train_rng)\n self._state, metrics = self._deployer.run_model_step(\n step_fn=self._p_train_step,\n input_args=(train_rng, self._state, batch))\n\n if self.mesh is None:\n metrics = unreplicate(metrics)\n data_batches.set_postfix(**metrics)\n self._deployer.log_metrics(metrics=metrics, step=self.step)\n
"}]}
\ No newline at end of file
+{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"RedCoast","text":"Red Coast (redco) is a lightweight and user-friendly tool designed to automate distributed training and inference for large models while simplifying the ML pipeline development process without necessitating MLSys expertise from users.
RedCoast supports Large Models + Complex Algorithms, in a lightweight and user-friendly way:
With RedCoast, to define a ML pipeline, only three functions are needed:
Redco automates all the remaining of pipeline execution such as data and model parallelism, multi-host related processing, distributed checkpointing, randomness controlling, logging, etc.
"},{"location":"deployer/","title":"Deployer","text":""},{"location":"deployer/#redco.deployers.deployer.Deployer","title":"Deployer
","text":"Handles low-level operations to support Trainer and Predictor, e.g., automatic data/model parallelism, distributed checkpointing, data processing, logging, randomness controlling, etc.
Attributes:
Name Type Descriptionworkdir
str
Working directory for saving checkpoints and logs.
mesh
jax Mesh
Mesh used for model sharding.
Source code inredco/deployers/deployer.py
class Deployer:\n \"\"\" Handles low-level operations to support Trainer and Predictor,\n e.g., automatic data/model parallelism, distributed checkpointing,\n data processing, logging, randomness controlling, etc.\n\n Attributes:\n workdir (str): Working directory for saving checkpoints and logs.\n mesh (jax Mesh): Mesh used for model sharding.\n \"\"\"\n def __init__(self,\n jax_seed,\n n_model_shards=1,\n verbose=True,\n workdir=None,\n n_processes=None,\n host0_address=None,\n host0_port=None,\n process_id=None,\n n_local_devices=None,\n run_tensorboard=False,\n wandb_init_kwargs=None):\n \"\"\" Initializes a Deployer.\n\n Args:\n jax_seed (int): Seed for random number generation.\n n_model_shards (int): Number of shards for running large model.\n verbose (bool): Whether to enable verbose logging.\n workdir (str): Directory for saving logs and checkpoints.\n n_processes (int): For multi-host, number of processes/nodes.\n host0_address (str): For multi-host, address of the host0.\n host0_port (int): For multi-host, port of the host0.\n process_id (int): For multi-host, index of the current process.\n n_local_devices (int): For multi-host, number of local devices.\n run_tensorboard (bool): Whether to enable TensorBoard logging.\n wandb_init_kwargs (dict): wandb.init arguments if using wandb.\n \"\"\"\n if n_processes is None:\n if 'SLURM_JOB_NUM_NODES' in os.environ:\n n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])\n process_id = int(os.environ['SLURM_NODEID'])\n else:\n n_processes = 1\n\n if n_processes > 1:\n local_device_ids = None if n_local_devices is None \\\n else list(range(n_local_devices))\n\n if host0_port is None:\n host0_port = DEFAULT_HOST0_PORT\n\n jax.distributed.initialize(\n coordinator_address=f'{host0_address}:{host0_port}',\n num_processes=n_processes,\n process_id=process_id,\n local_device_ids=local_device_ids)\n\n if workdir is not None:\n os.makedirs(workdir, exist_ok=True)\n\n self._verbose = verbose\n self._workdir = workdir\n self._logger = get_logger(verbose=verbose, workdir=workdir)\n\n if wandb_init_kwargs is not None and jax.process_index() == 0:\n import wandb\n wandb.init(**wandb_init_kwargs)\n self._wandb_log_fn = wandb.log\n else:\n self._wandb_log_fn = None\n\n if run_tensorboard and jax.process_index() == 0:\n from flax.metrics import tensorboard\n self._summary_writer = tensorboard.SummaryWriter(workdir)\n else:\n self._summary_writer = None\n\n self.log_info(\n f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')\n\n self._rng = jax.random.PRNGKey(seed=jax_seed)\n self._mesh = get_mesh(n_model_shards=n_model_shards)\n self._checkpointer = ocp.PyTreeCheckpointer()\n\n def get_local_global_micro_batch_size(self, per_device_batch_size):\n \"\"\"Get local/global micro batch sizes based on per-device batch size.\"\"\"\n if self._mesh is None:\n local_micro_batch_size = \\\n per_device_batch_size * jax.local_device_count()\n global_micro_batch_size = \\\n local_micro_batch_size * jax.process_count()\n else:\n global_micro_batch_size = local_micro_batch_size = \\\n per_device_batch_size * self._mesh.shape['dp']\n\n return local_micro_batch_size, global_micro_batch_size\n\n def get_accumulate_grad_batches(\n self, global_batch_size, per_device_batch_size):\n \"\"\"Calculates the number of gradient accumulation batches.\"\"\"\n _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n assert global_batch_size % global_micro_batch_size == 0\n accumulate_grad_batches = global_batch_size // global_micro_batch_size\n\n return accumulate_grad_batches\n\n def get_model_input_batches(self,\n examples,\n per_device_batch_size,\n collate_fn,\n shuffle,\n shuffle_rng,\n desc,\n is_train=False,\n accumulate_grad_batches=None):\n \"\"\"Prepares model input batches from examples.\n\n Args:\n examples (list): List of input examples.\n per_device_batch_size (int): Batch size per device.\n collate_fn (Callable): Function to collate the examples.\n shuffle (bool): Whether to shuffle the examples.\n shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.\n desc (str): Description in the progress bar.\n is_train (bool): Whether the data is for training.\n accumulate_grad_batches (int): gradient accumulation batches.\n\n Returns:\n (generator): A python generator of batched model inputs.\n \"\"\"\n local_micro_batch_size, global_micro_batch_size = \\\n self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n\n examples = get_host_examples(\n examples=examples,\n global_micro_batch_size=global_micro_batch_size,\n shuffle=shuffle,\n shuffle_rng=shuffle_rng,\n mesh=self._mesh)\n\n if not is_train:\n desc = f'{desc} (global_batch_size = {global_micro_batch_size})'\n elif accumulate_grad_batches is None:\n desc = \\\n f'{desc} (global_micro_batch_size = {global_micro_batch_size})'\n else:\n desc = (f'{desc} ('\n f'global_micro_batch_size = {global_micro_batch_size}, '\n f'accumulate_grad_batches = {accumulate_grad_batches})')\n\n return get_data_batches(\n examples=examples,\n batch_size=local_micro_batch_size,\n collate_fn=collate_fn,\n mesh=self._mesh,\n desc=desc,\n verbose=self._verbose)\n\n def get_lr_schedule_fn(self,\n train_size,\n per_device_batch_size,\n n_epochs,\n learning_rate,\n schedule_type='linear',\n warmup_ratio=0.,\n warmup_steps=None,\n init_learning_rate=0.,\n end_learning_rate=0.):\n \"\"\"Creates a learning rate schedule function.\n\n Args:\n train_size (int): Number of training examples per epoch.\n per_device_batch_size (int): Batch size per device.\n n_epochs (int): Number of epochs.\n learning_rate (float): Peak learning rate.\n schedule_type (str): Type of lr schedule, \"linear\" or \"cosine\".\n warmup_ratio (float): Ratio of lr warmup.\n warmup_steps (int): Number of warmup steps.\n init_learning_rate (float): Initial learning rate before warmup.\n end_learning_rate (float): End learning rate for the schedule.\n\n Returns:\n (Callable): A lr schedule function, step -> learning rate.\n \"\"\"\n _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n total_train_steps = n_epochs * (train_size // global_micro_batch_size)\n\n if warmup_steps is None:\n warmup_steps = int(total_train_steps * warmup_ratio)\n\n return get_lr_schedule_fn(\n schedule_type=schedule_type,\n total_train_steps=total_train_steps,\n warmup_steps=warmup_steps,\n init_learning_rate=init_learning_rate,\n learning_rate=learning_rate,\n end_learning_rate=end_learning_rate)\n\n def get_sharding_rules(self, params_shape_or_params):\n \"\"\"Get sharding rules based on the parameter shapes.\"\"\"\n if self._mesh is None:\n return None\n else:\n sharding_rules = get_sharding_rules(\n params_shape_or_params=params_shape_or_params,\n n_model_shards=self._mesh.shape['mp'])\n return sharding_rules\n\n def get_params_spec(self, params_shape_or_params, params_sharding_rules):\n \"\"\"Generates parameter specs based on sharding rules.\"\"\"\n return get_params_spec(\n params_shape_or_params=params_shape_or_params,\n params_sharding_rules=params_sharding_rules)\n\n def get_opt_state_spec(\n self, params_shape_or_params, params_spec, optimizer):\n \"\"\"Get optimizer state specs\"\"\"\n return get_opt_state_spec(\n params_shape_or_params=params_shape_or_params,\n params_spec=params_spec,\n optimizer=optimizer)\n\n def shard_params(self, params, params_spec, desc='params'):\n \"\"\"Distributes parameters to all devices based on the provided specs.\"\"\"\n self.log_info(info=f'Sharding {desc} ...')\n return shard_params(\n mesh=self._mesh, params=params, params_spec=params_spec)\n\n def run_model_step(self, step_fn, input_args):\n \"\"\"Executes a model step function with the provided inputs.\"\"\"\n if self._mesh is None:\n return step_fn(*input_args)\n else:\n with self._mesh:\n return step_fn(*input_args)\n\n def gen_rng(self):\n \"\"\"Get a new random number generator key and update the random state.\"\"\"\n self._rng, new_rng = jax.random.split(self._rng)\n return new_rng\n\n def log_info(self, info, title=None, step=None):\n \"\"\"Logs a messages\"\"\"\n log_info(\n info=info,\n title=title,\n logger=self._logger,\n summary_writer=self._summary_writer,\n step=step)\n\n def log_metrics(self, metrics, step):\n \"\"\"Logs metrics to TensorBoard and Weights and Biases (wandb).\"\"\"\n if self._summary_writer is not None:\n for metric_name, value in metrics.items():\n self._summary_writer.scalar(metric_name, value, step=step)\n\n if self._wandb_log_fn is not None:\n self._wandb_log_fn(metrics, step)\n\n def save_outputs(self, outputs, desc, step):\n \"\"\"Saves model outputs to workdir.\"\"\"\n if self._workdir is not None and jax.process_index() == 0:\n save_outputs(\n workdir=self._workdir,\n outputs=outputs,\n desc=desc,\n step=step,\n logger=self._logger,\n summary_writer=self._summary_writer)\n\n def save_ckpt(\n self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):\n \"\"\"Saves a checkpoint to the specified directory.\n\n Args:\n ckpt_dir (str): Directory to save the checkpoint.\n params (dict): Model parameters.\n opt_state (dict): Optimizer state.\n float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n **kwargs (dict): Additional information to be saved into\n info.json, e.g., current training step, epoch index, etc.\n \"\"\"\n ckpt_dir = os.path.abspath(ckpt_dir)\n self.log_info(f'Saving ckpt to {ckpt_dir} ...')\n save_ckpt(\n ckpt_dir=ckpt_dir,\n checkpointer=self._checkpointer,\n params=params,\n opt_state=opt_state,\n float_dtype=float_dtype,\n rng=self._rng,\n **kwargs)\n self.log_info(f'Ckpt saved into {ckpt_dir}')\n\n def load_params_shape(self, ckpt_dir):\n \"\"\"Loads the shape of the parameters from a checkpoint.\"\"\"\n return load_params_shape(ckpt_dir=ckpt_dir)\n\n def load_ckpt(self,\n ckpt_dir,\n params_sharding_rules=None,\n optimizer=None,\n float_dtype=None,\n load_params=True,\n load_opt_state=True,\n update_rng=False):\n \"\"\"Loads a checkpoint from the specified directory.\n\n Args:\n ckpt_dir (str): Directory of the checkpoint.\n params_sharding_rules (list[tuple]): Sharding rules for parameters.\n optimizer (optax optimizer): Optimizer for loading opt_state.\n float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n load_params (bool): Whether to load the parameters.\n load_opt_state (bool): Whether to load the optimizer state.\n update_rng (bool): if updating the random state of the deployer.\n\n Returns:\n (tuple): A tuple with the loaded checkpoint (in a dict with\n `\"params\"` and `\"opt_state\"`) and additional information (in a\n dict, usually including `\"steps\"`, `\"epoch_idx\"`, and `\"rng\"`).\n \"\"\"\n ckpt_dir = os.path.abspath(ckpt_dir)\n self.log_info(f'Loading ckpt from {ckpt_dir} ...')\n\n params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)\n\n specs = {}\n if self._mesh is not None:\n if params_sharding_rules is None:\n params_sharding_rules = self.get_sharding_rules(\n params_shape_or_params=params_shape)\n\n specs['params'] = self.get_params_spec(\n params_shape_or_params=params_shape,\n params_sharding_rules=params_sharding_rules)\n if optimizer is not None:\n specs['opt_state'] = self.get_opt_state_spec(\n params_shape_or_params=params_shape,\n params_spec=specs['params'],\n optimizer=optimizer)\n\n ckpt, info = load_ckpt(\n ckpt_dir=ckpt_dir,\n checkpointer=self._checkpointer,\n params_shape_or_params=params_shape,\n optimizer=optimizer,\n float_dtype=float_dtype,\n mesh=self._mesh,\n specs=specs,\n load_params=load_params,\n load_opt_state=load_opt_state)\n\n for key, value in info.items():\n if not update_rng and key == 'rng':\n continue\n self.log_info(f'{ckpt_dir}::{key} = {value}')\n\n if update_rng:\n self._rng = info['rng']\n self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')\n\n return ckpt, info\n\n def load_last_ckpt(self,\n optimizer=None,\n params_sharding_rules=None,\n float_dtype=None,\n load_params=True,\n load_opt_state=True,\n update_rng=True):\n \"\"\"Loads the last checkpoint from the work directory (self.workdir).\n See load_ckpt() for the explanation of arguments.\n \"\"\"\n try:\n last_ckpt_name = open(\n f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()\n except:\n self.log_info(\n f'{self._workdir}/ckpts/last_ckpt.txt not found. '\n f'no ckpt loaded.')\n return None, None\n\n return self.load_ckpt(\n ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',\n optimizer=optimizer,\n float_dtype=float_dtype,\n params_sharding_rules=params_sharding_rules,\n load_params=load_params,\n load_opt_state=load_opt_state,\n update_rng=update_rng)\n\n @property\n def mesh(self):\n \"\"\"Returns the mesh for model sharding\"\"\"\n return self._mesh\n\n @property\n def workdir(self):\n \"\"\"Returns the work directory.\"\"\"\n return self._workdir\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.mesh","title":"mesh
property
","text":"Returns the mesh for model sharding
"},{"location":"deployer/#redco.deployers.deployer.Deployer.workdir","title":"workdir
property
","text":"Returns the work directory.
"},{"location":"deployer/#redco.deployers.deployer.Deployer.__init__","title":"__init__(jax_seed, n_model_shards=1, verbose=True, workdir=None, n_processes=None, host0_address=None, host0_port=None, process_id=None, n_local_devices=None, run_tensorboard=False, wandb_init_kwargs=None)
","text":"Initializes a Deployer.
Parameters:
Name Type Description Defaultjax_seed
int
Seed for random number generation.
requiredn_model_shards
int
Number of shards for running large model.
1
verbose
bool
Whether to enable verbose logging.
True
workdir
str
Directory for saving logs and checkpoints.
None
n_processes
int
For multi-host, number of processes/nodes.
None
host0_address
str
For multi-host, address of the host0.
None
host0_port
int
For multi-host, port of the host0.
None
process_id
int
For multi-host, index of the current process.
None
n_local_devices
int
For multi-host, number of local devices.
None
run_tensorboard
bool
Whether to enable TensorBoard logging.
False
wandb_init_kwargs
dict
wandb.init arguments if using wandb.
None
Source code in redco/deployers/deployer.py
def __init__(self,\n jax_seed,\n n_model_shards=1,\n verbose=True,\n workdir=None,\n n_processes=None,\n host0_address=None,\n host0_port=None,\n process_id=None,\n n_local_devices=None,\n run_tensorboard=False,\n wandb_init_kwargs=None):\n \"\"\" Initializes a Deployer.\n\n Args:\n jax_seed (int): Seed for random number generation.\n n_model_shards (int): Number of shards for running large model.\n verbose (bool): Whether to enable verbose logging.\n workdir (str): Directory for saving logs and checkpoints.\n n_processes (int): For multi-host, number of processes/nodes.\n host0_address (str): For multi-host, address of the host0.\n host0_port (int): For multi-host, port of the host0.\n process_id (int): For multi-host, index of the current process.\n n_local_devices (int): For multi-host, number of local devices.\n run_tensorboard (bool): Whether to enable TensorBoard logging.\n wandb_init_kwargs (dict): wandb.init arguments if using wandb.\n \"\"\"\n if n_processes is None:\n if 'SLURM_JOB_NUM_NODES' in os.environ:\n n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])\n process_id = int(os.environ['SLURM_NODEID'])\n else:\n n_processes = 1\n\n if n_processes > 1:\n local_device_ids = None if n_local_devices is None \\\n else list(range(n_local_devices))\n\n if host0_port is None:\n host0_port = DEFAULT_HOST0_PORT\n\n jax.distributed.initialize(\n coordinator_address=f'{host0_address}:{host0_port}',\n num_processes=n_processes,\n process_id=process_id,\n local_device_ids=local_device_ids)\n\n if workdir is not None:\n os.makedirs(workdir, exist_ok=True)\n\n self._verbose = verbose\n self._workdir = workdir\n self._logger = get_logger(verbose=verbose, workdir=workdir)\n\n if wandb_init_kwargs is not None and jax.process_index() == 0:\n import wandb\n wandb.init(**wandb_init_kwargs)\n self._wandb_log_fn = wandb.log\n else:\n self._wandb_log_fn = None\n\n if run_tensorboard and jax.process_index() == 0:\n from flax.metrics import tensorboard\n self._summary_writer = tensorboard.SummaryWriter(workdir)\n else:\n self._summary_writer = None\n\n self.log_info(\n f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')\n\n self._rng = jax.random.PRNGKey(seed=jax_seed)\n self._mesh = get_mesh(n_model_shards=n_model_shards)\n self._checkpointer = ocp.PyTreeCheckpointer()\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.gen_rng","title":"gen_rng()
","text":"Get a new random number generator key and update the random state.
Source code inredco/deployers/deployer.py
def gen_rng(self):\n \"\"\"Get a new random number generator key and update the random state.\"\"\"\n self._rng, new_rng = jax.random.split(self._rng)\n return new_rng\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_accumulate_grad_batches","title":"get_accumulate_grad_batches(global_batch_size, per_device_batch_size)
","text":"Calculates the number of gradient accumulation batches.
Source code inredco/deployers/deployer.py
def get_accumulate_grad_batches(\n self, global_batch_size, per_device_batch_size):\n \"\"\"Calculates the number of gradient accumulation batches.\"\"\"\n _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n assert global_batch_size % global_micro_batch_size == 0\n accumulate_grad_batches = global_batch_size // global_micro_batch_size\n\n return accumulate_grad_batches\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_local_global_micro_batch_size","title":"get_local_global_micro_batch_size(per_device_batch_size)
","text":"Get local/global micro batch sizes based on per-device batch size.
Source code inredco/deployers/deployer.py
def get_local_global_micro_batch_size(self, per_device_batch_size):\n \"\"\"Get local/global micro batch sizes based on per-device batch size.\"\"\"\n if self._mesh is None:\n local_micro_batch_size = \\\n per_device_batch_size * jax.local_device_count()\n global_micro_batch_size = \\\n local_micro_batch_size * jax.process_count()\n else:\n global_micro_batch_size = local_micro_batch_size = \\\n per_device_batch_size * self._mesh.shape['dp']\n\n return local_micro_batch_size, global_micro_batch_size\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_lr_schedule_fn","title":"get_lr_schedule_fn(train_size, per_device_batch_size, n_epochs, learning_rate, schedule_type='linear', warmup_ratio=0.0, warmup_steps=None, init_learning_rate=0.0, end_learning_rate=0.0)
","text":"Creates a learning rate schedule function.
Parameters:
Name Type Description Defaulttrain_size
int
Number of training examples per epoch.
requiredper_device_batch_size
int
Batch size per device.
requiredn_epochs
int
Number of epochs.
requiredlearning_rate
float
Peak learning rate.
requiredschedule_type
str
Type of lr schedule, \"linear\" or \"cosine\".
'linear'
warmup_ratio
float
Ratio of lr warmup.
0.0
warmup_steps
int
Number of warmup steps.
None
init_learning_rate
float
Initial learning rate before warmup.
0.0
end_learning_rate
float
End learning rate for the schedule.
0.0
Returns:
Type DescriptionCallable
A lr schedule function, step -> learning rate.
Source code inredco/deployers/deployer.py
def get_lr_schedule_fn(self,\n train_size,\n per_device_batch_size,\n n_epochs,\n learning_rate,\n schedule_type='linear',\n warmup_ratio=0.,\n warmup_steps=None,\n init_learning_rate=0.,\n end_learning_rate=0.):\n \"\"\"Creates a learning rate schedule function.\n\n Args:\n train_size (int): Number of training examples per epoch.\n per_device_batch_size (int): Batch size per device.\n n_epochs (int): Number of epochs.\n learning_rate (float): Peak learning rate.\n schedule_type (str): Type of lr schedule, \"linear\" or \"cosine\".\n warmup_ratio (float): Ratio of lr warmup.\n warmup_steps (int): Number of warmup steps.\n init_learning_rate (float): Initial learning rate before warmup.\n end_learning_rate (float): End learning rate for the schedule.\n\n Returns:\n (Callable): A lr schedule function, step -> learning rate.\n \"\"\"\n _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n total_train_steps = n_epochs * (train_size // global_micro_batch_size)\n\n if warmup_steps is None:\n warmup_steps = int(total_train_steps * warmup_ratio)\n\n return get_lr_schedule_fn(\n schedule_type=schedule_type,\n total_train_steps=total_train_steps,\n warmup_steps=warmup_steps,\n init_learning_rate=init_learning_rate,\n learning_rate=learning_rate,\n end_learning_rate=end_learning_rate)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_model_input_batches","title":"get_model_input_batches(examples, per_device_batch_size, collate_fn, shuffle, shuffle_rng, desc, is_train=False, accumulate_grad_batches=None)
","text":"Prepares model input batches from examples.
Parameters:
Name Type Description Defaultexamples
list
List of input examples.
requiredper_device_batch_size
int
Batch size per device.
requiredcollate_fn
Callable
Function to collate the examples.
requiredshuffle
bool
Whether to shuffle the examples.
requiredshuffle_rng
`jax.numpy.Array`
RNG for randomness of shuffling.
requireddesc
str
Description in the progress bar.
requiredis_train
bool
Whether the data is for training.
False
accumulate_grad_batches
int
gradient accumulation batches.
None
Returns:
Type Descriptiongenerator
A python generator of batched model inputs.
Source code inredco/deployers/deployer.py
def get_model_input_batches(self,\n examples,\n per_device_batch_size,\n collate_fn,\n shuffle,\n shuffle_rng,\n desc,\n is_train=False,\n accumulate_grad_batches=None):\n \"\"\"Prepares model input batches from examples.\n\n Args:\n examples (list): List of input examples.\n per_device_batch_size (int): Batch size per device.\n collate_fn (Callable): Function to collate the examples.\n shuffle (bool): Whether to shuffle the examples.\n shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.\n desc (str): Description in the progress bar.\n is_train (bool): Whether the data is for training.\n accumulate_grad_batches (int): gradient accumulation batches.\n\n Returns:\n (generator): A python generator of batched model inputs.\n \"\"\"\n local_micro_batch_size, global_micro_batch_size = \\\n self.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n\n examples = get_host_examples(\n examples=examples,\n global_micro_batch_size=global_micro_batch_size,\n shuffle=shuffle,\n shuffle_rng=shuffle_rng,\n mesh=self._mesh)\n\n if not is_train:\n desc = f'{desc} (global_batch_size = {global_micro_batch_size})'\n elif accumulate_grad_batches is None:\n desc = \\\n f'{desc} (global_micro_batch_size = {global_micro_batch_size})'\n else:\n desc = (f'{desc} ('\n f'global_micro_batch_size = {global_micro_batch_size}, '\n f'accumulate_grad_batches = {accumulate_grad_batches})')\n\n return get_data_batches(\n examples=examples,\n batch_size=local_micro_batch_size,\n collate_fn=collate_fn,\n mesh=self._mesh,\n desc=desc,\n verbose=self._verbose)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_opt_state_spec","title":"get_opt_state_spec(params_shape_or_params, params_spec, optimizer)
","text":"Get optimizer state specs
Source code inredco/deployers/deployer.py
def get_opt_state_spec(\n self, params_shape_or_params, params_spec, optimizer):\n \"\"\"Get optimizer state specs\"\"\"\n return get_opt_state_spec(\n params_shape_or_params=params_shape_or_params,\n params_spec=params_spec,\n optimizer=optimizer)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_params_spec","title":"get_params_spec(params_shape_or_params, params_sharding_rules)
","text":"Generates parameter specs based on sharding rules.
Source code inredco/deployers/deployer.py
def get_params_spec(self, params_shape_or_params, params_sharding_rules):\n \"\"\"Generates parameter specs based on sharding rules.\"\"\"\n return get_params_spec(\n params_shape_or_params=params_shape_or_params,\n params_sharding_rules=params_sharding_rules)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.get_sharding_rules","title":"get_sharding_rules(params_shape_or_params)
","text":"Get sharding rules based on the parameter shapes.
Source code inredco/deployers/deployer.py
def get_sharding_rules(self, params_shape_or_params):\n \"\"\"Get sharding rules based on the parameter shapes.\"\"\"\n if self._mesh is None:\n return None\n else:\n sharding_rules = get_sharding_rules(\n params_shape_or_params=params_shape_or_params,\n n_model_shards=self._mesh.shape['mp'])\n return sharding_rules\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.load_ckpt","title":"load_ckpt(ckpt_dir, params_sharding_rules=None, optimizer=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=False)
","text":"Loads a checkpoint from the specified directory.
Parameters:
Name Type Description Defaultckpt_dir
str
Directory of the checkpoint.
requiredparams_sharding_rules
list[tuple]
Sharding rules for parameters.
None
optimizer
optax optimizer
Optimizer for loading opt_state.
None
float_dtype
`jax.numpy.dtype`
Dtype for floating point numbers.
None
load_params
bool
Whether to load the parameters.
True
load_opt_state
bool
Whether to load the optimizer state.
True
update_rng
bool
if updating the random state of the deployer.
False
Returns:
Type Descriptiontuple
A tuple with the loaded checkpoint (in a dict with \"params\"
and \"opt_state\"
) and additional information (in a dict, usually including \"steps\"
, \"epoch_idx\"
, and \"rng\"
).
redco/deployers/deployer.py
def load_ckpt(self,\n ckpt_dir,\n params_sharding_rules=None,\n optimizer=None,\n float_dtype=None,\n load_params=True,\n load_opt_state=True,\n update_rng=False):\n \"\"\"Loads a checkpoint from the specified directory.\n\n Args:\n ckpt_dir (str): Directory of the checkpoint.\n params_sharding_rules (list[tuple]): Sharding rules for parameters.\n optimizer (optax optimizer): Optimizer for loading opt_state.\n float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n load_params (bool): Whether to load the parameters.\n load_opt_state (bool): Whether to load the optimizer state.\n update_rng (bool): if updating the random state of the deployer.\n\n Returns:\n (tuple): A tuple with the loaded checkpoint (in a dict with\n `\"params\"` and `\"opt_state\"`) and additional information (in a\n dict, usually including `\"steps\"`, `\"epoch_idx\"`, and `\"rng\"`).\n \"\"\"\n ckpt_dir = os.path.abspath(ckpt_dir)\n self.log_info(f'Loading ckpt from {ckpt_dir} ...')\n\n params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)\n\n specs = {}\n if self._mesh is not None:\n if params_sharding_rules is None:\n params_sharding_rules = self.get_sharding_rules(\n params_shape_or_params=params_shape)\n\n specs['params'] = self.get_params_spec(\n params_shape_or_params=params_shape,\n params_sharding_rules=params_sharding_rules)\n if optimizer is not None:\n specs['opt_state'] = self.get_opt_state_spec(\n params_shape_or_params=params_shape,\n params_spec=specs['params'],\n optimizer=optimizer)\n\n ckpt, info = load_ckpt(\n ckpt_dir=ckpt_dir,\n checkpointer=self._checkpointer,\n params_shape_or_params=params_shape,\n optimizer=optimizer,\n float_dtype=float_dtype,\n mesh=self._mesh,\n specs=specs,\n load_params=load_params,\n load_opt_state=load_opt_state)\n\n for key, value in info.items():\n if not update_rng and key == 'rng':\n continue\n self.log_info(f'{ckpt_dir}::{key} = {value}')\n\n if update_rng:\n self._rng = info['rng']\n self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')\n\n return ckpt, info\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.load_last_ckpt","title":"load_last_ckpt(optimizer=None, params_sharding_rules=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=True)
","text":"Loads the last checkpoint from the work directory (self.workdir). See load_ckpt() for the explanation of arguments.
Source code inredco/deployers/deployer.py
def load_last_ckpt(self,\n optimizer=None,\n params_sharding_rules=None,\n float_dtype=None,\n load_params=True,\n load_opt_state=True,\n update_rng=True):\n \"\"\"Loads the last checkpoint from the work directory (self.workdir).\n See load_ckpt() for the explanation of arguments.\n \"\"\"\n try:\n last_ckpt_name = open(\n f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()\n except:\n self.log_info(\n f'{self._workdir}/ckpts/last_ckpt.txt not found. '\n f'no ckpt loaded.')\n return None, None\n\n return self.load_ckpt(\n ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',\n optimizer=optimizer,\n float_dtype=float_dtype,\n params_sharding_rules=params_sharding_rules,\n load_params=load_params,\n load_opt_state=load_opt_state,\n update_rng=update_rng)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.load_params_shape","title":"load_params_shape(ckpt_dir)
","text":"Loads the shape of the parameters from a checkpoint.
Source code inredco/deployers/deployer.py
def load_params_shape(self, ckpt_dir):\n \"\"\"Loads the shape of the parameters from a checkpoint.\"\"\"\n return load_params_shape(ckpt_dir=ckpt_dir)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.log_info","title":"log_info(info, title=None, step=None)
","text":"Logs a messages
Source code inredco/deployers/deployer.py
def log_info(self, info, title=None, step=None):\n \"\"\"Logs a messages\"\"\"\n log_info(\n info=info,\n title=title,\n logger=self._logger,\n summary_writer=self._summary_writer,\n step=step)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.log_metrics","title":"log_metrics(metrics, step)
","text":"Logs metrics to TensorBoard and Weights and Biases (wandb).
Source code inredco/deployers/deployer.py
def log_metrics(self, metrics, step):\n \"\"\"Logs metrics to TensorBoard and Weights and Biases (wandb).\"\"\"\n if self._summary_writer is not None:\n for metric_name, value in metrics.items():\n self._summary_writer.scalar(metric_name, value, step=step)\n\n if self._wandb_log_fn is not None:\n self._wandb_log_fn(metrics, step)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.run_model_step","title":"run_model_step(step_fn, input_args)
","text":"Executes a model step function with the provided inputs.
Source code inredco/deployers/deployer.py
def run_model_step(self, step_fn, input_args):\n \"\"\"Executes a model step function with the provided inputs.\"\"\"\n if self._mesh is None:\n return step_fn(*input_args)\n else:\n with self._mesh:\n return step_fn(*input_args)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.save_ckpt","title":"save_ckpt(ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs)
","text":"Saves a checkpoint to the specified directory.
Parameters:
Name Type Description Defaultckpt_dir
str
Directory to save the checkpoint.
requiredparams
dict
Model parameters.
requiredopt_state
dict
Optimizer state.
None
float_dtype
`jax.numpy.dtype`
Dtype for floating point numbers.
None
**kwargs
dict
Additional information to be saved into info.json, e.g., current training step, epoch index, etc.
{}
Source code in redco/deployers/deployer.py
def save_ckpt(\n self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):\n \"\"\"Saves a checkpoint to the specified directory.\n\n Args:\n ckpt_dir (str): Directory to save the checkpoint.\n params (dict): Model parameters.\n opt_state (dict): Optimizer state.\n float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n **kwargs (dict): Additional information to be saved into\n info.json, e.g., current training step, epoch index, etc.\n \"\"\"\n ckpt_dir = os.path.abspath(ckpt_dir)\n self.log_info(f'Saving ckpt to {ckpt_dir} ...')\n save_ckpt(\n ckpt_dir=ckpt_dir,\n checkpointer=self._checkpointer,\n params=params,\n opt_state=opt_state,\n float_dtype=float_dtype,\n rng=self._rng,\n **kwargs)\n self.log_info(f'Ckpt saved into {ckpt_dir}')\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.save_outputs","title":"save_outputs(outputs, desc, step)
","text":"Saves model outputs to workdir.
Source code inredco/deployers/deployer.py
def save_outputs(self, outputs, desc, step):\n \"\"\"Saves model outputs to workdir.\"\"\"\n if self._workdir is not None and jax.process_index() == 0:\n save_outputs(\n workdir=self._workdir,\n outputs=outputs,\n desc=desc,\n step=step,\n logger=self._logger,\n summary_writer=self._summary_writer)\n
"},{"location":"deployer/#redco.deployers.deployer.Deployer.shard_params","title":"shard_params(params, params_spec, desc='params')
","text":"Distributes parameters to all devices based on the provided specs.
Source code inredco/deployers/deployer.py
def shard_params(self, params, params_spec, desc='params'):\n \"\"\"Distributes parameters to all devices based on the provided specs.\"\"\"\n self.log_info(info=f'Sharding {desc} ...')\n return shard_params(\n mesh=self._mesh, params=params, params_spec=params_spec)\n
"},{"location":"mnist/","title":"MNIST Example","text":"This is a trivial MNIST example with RedCoast. Runnable by
python main.py\n
To simulate multiple devices in cpu-only envs,
XLA_FLAGS=\"--xla_force_host_platform_device_count=8\" python main.py\n
"},{"location":"mnist/#source-code","title":"Source Code","text":"from functools import partial\nimport fire\nimport numpy as np\nfrom flax import linen as nn\nimport optax\nfrom torchvision.datasets import MNIST\nfrom redco import Deployer, Trainer, Predictor\n\n\n# A simple CNN model \n# Copied from https://github.com/google/flax/blob/main/examples/mnist/train.py\nclass CNN(nn.Module):\n @nn.compact\n def __call__(self, x):\n x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n x = nn.relu(x)\n x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n x = nn.relu(x)\n x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n x = x.reshape((x.shape[0], -1)) # flatten\n x = nn.Dense(features=256)(x)\n x = nn.relu(x)\n x = nn.Dense(features=10)(x)\n return x\n\n\n# Collate function converting a batch of raw examples to model inputs (in numpy) \ndef collate_fn(examples):\n images = np.stack(\n [np.array(example['image'])[:, :, None] for example in examples])\n labels = np.array([example['label'] for example in examples])\n\n return {'images': images, 'labels': labels}\n\n\n# Loss function converting model inputs to a scalar loss\ndef loss_fn(train_rng, state, params, batch, is_training):\n logits = state.apply_fn({'params': params}, batch['images'])\n return optax.softmax_cross_entropy_with_integer_labels(\n logits=logits, labels=batch['labels']).mean()\n\n\n# Predict function converting model inputs to the model outputs\ndef pred_fn(pred_rng, params, batch, model):\n accs = model.apply({'params': params}, batch['images']).argmax(axis=-1)\n return {'acc': accs}\n\n\n# (Optional) Evaluation function in trainer.fit. Here it computes accuracy.\ndef eval_metric_fn(examples, preds):\n preds = np.array([pred['acc'] for pred in preds])\n labels = np.array([example['label'] for example in examples])\n return {'acc': np.mean(preds == labels).item()}\n\n\ndef main(per_device_batch_size=64, learning_rate=1e-3, jax_seed=42):\n deployer = Deployer(jax_seed=jax_seed, workdir='./workdir')\n\n dataset = {\n 'train': [{'image': t[0], 'label': t[1]} for t in list(\n MNIST('./data', train=True, download=True))],\n 'test': [{'image': t[0], 'label': t[1]} for t in list(\n MNIST('./data', train=False, download=True))],\n }\n\n model = CNN()\n dummy_batch = collate_fn(examples=[dataset['train'][0]])\n params = model.init(deployer.gen_rng(), dummy_batch['images'])['params']\n\n trainer = Trainer(\n deployer=deployer,\n collate_fn=collate_fn,\n apply_fn=model.apply,\n loss_fn=loss_fn,\n params=params,\n optimizer=optax.adamw(learning_rate=learning_rate))\n\n predictor = Predictor(\n deployer=deployer,\n collate_fn=collate_fn,\n pred_fn=partial(pred_fn, model=model))\n\n trainer.fit(\n train_examples=dataset['train'],\n per_device_batch_size=per_device_batch_size,\n n_epochs=2,\n eval_examples=dataset['test'],\n eval_predictor=predictor,\n eval_metric_fn=eval_metric_fn)\n\n\nif __name__ == '__main__':\n fire.Fire(main)\n
"},{"location":"predictor/","title":"Predictor","text":""},{"location":"predictor/#redco.predictors.predictor.Predictor","title":"Predictor
","text":"Predictor class managing distributed inference process.
Attributes:
Name Type Descriptionmesh
jax Mesh
Mesh used for distributed inference.
Source code inredco/predictors/predictor.py
class Predictor:\n \"\"\"Predictor class managing distributed inference process.\n\n Attributes:\n mesh (jax Mesh): Mesh used for distributed inference.\n \"\"\"\n def __init__(self,\n deployer,\n collate_fn,\n pred_fn,\n output_fn=None,\n params_sharding_rules=None):\n \"\"\"Initializes a Predictor instance.\n\n Args:\n deployer (Deployer): A deployer for low-level operations.\n collate_fn (Callable): A function making model inputs from raw data,\n e.g., tokenizing sentences into input_ids.\n pred_fn (Callable): A function producing model outputs from inputs,\n e.g., running beam search with a language model.\n output_fn (Callable): A function post-processing model outputs,\n e.g., decoding generated ids to text.\n params_sharding_rules (list[tuple]): Rules for sharding parameters.\n \"\"\"\n self._deployer = deployer\n self._collate_fn = partial(collate_fn_wrapper, collate_fn=collate_fn)\n self._params_sharding_rules = params_sharding_rules\n self._pred_fn = partial(pred_fn_wrapper, pred_fn=pred_fn)\n self._p_pred_step = None\n\n if output_fn is None:\n self._output_fn = default_output_fn\n else:\n self._output_fn = output_fn\n\n def setup_running_step(self, dummy_batch, params_shape_or_params):\n \"\"\"Sets up the prediction step function for distributed inference.\n\n Args:\n dummy_batch (PyTree): A dummy batch used to determine data shapes.\n params_shape_or_params (dict): The shape of params or actual params.\n \"\"\"\n pred_step_fn = partial(pred_step, pred_fn=self._pred_fn, mesh=self.mesh)\n\n if self.mesh is None:\n self._p_pred_step = jax.pmap(pred_step_fn, axis_name='dp')\n else:\n data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params_shape_or_params,\n params_sharding_rules=self._params_sharding_rules)\n self._p_pred_step = pjit(\n pred_step_fn,\n in_shardings=(None, params_spec, data_spec),\n out_shardings=None)\n\n def predict(self,\n examples,\n per_device_batch_size,\n params,\n params_replicated=False,\n params_sharded=False,\n desc=None):\n \"\"\"Runs distributed prediction on a list of examples.\n\n Args:\n examples (list): Input examples for prediction.\n per_device_batch_size (int): Batch size per device.\n params (dict): Model parameters in a dict/FrozenDict.\n params_replicated (bool): if the params are already replicated.\n params_sharded (bool): if the parameters are already sharded.\n desc (str): Description to show in the progress bar.\n\n Returns:\n (list): A list of predictions corresponding to the input examples.\n \"\"\"\n raw_n_inputs = len(examples)\n _, global_micro_batch_size = \\\n self._deployer.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n examples = examples + [examples[0]] * (global_micro_batch_size - 1)\n examples = add_idxes(examples=examples)\n\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=False,\n shuffle_rng=None,\n desc=f'Predicting ({desc})' if desc is not None else 'Predicting')\n\n params = freeze(params)\n if (self.mesh is None) and (not params_replicated):\n params = replicate(params)\n if (self.mesh is not None) and (not params_sharded):\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params,\n params_sharding_rules=self._params_sharding_rules)\n params = self._deployer.shard_params(\n params=params, params_spec=params_spec)\n\n preds = []\n for batch in data_batches:\n if self._p_pred_step is None:\n self.setup_running_step(\n dummy_batch=batch, params_shape_or_params=params)\n\n pred_rng = self._deployer.gen_rng()\n if self.mesh is None:\n pred_rng = jax.random.split(\n pred_rng, num=jax.process_count())[jax.process_index()]\n pred_rng = shard_prng_key(pred_rng)\n\n batch_preds_with_idxes = self._deployer.run_model_step(\n step_fn=self._p_pred_step,\n input_args=(pred_rng, params, batch))\n batch_preds = process_batch_preds(\n batch_preds_with_idxes=batch_preds_with_idxes, mesh=self.mesh)\n batch_preds = self._output_fn(batch_preds)\n\n assert isinstance(batch_preds, list) and \\\n len(batch_preds) == global_micro_batch_size\n preds.extend(batch_preds)\n\n return preds[:raw_n_inputs]\n\n @property\n def mesh(self):\n \"\"\"Returns the mesh used for distributed inference.\"\"\"\n return self._deployer.mesh\n
"},{"location":"predictor/#redco.predictors.predictor.Predictor.mesh","title":"mesh
property
","text":"Returns the mesh used for distributed inference.
"},{"location":"predictor/#redco.predictors.predictor.Predictor.__init__","title":"__init__(deployer, collate_fn, pred_fn, output_fn=None, params_sharding_rules=None)
","text":"Initializes a Predictor instance.
Parameters:
Name Type Description Defaultdeployer
Deployer
A deployer for low-level operations.
requiredcollate_fn
Callable
A function making model inputs from raw data, e.g., tokenizing sentences into input_ids.
requiredpred_fn
Callable
A function producing model outputs from inputs, e.g., running beam search with a language model.
requiredoutput_fn
Callable
A function post-processing model outputs, e.g., decoding generated ids to text.
None
params_sharding_rules
list[tuple]
Rules for sharding parameters.
None
Source code in redco/predictors/predictor.py
def __init__(self,\n deployer,\n collate_fn,\n pred_fn,\n output_fn=None,\n params_sharding_rules=None):\n \"\"\"Initializes a Predictor instance.\n\n Args:\n deployer (Deployer): A deployer for low-level operations.\n collate_fn (Callable): A function making model inputs from raw data,\n e.g., tokenizing sentences into input_ids.\n pred_fn (Callable): A function producing model outputs from inputs,\n e.g., running beam search with a language model.\n output_fn (Callable): A function post-processing model outputs,\n e.g., decoding generated ids to text.\n params_sharding_rules (list[tuple]): Rules for sharding parameters.\n \"\"\"\n self._deployer = deployer\n self._collate_fn = partial(collate_fn_wrapper, collate_fn=collate_fn)\n self._params_sharding_rules = params_sharding_rules\n self._pred_fn = partial(pred_fn_wrapper, pred_fn=pred_fn)\n self._p_pred_step = None\n\n if output_fn is None:\n self._output_fn = default_output_fn\n else:\n self._output_fn = output_fn\n
"},{"location":"predictor/#redco.predictors.predictor.Predictor.predict","title":"predict(examples, per_device_batch_size, params, params_replicated=False, params_sharded=False, desc=None)
","text":"Runs distributed prediction on a list of examples.
Parameters:
Name Type Description Defaultexamples
list
Input examples for prediction.
requiredper_device_batch_size
int
Batch size per device.
requiredparams
dict
Model parameters in a dict/FrozenDict.
requiredparams_replicated
bool
if the params are already replicated.
False
params_sharded
bool
if the parameters are already sharded.
False
desc
str
Description to show in the progress bar.
None
Returns:
Type Descriptionlist
A list of predictions corresponding to the input examples.
Source code inredco/predictors/predictor.py
def predict(self,\n examples,\n per_device_batch_size,\n params,\n params_replicated=False,\n params_sharded=False,\n desc=None):\n \"\"\"Runs distributed prediction on a list of examples.\n\n Args:\n examples (list): Input examples for prediction.\n per_device_batch_size (int): Batch size per device.\n params (dict): Model parameters in a dict/FrozenDict.\n params_replicated (bool): if the params are already replicated.\n params_sharded (bool): if the parameters are already sharded.\n desc (str): Description to show in the progress bar.\n\n Returns:\n (list): A list of predictions corresponding to the input examples.\n \"\"\"\n raw_n_inputs = len(examples)\n _, global_micro_batch_size = \\\n self._deployer.get_local_global_micro_batch_size(\n per_device_batch_size=per_device_batch_size)\n examples = examples + [examples[0]] * (global_micro_batch_size - 1)\n examples = add_idxes(examples=examples)\n\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=False,\n shuffle_rng=None,\n desc=f'Predicting ({desc})' if desc is not None else 'Predicting')\n\n params = freeze(params)\n if (self.mesh is None) and (not params_replicated):\n params = replicate(params)\n if (self.mesh is not None) and (not params_sharded):\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params,\n params_sharding_rules=self._params_sharding_rules)\n params = self._deployer.shard_params(\n params=params, params_spec=params_spec)\n\n preds = []\n for batch in data_batches:\n if self._p_pred_step is None:\n self.setup_running_step(\n dummy_batch=batch, params_shape_or_params=params)\n\n pred_rng = self._deployer.gen_rng()\n if self.mesh is None:\n pred_rng = jax.random.split(\n pred_rng, num=jax.process_count())[jax.process_index()]\n pred_rng = shard_prng_key(pred_rng)\n\n batch_preds_with_idxes = self._deployer.run_model_step(\n step_fn=self._p_pred_step,\n input_args=(pred_rng, params, batch))\n batch_preds = process_batch_preds(\n batch_preds_with_idxes=batch_preds_with_idxes, mesh=self.mesh)\n batch_preds = self._output_fn(batch_preds)\n\n assert isinstance(batch_preds, list) and \\\n len(batch_preds) == global_micro_batch_size\n preds.extend(batch_preds)\n\n return preds[:raw_n_inputs]\n
"},{"location":"predictor/#redco.predictors.predictor.Predictor.setup_running_step","title":"setup_running_step(dummy_batch, params_shape_or_params)
","text":"Sets up the prediction step function for distributed inference.
Parameters:
Name Type Description Defaultdummy_batch
PyTree
A dummy batch used to determine data shapes.
requiredparams_shape_or_params
dict
The shape of params or actual params.
required Source code inredco/predictors/predictor.py
def setup_running_step(self, dummy_batch, params_shape_or_params):\n \"\"\"Sets up the prediction step function for distributed inference.\n\n Args:\n dummy_batch (PyTree): A dummy batch used to determine data shapes.\n params_shape_or_params (dict): The shape of params or actual params.\n \"\"\"\n pred_step_fn = partial(pred_step, pred_fn=self._pred_fn, mesh=self.mesh)\n\n if self.mesh is None:\n self._p_pred_step = jax.pmap(pred_step_fn, axis_name='dp')\n else:\n data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params_shape_or_params,\n params_sharding_rules=self._params_sharding_rules)\n self._p_pred_step = pjit(\n pred_step_fn,\n in_shardings=(None, params_spec, data_spec),\n out_shardings=None)\n
"},{"location":"trainer/","title":"Trainer","text":""},{"location":"trainer/#redco.trainers.trainer.Trainer","title":"Trainer
","text":"Trainer class managing distributed training process.
Attributes:
Name Type Descriptionstep
int
Current training step.
workdir
str
Working directory for saving checkpoints and logs.
mesh
jax Mesh
Mesh used for distributed training.
state
flax TrainState
Current training state.
Source code inredco/trainers/trainer.py
class Trainer:\n \"\"\"Trainer class managing distributed training process.\n\n Attributes:\n step (int): Current training step.\n workdir (str): Working directory for saving checkpoints and logs.\n mesh (jax Mesh): Mesh used for distributed training.\n state (flax TrainState): Current training state.\n \"\"\"\n def __init__(self,\n deployer,\n collate_fn,\n apply_fn,\n loss_fn,\n params,\n optimizer,\n opt_state=None,\n compute_dtype=jnp.float32,\n last_ckpt_info=None,\n lr_schedule_fn=None,\n accumulate_grad_batches=None,\n params_sharding_rules=None,\n train_step_fn=None):\n \"\"\"Initializes the Trainer with initial parameters, etc.\n\n Args:\n deployer (Deployer): A deployer supporting low-level operations.\n collate_fn (Callable): The function converting a data batch to model\n inputs, e.g., tokenizing sentences into input_ids.\n apply_fn (Callable): The function to apply the model, such as\n model.apply for Flax modules, or model itself for HuggingFace\n models. It would be set as state.apply_fn, and used in loss_fn.\n loss_fn (Callable): The loss function converting model inputs to a\n scalar loss, e.g., computing cross-entropy loss from input_ids.\n params (dict): Initial model parameters.\n optimizer (optax optimizer): The optimizer used for training.\n opt_state (dict): optimizer state.\n compute_dtype (dtype): Computation dtype, e.g., `jnp.bfloat16`,\n independent of param dtypes. (for mixed-precision training)\n last_ckpt_info (dict): the beginning step and epoch.\n lr_schedule_fn (Callable): The learning rate schedule\n function converting step to learning rate.\n accumulate_grad_batches (int): Gradient accumulation step.\n params_sharding_rules (list): Sharding rules.\n train_step_fn (Callable): For fully customizing every training step,\n e.g., per-sample gradient noising for data-private training.\n \"\"\"\n self._deployer = deployer\n self._collate_fn = collate_fn\n self._apply_fn = apply_fn\n self._loss_fn = loss_fn\n self._optimizer = optimizer\n self._compute_dtype = compute_dtype\n self._lr_schedule_fn = lr_schedule_fn\n self._accumulate_grad_batches = accumulate_grad_batches\n self._params_sharding_rules = params_sharding_rules\n self._train_step_fn = train_step_fn\n\n self._state = None\n self._state_spec = None\n self._p_train_step = None\n self._p_eval_step = None\n\n self._init_step = 0\n self._init_epoch_idx = 0\n if last_ckpt_info is not None:\n self._init_step = last_ckpt_info.get('step', 0)\n self._init_epoch_idx = last_ckpt_info.get('epoch_idx', -1) + 1\n\n n_params = sum([param.size for param in jax.tree.leaves(params)])\n self._deployer.log_info(f'{n_params:,}', title='Parameters')\n\n self.set_train_state(\n apply_fn=self._apply_fn,\n params=params,\n optimizer=self._optimizer,\n step=self._init_step,\n opt_state=opt_state)\n\n def set_train_state(\n self, apply_fn, params, optimizer, step, opt_state=None):\n \"\"\"Sets/Resets the training state with given parameters and optimizer.\n\n Args:\n apply_fn (Callable): The function to apply the model.\n params (dict): Model parameters.\n optimizer (dict): The optimizer used for training.\n step (int): The training step.\n opt_state (dict): The state of the optimizer.\n \"\"\"\n self._deployer.log_info('Setting train_state ...')\n params = freeze(params)\n\n if self.mesh is None:\n params = jax.device_put(params, jax.local_devices()[0])\n if opt_state is None:\n self._deployer.log_info('Initializing opt_state ...')\n opt_state = optimizer.init(params)\n else:\n opt_state = jax.device_put(opt_state, jax.local_devices()[0])\n\n self._state = train_state.TrainState(\n step=step,\n apply_fn=apply_fn,\n params=params,\n tx=optimizer,\n opt_state=opt_state)\n self._state = replicate(self._state)\n else:\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params,\n params_sharding_rules=self._params_sharding_rules)\n params = self._deployer.shard_params(\n params=params, params_spec=params_spec)\n\n if opt_state is None:\n self._deployer.log_info('Initializing opt_state ...')\n opt_state = optimizer.init(params)\n\n opt_state_spec = self._deployer.get_opt_state_spec(\n params_shape_or_params=params,\n params_spec=params_spec,\n optimizer=optimizer)\n opt_state = self._deployer.shard_params(\n params=opt_state,\n params_spec=opt_state_spec,\n desc='opt_state')\n\n self._state = train_state.TrainState(\n apply_fn=apply_fn,\n params=params,\n tx=optimizer,\n opt_state=opt_state,\n step=step)\n\n self._state_spec = train_state.TrainState(\n apply_fn=apply_fn,\n params=params_spec,\n tx=optimizer,\n opt_state=opt_state_spec,\n step=None)\n\n def setup_running_step(self, dummy_batch):\n \"\"\"Sets up the running step functions for training and evaluation.\n\n Args:\n dummy_batch (PyTree): A dummy batch of data.\n \"\"\"\n train_step_fn = partial(\n self._train_step_fn or default_train_step,\n loss_fn=self._loss_fn,\n lr_schedule_fn=self._lr_schedule_fn,\n mesh=self.mesh,\n compute_dtype=self._compute_dtype)\n eval_step_fn = partial(\n eval_step,\n loss_fn=self._loss_fn,\n mesh=self.mesh,\n compute_dtype=self._compute_dtype)\n\n if self.mesh is None:\n self._p_train_step = jax.pmap(train_step_fn, axis_name='dp')\n self._p_eval_step = jax.pmap(eval_step_fn, axis_name='dp')\n else:\n data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n self._p_train_step = pjit(\n train_step_fn,\n in_shardings=(None, self._state_spec, data_spec),\n out_shardings=(self._state_spec, None),\n donate_argnums=(1, ))\n self._p_eval_step = pjit(\n eval_step_fn,\n in_shardings=(self._state_spec, data_spec),\n out_shardings=None)\n\n def train(self, examples, per_device_batch_size, desc=None):\n \"\"\"Trains the model on the provided examples.\n\n Args:\n examples (list): Training examples in python list.\n per_device_batch_size (int): The batch size per device.\n desc (str): Description in the progress bar.\n \"\"\"\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=True,\n shuffle_rng=self._deployer.gen_rng(),\n desc=f'Training ({desc})' if desc is not None else 'Training',\n is_train=True,\n accumulate_grad_batches=self._accumulate_grad_batches)\n\n for batch in data_batches:\n if self._p_train_step is None:\n self.setup_running_step(dummy_batch=batch)\n\n train_rng = self._deployer.gen_rng()\n if self.mesh is None:\n train_rng = jax.random.split(\n train_rng, num=jax.process_count())[jax.process_index()]\n train_rng = shard_prng_key(train_rng)\n self._state, metrics = self._deployer.run_model_step(\n step_fn=self._p_train_step,\n input_args=(train_rng, self._state, batch))\n\n if self.mesh is None:\n metrics = unreplicate(metrics)\n data_batches.set_postfix(**metrics)\n self._deployer.log_metrics(metrics=metrics, step=self.step)\n\n def eval_loss(self, examples, per_device_batch_size, desc=None):\n \"\"\"Evaluates the loss on the provided examples.\n\n Args:\n examples (list): Evaluation examples in list.\n per_device_batch_size (int): The batch size per device.\n desc (str): Description in the progress bar.\n\n Returns:\n (float): The average loss over the evaluation examples.\n \"\"\"\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=False,\n shuffle_rng=None,\n desc=f'Evaluating ({desc})' if desc is not None else 'Evaluating')\n\n losses = []\n for batch in data_batches:\n if self._p_eval_step is None:\n self.setup_running_step(dummy_batch=batch)\n\n metrics = self._deployer.run_model_step(\n step_fn=self._p_eval_step, input_args=(self._state, batch))\n\n if self.mesh is None:\n metrics = unreplicate(metrics)\n\n losses.append(metrics['loss'].item())\n data_batches.set_postfix(**metrics)\n\n return np.mean(losses).item()\n\n def fit(self,\n train_examples,\n per_device_batch_size,\n n_epochs,\n eval_examples=None,\n eval_per_device_batch_size=None,\n eval_loss=True,\n eval_predictor=None,\n eval_metric_fn=None,\n eval_sanity_check=True,\n save_every_ckpt=False,\n save_last_ckpt=False,\n save_argmin_ckpt_by_metrics=None,\n save_argmax_ckpt_by_metrics=None,\n save_opt_states=True,\n save_float_dtype=None):\n \"\"\"Fits the model on the training data for a given number of epochs,\n optionally evaluating and saving checkpoints.\n\n Args:\n train_examples (list or Callable): Training examples, can be a\n list or a function of epoch_idx (for assigning different\n examples in separate epochs/chunks),\n e.g., `train_examples=lambda epoch_idx: load_data(chunk_idx)`\n per_device_batch_size (int): The batch size per device.\n n_epochs (int): Number of epochs to train.\n eval_examples (list): Examples for evaluation and prediction.\n eval_per_device_batch_size (int): Batch size for evaluation\n eval_loss (bool): Whether to evaluate loss.\n eval_predictor (Predictor): Predictor working on `eval_examples`.\n eval_metric_fn (Callable): Metric function for prediction.\n eval_sanity_check (bool): if to run a sanity check for\n evaluation & predict functions before training.\n save_every_ckpt (bool): if to save a ckpt after every epoch.\n save_last_ckpt (bool): Whether to save the last checkpoint.\n save_argmin_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n based on minimum values.\n save_argmax_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n based on maximum values.\n save_opt_states (bool): of to save optimizer states in ckpts.\n save_float_dtype (bool): The data type for saving checkpoints.\n \"\"\"\n if eval_per_device_batch_size is None:\n eval_per_device_batch_size = per_device_batch_size\n\n if save_argmax_ckpt_by_metrics is None:\n save_argmax_ckpt_by_metrics = []\n if save_argmin_ckpt_by_metrics is None:\n save_argmin_ckpt_by_metrics = []\n min_metrics, max_metrics = {}, {}\n\n if os.path.exists(f'{self.workdir}/min_metrics.json'):\n min_metrics = json.load(open(\n f'{self.workdir}/min_metrics.json'))\n self._deployer.log_info(\n json.dumps(min_metrics, indent=4), title='Detected min_metrics')\n\n if os.path.exists(f'{self.workdir}/max_metrics.json'):\n max_metrics = json.load(open(\n f'{self.workdir}/max_metrics.json'))\n self._deployer.log_info(\n json.dumps(max_metrics, indent=4), title='Detected max_metrics')\n\n if eval_sanity_check and eval_examples is not None:\n rng_backup = self._deployer._rng\n _, eval_global_micro_batch_size = \\\n self._deployer.get_local_global_micro_batch_size(\n per_device_batch_size=eval_per_device_batch_size)\n\n if eval_loss:\n self.eval_loss(\n examples=eval_examples[:eval_global_micro_batch_size],\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'Sanity check')\n self._deployer.log_info(\n 'Sanity check (for evaluation loss) passed.')\n\n if eval_predictor is not None:\n preds = eval_predictor.predict(\n examples=eval_examples[:eval_global_micro_batch_size],\n params=self._state.params,\n params_replicated=(self.mesh is None),\n params_sharded=(self.mesh is not None),\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'Sanity check')\n self._deployer.log_info(\n 'Sanity check (for prediction) passed.')\n\n if eval_metric_fn is not None:\n json.dumps(eval_metric_fn(\n examples=eval_examples[:eval_global_micro_batch_size],\n preds=preds))\n self._deployer.log_info(\n 'Sanity check (for evaluation metrics) passed.')\n\n self._deployer._rng = rng_backup\n\n for epoch_idx in range(self._init_epoch_idx, n_epochs):\n if isinstance(train_examples, list):\n epoch_train_examples = train_examples\n else:\n epoch_train_examples = train_examples(epoch_idx=epoch_idx)\n\n self.train(\n examples=epoch_train_examples,\n per_device_batch_size=per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n\n save_ckpt_kwargs = {\n 'epoch_idx': epoch_idx,\n 'save_opt_state': save_opt_states,\n 'float_dtype': save_float_dtype\n }\n\n if eval_examples is None:\n self._deployer.log_info(\n 'No evaluation cuz \\'eval_examples\\' is None.')\n else:\n eval_metrics = {}\n\n if eval_loss:\n loss = self.eval_loss(\n examples=eval_examples,\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n eval_metrics['loss'] = loss\n\n if eval_predictor is not None:\n preds = eval_predictor.predict(\n examples=eval_examples,\n params=self._state.params,\n params_replicated=(self.mesh is None),\n params_sharded=(self.mesh is not None),\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n\n if eval_metric_fn is not None:\n eval_metrics.update(eval_metric_fn(\n examples=eval_examples, preds=preds))\n\n eval_outputs = [\n {'example': example, 'pred': pred}\n for example, pred in zip(eval_examples, preds)]\n\n self._deployer.save_outputs(\n outputs=eval_outputs,\n desc=f'epoch{epoch_idx}',\n step=self.step)\n\n self._deployer.log_info(\n info=json.dumps(eval_metrics, indent=4),\n title=f'Eval results',\n step=self.step)\n self._deployer.log_metrics(metrics={\n f'eval_{key}': value\n for key, value in eval_metrics.items()\n }, step=self.step)\n\n if self.workdir is not None:\n result_filepath = \\\n f'{self.workdir}/eval_results_epoch{epoch_idx}.json'\n json.dump(\n eval_metrics, open(result_filepath, 'w'), indent=4)\n self._deployer.log_info(\n f'eval_results saved into {result_filepath}.')\n\n for key in save_argmin_ckpt_by_metrics:\n assert self.workdir is not None\n if eval_metrics[key] < min_metrics.get(key, float('inf')):\n min_metrics[key] = eval_metrics[key]\n\n if jax.process_index() == 0:\n self._deployer.log_info(\n f'minimal {key} updated to {min_metrics[key]}')\n json.dump(min_metrics, open(\n f'{self.workdir}/min_metrics.json', 'w'))\n\n self.save_ckpt(\n ckpt_name=f'min_{key}', **save_ckpt_kwargs)\n\n for key in save_argmax_ckpt_by_metrics:\n assert self.workdir is not None\n if eval_metrics[key] > max_metrics.get(key, float('-inf')):\n max_metrics[key] = eval_metrics[key]\n\n if jax.process_index() == 0:\n self._deployer.log_info(\n f'maximal {key} updated to {max_metrics[key]}')\n json.dump(max_metrics, open(\n f'{self.workdir}/max_metrics.json', 'w'))\n\n self.save_ckpt(\n ckpt_name=f'max_{key}', **save_ckpt_kwargs)\n\n if save_every_ckpt:\n self.save_ckpt(\n ckpt_name=f'epoch_{epoch_idx}', **save_ckpt_kwargs)\n elif save_last_ckpt:\n self.save_ckpt(ckpt_name='last', **save_ckpt_kwargs)\n\n def save_ckpt(self, epoch_idx, ckpt_name, save_opt_state, float_dtype):\n \"\"\"Saves a checkpoint into `{self.workdir}/ckpts`.\n\n Args:\n epoch_idx (int): The current epoch index.\n ckpt_name (str): The name of the checkpoint.\n save_opt_state (bool): Whether to save the optimizer state.\n float_dtype (`jax.numpy.dtype`): Data type for saving float params.\n \"\"\"\n if self.mesh is None:\n params = jax.tree.map(\n fully_replicated_host_local_array_to_global_array,\n self._state.params)\n else:\n params = self._state.params\n\n opt_state = None\n if save_opt_state:\n if self.mesh is None:\n opt_state = jax.tree.map(\n fully_replicated_host_local_array_to_global_array,\n self._state.opt_state)\n else:\n opt_state = self._state.opt_state\n\n ckpt_dir = f'{self.workdir}/ckpts/{ckpt_name}'\n self._deployer.save_ckpt(\n ckpt_dir=ckpt_dir,\n params=params,\n opt_state=opt_state,\n float_dtype=float_dtype,\n step=self.step,\n epoch_idx=epoch_idx)\n\n if jax.process_index() == 0:\n open(f'{self.workdir}/ckpts/last_ckpt.txt', 'w').write(ckpt_name)\n self._deployer.log_info(f'last ckpt updated -- {ckpt_dir}')\n\n @property\n def step(self):\n \"\"\"Returns the current training step.\"\"\"\n if self.mesh is None:\n return unreplicate(self._state.step).item()\n else:\n return self._state.step.item()\n\n @property\n def workdir(self):\n \"\"\"Returns the working directory for saving checkpoints and logs.\"\"\"\n return self._deployer.workdir\n\n @property\n def mesh(self):\n \"\"\"Returns the mesh used for distributed training.\"\"\"\n return self._deployer.mesh\n\n @property\n def state(self):\n \"\"\"Returns the current training state.\"\"\"\n return self._state\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.mesh","title":"mesh
property
","text":"Returns the mesh used for distributed training.
"},{"location":"trainer/#redco.trainers.trainer.Trainer.state","title":"state
property
","text":"Returns the current training state.
"},{"location":"trainer/#redco.trainers.trainer.Trainer.step","title":"step
property
","text":"Returns the current training step.
"},{"location":"trainer/#redco.trainers.trainer.Trainer.workdir","title":"workdir
property
","text":"Returns the working directory for saving checkpoints and logs.
"},{"location":"trainer/#redco.trainers.trainer.Trainer.__init__","title":"__init__(deployer, collate_fn, apply_fn, loss_fn, params, optimizer, opt_state=None, compute_dtype=jnp.float32, last_ckpt_info=None, lr_schedule_fn=None, accumulate_grad_batches=None, params_sharding_rules=None, train_step_fn=None)
","text":"Initializes the Trainer with initial parameters, etc.
Parameters:
Name Type Description Defaultdeployer
Deployer
A deployer supporting low-level operations.
requiredcollate_fn
Callable
The function converting a data batch to model inputs, e.g., tokenizing sentences into input_ids.
requiredapply_fn
Callable
The function to apply the model, such as model.apply for Flax modules, or model itself for HuggingFace models. It would be set as state.apply_fn, and used in loss_fn.
requiredloss_fn
Callable
The loss function converting model inputs to a scalar loss, e.g., computing cross-entropy loss from input_ids.
requiredparams
dict
Initial model parameters.
requiredoptimizer
optax optimizer
The optimizer used for training.
requiredopt_state
dict
optimizer state.
None
compute_dtype
dtype
Computation dtype, e.g., jnp.bfloat16
, independent of param dtypes. (for mixed-precision training)
float32
last_ckpt_info
dict
the beginning step and epoch.
None
lr_schedule_fn
Callable
The learning rate schedule function converting step to learning rate.
None
accumulate_grad_batches
int
Gradient accumulation step.
None
params_sharding_rules
list
Sharding rules.
None
train_step_fn
Callable
For fully customizing every training step, e.g., per-sample gradient noising for data-private training.
None
Source code in redco/trainers/trainer.py
def __init__(self,\n deployer,\n collate_fn,\n apply_fn,\n loss_fn,\n params,\n optimizer,\n opt_state=None,\n compute_dtype=jnp.float32,\n last_ckpt_info=None,\n lr_schedule_fn=None,\n accumulate_grad_batches=None,\n params_sharding_rules=None,\n train_step_fn=None):\n \"\"\"Initializes the Trainer with initial parameters, etc.\n\n Args:\n deployer (Deployer): A deployer supporting low-level operations.\n collate_fn (Callable): The function converting a data batch to model\n inputs, e.g., tokenizing sentences into input_ids.\n apply_fn (Callable): The function to apply the model, such as\n model.apply for Flax modules, or model itself for HuggingFace\n models. It would be set as state.apply_fn, and used in loss_fn.\n loss_fn (Callable): The loss function converting model inputs to a\n scalar loss, e.g., computing cross-entropy loss from input_ids.\n params (dict): Initial model parameters.\n optimizer (optax optimizer): The optimizer used for training.\n opt_state (dict): optimizer state.\n compute_dtype (dtype): Computation dtype, e.g., `jnp.bfloat16`,\n independent of param dtypes. (for mixed-precision training)\n last_ckpt_info (dict): the beginning step and epoch.\n lr_schedule_fn (Callable): The learning rate schedule\n function converting step to learning rate.\n accumulate_grad_batches (int): Gradient accumulation step.\n params_sharding_rules (list): Sharding rules.\n train_step_fn (Callable): For fully customizing every training step,\n e.g., per-sample gradient noising for data-private training.\n \"\"\"\n self._deployer = deployer\n self._collate_fn = collate_fn\n self._apply_fn = apply_fn\n self._loss_fn = loss_fn\n self._optimizer = optimizer\n self._compute_dtype = compute_dtype\n self._lr_schedule_fn = lr_schedule_fn\n self._accumulate_grad_batches = accumulate_grad_batches\n self._params_sharding_rules = params_sharding_rules\n self._train_step_fn = train_step_fn\n\n self._state = None\n self._state_spec = None\n self._p_train_step = None\n self._p_eval_step = None\n\n self._init_step = 0\n self._init_epoch_idx = 0\n if last_ckpt_info is not None:\n self._init_step = last_ckpt_info.get('step', 0)\n self._init_epoch_idx = last_ckpt_info.get('epoch_idx', -1) + 1\n\n n_params = sum([param.size for param in jax.tree.leaves(params)])\n self._deployer.log_info(f'{n_params:,}', title='Parameters')\n\n self.set_train_state(\n apply_fn=self._apply_fn,\n params=params,\n optimizer=self._optimizer,\n step=self._init_step,\n opt_state=opt_state)\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.eval_loss","title":"eval_loss(examples, per_device_batch_size, desc=None)
","text":"Evaluates the loss on the provided examples.
Parameters:
Name Type Description Defaultexamples
list
Evaluation examples in list.
requiredper_device_batch_size
int
The batch size per device.
requireddesc
str
Description in the progress bar.
None
Returns:
Type Descriptionfloat
The average loss over the evaluation examples.
Source code inredco/trainers/trainer.py
def eval_loss(self, examples, per_device_batch_size, desc=None):\n \"\"\"Evaluates the loss on the provided examples.\n\n Args:\n examples (list): Evaluation examples in list.\n per_device_batch_size (int): The batch size per device.\n desc (str): Description in the progress bar.\n\n Returns:\n (float): The average loss over the evaluation examples.\n \"\"\"\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=False,\n shuffle_rng=None,\n desc=f'Evaluating ({desc})' if desc is not None else 'Evaluating')\n\n losses = []\n for batch in data_batches:\n if self._p_eval_step is None:\n self.setup_running_step(dummy_batch=batch)\n\n metrics = self._deployer.run_model_step(\n step_fn=self._p_eval_step, input_args=(self._state, batch))\n\n if self.mesh is None:\n metrics = unreplicate(metrics)\n\n losses.append(metrics['loss'].item())\n data_batches.set_postfix(**metrics)\n\n return np.mean(losses).item()\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.fit","title":"fit(train_examples, per_device_batch_size, n_epochs, eval_examples=None, eval_per_device_batch_size=None, eval_loss=True, eval_predictor=None, eval_metric_fn=None, eval_sanity_check=True, save_every_ckpt=False, save_last_ckpt=False, save_argmin_ckpt_by_metrics=None, save_argmax_ckpt_by_metrics=None, save_opt_states=True, save_float_dtype=None)
","text":"Fits the model on the training data for a given number of epochs, optionally evaluating and saving checkpoints.
Parameters:
Name Type Description Defaulttrain_examples
list or Callable
Training examples, can be a list or a function of epoch_idx (for assigning different examples in separate epochs/chunks), e.g., train_examples=lambda epoch_idx: load_data(chunk_idx)
per_device_batch_size
int
The batch size per device.
requiredn_epochs
int
Number of epochs to train.
requiredeval_examples
list
Examples for evaluation and prediction.
None
eval_per_device_batch_size
int
Batch size for evaluation
None
eval_loss
bool
Whether to evaluate loss.
True
eval_predictor
Predictor
Predictor working on eval_examples
.
None
eval_metric_fn
Callable
Metric function for prediction.
None
eval_sanity_check
bool
if to run a sanity check for evaluation & predict functions before training.
True
save_every_ckpt
bool
if to save a ckpt after every epoch.
False
save_last_ckpt
bool
Whether to save the last checkpoint.
False
save_argmin_ckpt_by_metrics
list[str]
Metrics to save checkpoints based on minimum values.
None
save_argmax_ckpt_by_metrics
list[str]
Metrics to save checkpoints based on maximum values.
None
save_opt_states
bool
of to save optimizer states in ckpts.
True
save_float_dtype
bool
The data type for saving checkpoints.
None
Source code in redco/trainers/trainer.py
def fit(self,\n train_examples,\n per_device_batch_size,\n n_epochs,\n eval_examples=None,\n eval_per_device_batch_size=None,\n eval_loss=True,\n eval_predictor=None,\n eval_metric_fn=None,\n eval_sanity_check=True,\n save_every_ckpt=False,\n save_last_ckpt=False,\n save_argmin_ckpt_by_metrics=None,\n save_argmax_ckpt_by_metrics=None,\n save_opt_states=True,\n save_float_dtype=None):\n \"\"\"Fits the model on the training data for a given number of epochs,\n optionally evaluating and saving checkpoints.\n\n Args:\n train_examples (list or Callable): Training examples, can be a\n list or a function of epoch_idx (for assigning different\n examples in separate epochs/chunks),\n e.g., `train_examples=lambda epoch_idx: load_data(chunk_idx)`\n per_device_batch_size (int): The batch size per device.\n n_epochs (int): Number of epochs to train.\n eval_examples (list): Examples for evaluation and prediction.\n eval_per_device_batch_size (int): Batch size for evaluation\n eval_loss (bool): Whether to evaluate loss.\n eval_predictor (Predictor): Predictor working on `eval_examples`.\n eval_metric_fn (Callable): Metric function for prediction.\n eval_sanity_check (bool): if to run a sanity check for\n evaluation & predict functions before training.\n save_every_ckpt (bool): if to save a ckpt after every epoch.\n save_last_ckpt (bool): Whether to save the last checkpoint.\n save_argmin_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n based on minimum values.\n save_argmax_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n based on maximum values.\n save_opt_states (bool): of to save optimizer states in ckpts.\n save_float_dtype (bool): The data type for saving checkpoints.\n \"\"\"\n if eval_per_device_batch_size is None:\n eval_per_device_batch_size = per_device_batch_size\n\n if save_argmax_ckpt_by_metrics is None:\n save_argmax_ckpt_by_metrics = []\n if save_argmin_ckpt_by_metrics is None:\n save_argmin_ckpt_by_metrics = []\n min_metrics, max_metrics = {}, {}\n\n if os.path.exists(f'{self.workdir}/min_metrics.json'):\n min_metrics = json.load(open(\n f'{self.workdir}/min_metrics.json'))\n self._deployer.log_info(\n json.dumps(min_metrics, indent=4), title='Detected min_metrics')\n\n if os.path.exists(f'{self.workdir}/max_metrics.json'):\n max_metrics = json.load(open(\n f'{self.workdir}/max_metrics.json'))\n self._deployer.log_info(\n json.dumps(max_metrics, indent=4), title='Detected max_metrics')\n\n if eval_sanity_check and eval_examples is not None:\n rng_backup = self._deployer._rng\n _, eval_global_micro_batch_size = \\\n self._deployer.get_local_global_micro_batch_size(\n per_device_batch_size=eval_per_device_batch_size)\n\n if eval_loss:\n self.eval_loss(\n examples=eval_examples[:eval_global_micro_batch_size],\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'Sanity check')\n self._deployer.log_info(\n 'Sanity check (for evaluation loss) passed.')\n\n if eval_predictor is not None:\n preds = eval_predictor.predict(\n examples=eval_examples[:eval_global_micro_batch_size],\n params=self._state.params,\n params_replicated=(self.mesh is None),\n params_sharded=(self.mesh is not None),\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'Sanity check')\n self._deployer.log_info(\n 'Sanity check (for prediction) passed.')\n\n if eval_metric_fn is not None:\n json.dumps(eval_metric_fn(\n examples=eval_examples[:eval_global_micro_batch_size],\n preds=preds))\n self._deployer.log_info(\n 'Sanity check (for evaluation metrics) passed.')\n\n self._deployer._rng = rng_backup\n\n for epoch_idx in range(self._init_epoch_idx, n_epochs):\n if isinstance(train_examples, list):\n epoch_train_examples = train_examples\n else:\n epoch_train_examples = train_examples(epoch_idx=epoch_idx)\n\n self.train(\n examples=epoch_train_examples,\n per_device_batch_size=per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n\n save_ckpt_kwargs = {\n 'epoch_idx': epoch_idx,\n 'save_opt_state': save_opt_states,\n 'float_dtype': save_float_dtype\n }\n\n if eval_examples is None:\n self._deployer.log_info(\n 'No evaluation cuz \\'eval_examples\\' is None.')\n else:\n eval_metrics = {}\n\n if eval_loss:\n loss = self.eval_loss(\n examples=eval_examples,\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n eval_metrics['loss'] = loss\n\n if eval_predictor is not None:\n preds = eval_predictor.predict(\n examples=eval_examples,\n params=self._state.params,\n params_replicated=(self.mesh is None),\n params_sharded=(self.mesh is not None),\n per_device_batch_size=eval_per_device_batch_size,\n desc=f'epoch {epoch_idx} / {n_epochs}')\n\n if eval_metric_fn is not None:\n eval_metrics.update(eval_metric_fn(\n examples=eval_examples, preds=preds))\n\n eval_outputs = [\n {'example': example, 'pred': pred}\n for example, pred in zip(eval_examples, preds)]\n\n self._deployer.save_outputs(\n outputs=eval_outputs,\n desc=f'epoch{epoch_idx}',\n step=self.step)\n\n self._deployer.log_info(\n info=json.dumps(eval_metrics, indent=4),\n title=f'Eval results',\n step=self.step)\n self._deployer.log_metrics(metrics={\n f'eval_{key}': value\n for key, value in eval_metrics.items()\n }, step=self.step)\n\n if self.workdir is not None:\n result_filepath = \\\n f'{self.workdir}/eval_results_epoch{epoch_idx}.json'\n json.dump(\n eval_metrics, open(result_filepath, 'w'), indent=4)\n self._deployer.log_info(\n f'eval_results saved into {result_filepath}.')\n\n for key in save_argmin_ckpt_by_metrics:\n assert self.workdir is not None\n if eval_metrics[key] < min_metrics.get(key, float('inf')):\n min_metrics[key] = eval_metrics[key]\n\n if jax.process_index() == 0:\n self._deployer.log_info(\n f'minimal {key} updated to {min_metrics[key]}')\n json.dump(min_metrics, open(\n f'{self.workdir}/min_metrics.json', 'w'))\n\n self.save_ckpt(\n ckpt_name=f'min_{key}', **save_ckpt_kwargs)\n\n for key in save_argmax_ckpt_by_metrics:\n assert self.workdir is not None\n if eval_metrics[key] > max_metrics.get(key, float('-inf')):\n max_metrics[key] = eval_metrics[key]\n\n if jax.process_index() == 0:\n self._deployer.log_info(\n f'maximal {key} updated to {max_metrics[key]}')\n json.dump(max_metrics, open(\n f'{self.workdir}/max_metrics.json', 'w'))\n\n self.save_ckpt(\n ckpt_name=f'max_{key}', **save_ckpt_kwargs)\n\n if save_every_ckpt:\n self.save_ckpt(\n ckpt_name=f'epoch_{epoch_idx}', **save_ckpt_kwargs)\n elif save_last_ckpt:\n self.save_ckpt(ckpt_name='last', **save_ckpt_kwargs)\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.save_ckpt","title":"save_ckpt(epoch_idx, ckpt_name, save_opt_state, float_dtype)
","text":"Saves a checkpoint into {self.workdir}/ckpts
.
Parameters:
Name Type Description Defaultepoch_idx
int
The current epoch index.
requiredckpt_name
str
The name of the checkpoint.
requiredsave_opt_state
bool
Whether to save the optimizer state.
requiredfloat_dtype
`jax.numpy.dtype`
Data type for saving float params.
required Source code inredco/trainers/trainer.py
def save_ckpt(self, epoch_idx, ckpt_name, save_opt_state, float_dtype):\n \"\"\"Saves a checkpoint into `{self.workdir}/ckpts`.\n\n Args:\n epoch_idx (int): The current epoch index.\n ckpt_name (str): The name of the checkpoint.\n save_opt_state (bool): Whether to save the optimizer state.\n float_dtype (`jax.numpy.dtype`): Data type for saving float params.\n \"\"\"\n if self.mesh is None:\n params = jax.tree.map(\n fully_replicated_host_local_array_to_global_array,\n self._state.params)\n else:\n params = self._state.params\n\n opt_state = None\n if save_opt_state:\n if self.mesh is None:\n opt_state = jax.tree.map(\n fully_replicated_host_local_array_to_global_array,\n self._state.opt_state)\n else:\n opt_state = self._state.opt_state\n\n ckpt_dir = f'{self.workdir}/ckpts/{ckpt_name}'\n self._deployer.save_ckpt(\n ckpt_dir=ckpt_dir,\n params=params,\n opt_state=opt_state,\n float_dtype=float_dtype,\n step=self.step,\n epoch_idx=epoch_idx)\n\n if jax.process_index() == 0:\n open(f'{self.workdir}/ckpts/last_ckpt.txt', 'w').write(ckpt_name)\n self._deployer.log_info(f'last ckpt updated -- {ckpt_dir}')\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.set_train_state","title":"set_train_state(apply_fn, params, optimizer, step, opt_state=None)
","text":"Sets/Resets the training state with given parameters and optimizer.
Parameters:
Name Type Description Defaultapply_fn
Callable
The function to apply the model.
requiredparams
dict
Model parameters.
requiredoptimizer
dict
The optimizer used for training.
requiredstep
int
The training step.
requiredopt_state
dict
The state of the optimizer.
None
Source code in redco/trainers/trainer.py
def set_train_state(\n self, apply_fn, params, optimizer, step, opt_state=None):\n \"\"\"Sets/Resets the training state with given parameters and optimizer.\n\n Args:\n apply_fn (Callable): The function to apply the model.\n params (dict): Model parameters.\n optimizer (dict): The optimizer used for training.\n step (int): The training step.\n opt_state (dict): The state of the optimizer.\n \"\"\"\n self._deployer.log_info('Setting train_state ...')\n params = freeze(params)\n\n if self.mesh is None:\n params = jax.device_put(params, jax.local_devices()[0])\n if opt_state is None:\n self._deployer.log_info('Initializing opt_state ...')\n opt_state = optimizer.init(params)\n else:\n opt_state = jax.device_put(opt_state, jax.local_devices()[0])\n\n self._state = train_state.TrainState(\n step=step,\n apply_fn=apply_fn,\n params=params,\n tx=optimizer,\n opt_state=opt_state)\n self._state = replicate(self._state)\n else:\n params_spec = self._deployer.get_params_spec(\n params_shape_or_params=params,\n params_sharding_rules=self._params_sharding_rules)\n params = self._deployer.shard_params(\n params=params, params_spec=params_spec)\n\n if opt_state is None:\n self._deployer.log_info('Initializing opt_state ...')\n opt_state = optimizer.init(params)\n\n opt_state_spec = self._deployer.get_opt_state_spec(\n params_shape_or_params=params,\n params_spec=params_spec,\n optimizer=optimizer)\n opt_state = self._deployer.shard_params(\n params=opt_state,\n params_spec=opt_state_spec,\n desc='opt_state')\n\n self._state = train_state.TrainState(\n apply_fn=apply_fn,\n params=params,\n tx=optimizer,\n opt_state=opt_state,\n step=step)\n\n self._state_spec = train_state.TrainState(\n apply_fn=apply_fn,\n params=params_spec,\n tx=optimizer,\n opt_state=opt_state_spec,\n step=None)\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.setup_running_step","title":"setup_running_step(dummy_batch)
","text":"Sets up the running step functions for training and evaluation.
Parameters:
Name Type Description Defaultdummy_batch
PyTree
A dummy batch of data.
required Source code inredco/trainers/trainer.py
def setup_running_step(self, dummy_batch):\n \"\"\"Sets up the running step functions for training and evaluation.\n\n Args:\n dummy_batch (PyTree): A dummy batch of data.\n \"\"\"\n train_step_fn = partial(\n self._train_step_fn or default_train_step,\n loss_fn=self._loss_fn,\n lr_schedule_fn=self._lr_schedule_fn,\n mesh=self.mesh,\n compute_dtype=self._compute_dtype)\n eval_step_fn = partial(\n eval_step,\n loss_fn=self._loss_fn,\n mesh=self.mesh,\n compute_dtype=self._compute_dtype)\n\n if self.mesh is None:\n self._p_train_step = jax.pmap(train_step_fn, axis_name='dp')\n self._p_eval_step = jax.pmap(eval_step_fn, axis_name='dp')\n else:\n data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n self._p_train_step = pjit(\n train_step_fn,\n in_shardings=(None, self._state_spec, data_spec),\n out_shardings=(self._state_spec, None),\n donate_argnums=(1, ))\n self._p_eval_step = pjit(\n eval_step_fn,\n in_shardings=(self._state_spec, data_spec),\n out_shardings=None)\n
"},{"location":"trainer/#redco.trainers.trainer.Trainer.train","title":"train(examples, per_device_batch_size, desc=None)
","text":"Trains the model on the provided examples.
Parameters:
Name Type Description Defaultexamples
list
Training examples in python list.
requiredper_device_batch_size
int
The batch size per device.
requireddesc
str
Description in the progress bar.
None
Source code in redco/trainers/trainer.py
def train(self, examples, per_device_batch_size, desc=None):\n \"\"\"Trains the model on the provided examples.\n\n Args:\n examples (list): Training examples in python list.\n per_device_batch_size (int): The batch size per device.\n desc (str): Description in the progress bar.\n \"\"\"\n data_batches = self._deployer.get_model_input_batches(\n examples=examples,\n per_device_batch_size=per_device_batch_size,\n collate_fn=self._collate_fn,\n shuffle=True,\n shuffle_rng=self._deployer.gen_rng(),\n desc=f'Training ({desc})' if desc is not None else 'Training',\n is_train=True,\n accumulate_grad_batches=self._accumulate_grad_batches)\n\n for batch in data_batches:\n if self._p_train_step is None:\n self.setup_running_step(dummy_batch=batch)\n\n train_rng = self._deployer.gen_rng()\n if self.mesh is None:\n train_rng = jax.random.split(\n train_rng, num=jax.process_count())[jax.process_index()]\n train_rng = shard_prng_key(train_rng)\n self._state, metrics = self._deployer.run_model_step(\n step_fn=self._p_train_step,\n input_args=(train_rng, self._state, batch))\n\n if self.mesh is None:\n metrics = unreplicate(metrics)\n data_batches.set_postfix(**metrics)\n self._deployer.log_metrics(metrics=metrics, step=self.step)\n
"}]}
\ No newline at end of file
diff --git a/sitemap.xml b/sitemap.xml
index e9a3417..78bdd3d 100755
--- a/sitemap.xml
+++ b/sitemap.xml
@@ -2,27 +2,27 @@