-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor training pipeline and add Llama2 tokenizer support #7
Conversation
Sunspot frameworks tests
Fix path in `prof.export_chrome_trace()` from `pretrain_gpt_alcf.py`
…tron-DeepSpeed into tokenizer-tests
Merge in `tokenizer-tests` branch into `main`
…into hzheng-data-fix
…tron-DeepSpeed into hzheng-data-fix
Pull in changes from [6acc370](6acc370) to [`megatron/utils.py`](https://github.com/argonne-lcf/Megatron-DeepSpeed)
[merge]: into `microsoft-main` $\leftarrow$ from `hzheng-data-fix`
@sourcery-ai review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @saforem2 - I've reviewed your changes - here's some feedback:
Overall Comments:
- Please add documentation describing the new Llama2 tokenizer support and training pipeline changes, including examples of usage.
Here's what I looked at during the review
- 🟡 General issues: 3 issues found
- 🟡 Security: 1 issue found
- 🟢 Testing: all looks good
- 🟡 Complexity: 3 issues found
- 🟡 Documentation: 3 issues found
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
"lr_state_dict.yaml" | ||
) | ||
log.info(f"Saving lr_state_dict to {lr_state_dict_fp.as_posix()}") | ||
with lr_state_dict_fp.open('w') as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (bug_risk): Consider using atomic file writes for the lr state dict
Use atomic file operations (write to temp file and rename) to prevent corruption if the process crashes during write.
temp_fp = lr_state_dict_fp.with_suffix('.tmp')
with temp_fp.open('w') as f:
yaml.dump(
{'iteration': args.iteration, 'lr': args.lr},
f)
temp_fp.rename(lr_state_dict_fp)
step_size = lr | ||
step_size_neg = step_size.neg() | ||
|
||
ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Consider making epsilon configurable and potentially larger for better numerical stability
The hardcoded epsilon of 1e-15 might be too small for some use cases. Consider making this a configurable parameter with a default value of 1e-8 or similar.
ratio = (exp_avg.abs() / (rho * bs * hess + self.eps)).clamp(None,1)
in_list = "" | ||
for i in json_gz_files: | ||
in_list = in_list + " " +str(i) | ||
command = "cat" + in_list + " > " + output_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚨 suggestion (security): Consider using Python's built-in file operations instead of shell commands for better security and error handling
Shell command injection vulnerabilities could be introduced here. Consider using Python's gzip and file operations directly.
with open(output_file, 'wb') as outfile:
for gz_file in json_gz_files:
with open(gz_file, 'rb') as infile:
outfile.write(infile.read())
return json_gz_files | ||
|
||
def combine_json_gz_files(json_gz_files, output_file): | ||
in_list = "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (performance): Use list for accumulating filenames instead of string concatenation
String concatenation in loops is inefficient. Consider using a list and join() at the end if needed.
in_list = []
@@ -0,0 +1,207 @@ | |||
# Converting Checkpoints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (documentation): Add context for checkpoint conversion
Consider adding a brief introduction explaining when and why users might need to convert checkpoints between Megatron and HuggingFace formats.
# Converting Checkpoints | |
# Converting Checkpoints | |
Checkpoint conversion is necessary when moving between different deep learning frameworks. This guide explains how to convert model weights between Megatron-LM and Hugging Face formats, enabling interoperability between these popular frameworks. |
@@ -0,0 +1,262 @@ | |||
6.322825248625475e-06 /gila/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0000_text_document megawika |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (documentation): Dataset identifier 'megawika' differs from filename 'megawiki'
There appears to be an inconsistency between the filename (megawiki) and the dataset identifier used in the file (megawika). This should be made consistent to avoid confusion.
1.7132649760565998e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0143_text_document | ||
1.7492547092602047e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0144_text_document | ||
1.7499951097392276e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0145_text_document | ||
1.6632444789170958e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0146_text_document | ||
1.6678802252361607e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0147_text_document | ||
1.5519208704558896e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0148_text_document | ||
1.652420992967167e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0149_text_document | ||
1.6119931034508755e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0150_text_document | ||
1.6638882076736552e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0151_text_document | ||
1.7198076782652946e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (documentation): Consider documenting the meaning of the floating point values in a README
While the file paths are self-explanatory, it would be helpful to have documentation explaining what the floating point values represent and how they're used in the system.
# Weights and file paths for training data sources
# Format: <weight> <file_path>
# Weights represent relative importance in training:
# - Reddit: ~0.0005-0.0006
# - StackExchange: ~0.001
# - StarCoder: ~0.003-0.005
# - Tulu-FLAN: ~0.0002-0.0003
# - Wikipedia: ~0.003-0.004
|
||
elif args.optimizer.lower() == "galore_adamw": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider refactoring optimizer initialization into factory functions to reduce complexity
The optimizer initialization can be simplified by extracting the creation logic into factory functions while maintaining all functionality. This reduces nesting and improves maintainability. Example:
def create_adam_optimizer(param_groups, args):
"""Factory function for Adam optimizer variants"""
if args.ds_fused_adam:
from deepspeed.ops.adam import FusedAdam
adam_cls = FusedAdam
else:
adam_cls = torch.optim.Adam
return adam_cls(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
)
def create_adamw_optimizer(param_groups, args):
"""Factory function for AdamW optimizer"""
return torch.optim.AdamW(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
)
# In get_megatron_optimizer:
OPTIMIZER_FACTORIES = {
'adam': create_adam_optimizer,
'adamw': create_adamw_optimizer,
# Add other optimizers similarly
}
# Use factory functions
if args.optimizer.lower() in OPTIMIZER_FACTORIES:
optimizer = OPTIMIZER_FACTORIES[args.optimizer.lower()](param_groups, args)
This approach:
- Reduces nesting depth and complexity
- Makes it easier to add new optimizers
- Keeps configuration explicit and debuggable
- Maintains all existing functionality
) | ||
|
||
|
||
def training_log( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider refactoring the logging code into a dedicated metrics collection class to improve organization and maintainability
The logging code has unnecessary complexity from mixed concerns and repeated patterns. Consider refactoring to use a metrics collection class:
class TrainingMetrics:
def __init__(self, args, writer, wandb_enabled=False):
self.args = args
self.writer = writer
self.wandb_enabled = wandb_enabled
self.metrics = {}
def add_scalar(self, name, value, iteration, log_samples=True, log_tokens=True):
self.metrics[name] = value
if self.writer:
self.writer.add_scalar(f"{name}", value, iteration)
if log_samples:
self.writer.add_scalar(
f"{name} vs samples", value, self.args.consumed_train_samples
)
if log_tokens:
self.writer.add_scalar(
f"{name} vs tokens", value, self.args.consumed_train_tokens
)
def log_loss_dict(self, loss_dict, iteration):
for key, value in loss_dict.items():
self.add_scalar(f"lm-loss-training/{key}", value, iteration)
def get_wandb_metrics(self):
return self.metrics if self.wandb_enabled else {}
Usage example:
metrics = TrainingMetrics(args, writer, wandb is not None)
metrics.add_scalar("learning-rate/learning-rate", learning_rate, iteration)
metrics.log_loss_dict(loss_dict, iteration)
wandb_metrics = metrics.get_wandb_metrics()
This refactoring:
- Centralizes metric collection and formatting
- Reduces code duplication between logging paths
- Makes the code more maintainable by isolating logging logic
- Provides a consistent interface for adding metrics
|
||
os.makedirs(args.trace_dir, exist_ok=True) | ||
|
||
corpus_all = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider refactoring the code into dedicated classes for dataset configuration and performance tracing
The code mixes multiple concerns making it hard to maintain. Here's how to reduce complexity while keeping all functionality:
- Extract dataset configuration into a dedicated class:
class DatasetConfig:
def __init__(self, data_file_list):
self.files = []
self.weights = []
self.corpus_all = []
self._load_config(data_file_list)
def _load_config(self, data_file_list):
with open(data_file_list, 'r') as fin:
for line in fin:
weight, fname, corpus = line.split()
self.weights.append(float(weight))
self.files.extend([float(weight), fname, corpus])
if corpus not in self.corpus_all:
self.corpus_all.append(corpus)
self.weights = np.array(self.weights)
self.weights /= np.sum(self.weights)
- Wrap performance tracing in a context manager:
class PerformanceTracer:
def __init__(self, args, comm):
self.args = args
self.comm = comm
def __enter__(self):
extra_path = os.getenv('DLIO_PROFILER_DATASET_DIR', '')
trace_file = f"{self.args.trace_dir}/trace-{self.comm.rank}-of-{self.comm.size}.pfw"
paths = f"{self.args.data_cache_path}:{extra_path}:{self.args.data_path}:{self.args.save}:{self.args.load}"
PerfTrace.initialize_log(trace_file, paths, process_id=self.comm.rank)
return Profile("TEST_BLENDABLEDATASET")
def __exit__(self, *args):
pass
This separates concerns while maintaining functionality. The main script becomes clearer:
dataset_config = DatasetConfig(args.data_file_list)
with PerformanceTracer(args, comm) as tracer:
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
dataset_config.files,
args.data_impl,
splits_string,
train_valid_test_num_samples,
args.seq_length,
args.seed,
not args.mmap_warmup,
data_cache_path=args.data_cache_path)
# Continue with data loading...
Summary by Sourcery
Refactor the training pipeline and add support for the Llama2 tokenizer, along with logging enhancements for improved tracking and debugging.
New Features:
Enhancements: