Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor training pipeline and add Llama2 tokenizer support #7

Merged
552 commits merged into from
Nov 15, 2024

Conversation

saforem2
Copy link
Owner

@saforem2 saforem2 commented Oct 11, 2024

Summary by Sourcery

Refactor the training pipeline and add support for the Llama2 tokenizer, along with logging enhancements for improved tracking and debugging.

New Features:

  • Add support for Llama2 tokenizer in the training pipeline.

Enhancements:

  • Refactor the training pipeline to improve code organization and readability.
  • Introduce logging enhancements for better tracking and debugging of the training process.
  • Implement additional logging for dataset building and loading processes.

saforem2 and others added 30 commits May 20, 2024 09:44
Fix path in `prof.export_chrome_trace()` from `pretrain_gpt_alcf.py`
Merge in `tokenizer-tests` branch into `main`
saforem2 and others added 24 commits October 14, 2024 21:55
[merge]: into `microsoft-main` $\leftarrow$ from `hzheng-data-fix`
@saforem2
Copy link
Owner Author

saforem2 commented Nov 7, 2024

@sourcery-ai review

Copy link

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @saforem2 - I've reviewed your changes - here's some feedback:

Overall Comments:

  • Please add documentation describing the new Llama2 tokenizer support and training pipeline changes, including examples of usage.
Here's what I looked at during the review
  • 🟡 General issues: 3 issues found
  • 🟡 Security: 1 issue found
  • 🟢 Testing: all looks good
  • 🟡 Complexity: 3 issues found
  • 🟡 Documentation: 3 issues found

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

"lr_state_dict.yaml"
)
log.info(f"Saving lr_state_dict to {lr_state_dict_fp.as_posix()}")
with lr_state_dict_fp.open('w') as f:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Consider using atomic file writes for the lr state dict

Use atomic file operations (write to temp file and rename) to prevent corruption if the process crashes during write.

    temp_fp = lr_state_dict_fp.with_suffix('.tmp')
    with temp_fp.open('w') as f:
        yaml.dump(
            {'iteration': args.iteration, 'lr': args.lr},
            f)
    temp_fp.rename(lr_state_dict_fp)

step_size = lr
step_size_neg = step_size.neg()

ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider making epsilon configurable and potentially larger for better numerical stability

The hardcoded epsilon of 1e-15 might be too small for some use cases. Consider making this a configurable parameter with a default value of 1e-8 or similar.

            ratio = (exp_avg.abs() / (rho * bs * hess + self.eps)).clamp(None,1)

in_list = ""
for i in json_gz_files:
in_list = in_list + " " +str(i)
command = "cat" + in_list + " > " + output_file
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 suggestion (security): Consider using Python's built-in file operations instead of shell commands for better security and error handling

Shell command injection vulnerabilities could be introduced here. Consider using Python's gzip and file operations directly.

    with open(output_file, 'wb') as outfile:
        for gz_file in json_gz_files:
            with open(gz_file, 'rb') as infile:
                outfile.write(infile.read())

return json_gz_files

def combine_json_gz_files(json_gz_files, output_file):
in_list = ""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Use list for accumulating filenames instead of string concatenation

String concatenation in loops is inefficient. Consider using a list and join() at the end if needed.

    in_list = []

@@ -0,0 +1,207 @@
# Converting Checkpoints
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (documentation): Add context for checkpoint conversion

Consider adding a brief introduction explaining when and why users might need to convert checkpoints between Megatron and HuggingFace formats.

Suggested change
# Converting Checkpoints
# Converting Checkpoints
Checkpoint conversion is necessary when moving between different deep learning frameworks. This guide explains how to convert model weights between Megatron-LM and Hugging Face formats, enabling interoperability between these popular frameworks.

@@ -0,0 +1,262 @@
6.322825248625475e-06 /gila/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0000_text_document megawika
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (documentation): Dataset identifier 'megawika' differs from filename 'megawiki'

There appears to be an inconsistency between the filename (megawiki) and the dataset identifier used in the file (megawika). This should be made consistent to avoid confusion.

Comment on lines 2041 to 2215
1.7132649760565998e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0143_text_document
1.7492547092602047e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0144_text_document
1.7499951097392276e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0145_text_document
1.6632444789170958e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0146_text_document
1.6678802252361607e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0147_text_document
1.5519208704558896e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0148_text_document
1.652420992967167e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0149_text_document
1.6119931034508755e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0150_text_document
1.6638882076736552e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1.7_Llama2Tokenizer/megawika-0151_text_document
1.7198076782652946e-05 /flare/Aurora_deployment/AuroraGPT/datasets/dolma/data_v1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (documentation): Consider documenting the meaning of the floating point values in a README

While the file paths are self-explanatory, it would be helpful to have documentation explaining what the floating point values represent and how they're used in the system.

# Weights and file paths for training data sources
# Format: <weight> <file_path>
# Weights represent relative importance in training:
# - Reddit: ~0.0005-0.0006 
# - StackExchange: ~0.001
# - StarCoder: ~0.003-0.005
# - Tulu-FLAN: ~0.0002-0.0003
# - Wikipedia: ~0.003-0.004


elif args.optimizer.lower() == "galore_adamw":
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring optimizer initialization into factory functions to reduce complexity

The optimizer initialization can be simplified by extracting the creation logic into factory functions while maintaining all functionality. This reduces nesting and improves maintainability. Example:

def create_adam_optimizer(param_groups, args):
    """Factory function for Adam optimizer variants"""
    if args.ds_fused_adam:
        from deepspeed.ops.adam import FusedAdam
        adam_cls = FusedAdam
    else:
        adam_cls = torch.optim.Adam

    return adam_cls(
        param_groups,
        lr=args.lr,
        weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2),
        eps=args.adam_eps
    )

def create_adamw_optimizer(param_groups, args):
    """Factory function for AdamW optimizer"""
    return torch.optim.AdamW(
        param_groups,
        lr=args.lr,
        weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2),
        eps=args.adam_eps
    )

# In get_megatron_optimizer:
OPTIMIZER_FACTORIES = {
    'adam': create_adam_optimizer,
    'adamw': create_adamw_optimizer,
    # Add other optimizers similarly
}

# Use factory functions
if args.optimizer.lower() in OPTIMIZER_FACTORIES:
    optimizer = OPTIMIZER_FACTORIES[args.optimizer.lower()](param_groups, args)

This approach:

  • Reduces nesting depth and complexity
  • Makes it easier to add new optimizers
  • Keeps configuration explicit and debuggable
  • Maintains all existing functionality

)


def training_log(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring the logging code into a dedicated metrics collection class to improve organization and maintainability

The logging code has unnecessary complexity from mixed concerns and repeated patterns. Consider refactoring to use a metrics collection class:

class TrainingMetrics:
    def __init__(self, args, writer, wandb_enabled=False):
        self.args = args
        self.writer = writer
        self.wandb_enabled = wandb_enabled
        self.metrics = {}

    def add_scalar(self, name, value, iteration, log_samples=True, log_tokens=True):
        self.metrics[name] = value
        if self.writer:
            self.writer.add_scalar(f"{name}", value, iteration)
            if log_samples:
                self.writer.add_scalar(
                    f"{name} vs samples", value, self.args.consumed_train_samples
                )
            if log_tokens:
                self.writer.add_scalar(
                    f"{name} vs tokens", value, self.args.consumed_train_tokens
                )

    def log_loss_dict(self, loss_dict, iteration):
        for key, value in loss_dict.items():
            self.add_scalar(f"lm-loss-training/{key}", value, iteration)

    def get_wandb_metrics(self):
        return self.metrics if self.wandb_enabled else {}

Usage example:

metrics = TrainingMetrics(args, writer, wandb is not None)
metrics.add_scalar("learning-rate/learning-rate", learning_rate, iteration)
metrics.log_loss_dict(loss_dict, iteration)
wandb_metrics = metrics.get_wandb_metrics()

This refactoring:

  1. Centralizes metric collection and formatting
  2. Reduces code duplication between logging paths
  3. Makes the code more maintainable by isolating logging logic
  4. Provides a consistent interface for adding metrics


os.makedirs(args.trace_dir, exist_ok=True)

corpus_all = []
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring the code into dedicated classes for dataset configuration and performance tracing

The code mixes multiple concerns making it hard to maintain. Here's how to reduce complexity while keeping all functionality:

  1. Extract dataset configuration into a dedicated class:
class DatasetConfig:
    def __init__(self, data_file_list):
        self.files = []
        self.weights = []
        self.corpus_all = []
        self._load_config(data_file_list)

    def _load_config(self, data_file_list):
        with open(data_file_list, 'r') as fin:
            for line in fin:
                weight, fname, corpus = line.split()
                self.weights.append(float(weight))
                self.files.extend([float(weight), fname, corpus])
                if corpus not in self.corpus_all:
                    self.corpus_all.append(corpus)
        self.weights = np.array(self.weights)
        self.weights /= np.sum(self.weights)
  1. Wrap performance tracing in a context manager:
class PerformanceTracer:
    def __init__(self, args, comm):
        self.args = args
        self.comm = comm

    def __enter__(self):
        extra_path = os.getenv('DLIO_PROFILER_DATASET_DIR', '')
        trace_file = f"{self.args.trace_dir}/trace-{self.comm.rank}-of-{self.comm.size}.pfw"
        paths = f"{self.args.data_cache_path}:{extra_path}:{self.args.data_path}:{self.args.save}:{self.args.load}"
        PerfTrace.initialize_log(trace_file, paths, process_id=self.comm.rank)
        return Profile("TEST_BLENDABLEDATASET")

    def __exit__(self, *args):
        pass

This separates concerns while maintaining functionality. The main script becomes clearer:

dataset_config = DatasetConfig(args.data_file_list)
with PerformanceTracer(args, comm) as tracer:
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        dataset_config.files,
        args.data_impl,
        splits_string,
        train_valid_test_num_samples,
        args.seq_length,
        args.seed,
        not args.mmap_warmup,
        data_cache_path=args.data_cache_path)
    # Continue with data loading...

@saforem2 saforem2 closed this pull request by merging all changes into saforem2:main in 33962ee Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants