This project provides a flexible template for PyTorch-based machine learning experiments. It includes configuration management, logging with Weights & Biases (wandb), hyperparameter optimization with Optuna, and a modular structure for easy customization and experimentation.
config.py
: Defines theRunConfig
andOptimizeConfig
classes for managing experiment configurations and optimization settings.main.py
: The entry point of the project, handling command-line arguments and experiment execution.model.py
: Contains the model architecture (currently an MLP).util.py
: Utility functions for data loading, device selection, training, and analysis.configs/run_template.yaml
: Template for run configuration.configs/optimize_template.yaml
: Template for optimization configuration.analyze.py
: Script for analyzing completed runs and optimizations, utilizing functions fromutil.py
.
-
Clone the repository:
git clone https://github.com/yourusername/pytorch_template.git cd pytorch_template
-
Install the required packages:
# Use pip pip install torch wandb survey polars numpy optuna matplotlib scienceplots # Or Use uv with sync requirements.txt (recommended) uv pip sync requirements.txt # Or Use uv (fresh install) uv pip install -U torch wandb survey polars numpy optuna matplotlib scienceplots
-
(Optional) Set up a Weights & Biases account for experiment tracking.
-
Configure your experiment by modifying
configs/run_template.yaml
or creating a new YAML file based on it. -
(Optional) Configure hyperparameter optimization by modifying
configs/optimize_template.yaml
or creating a new YAML file based on it. -
Run the experiment:
python main.py --run_config path/to/run_config.yaml [--optimize_config path/to/optimize_config.yaml]
If
--optimize_config
is provided, the script will perform hyperparameter optimization using Optuna. -
Analyze the results:
python analyze.py
project
: Project name for wandb loggingdevice
: Device to run on (e.g., 'cpu', 'cuda:0')net
: Model class to useoptimizer
: Optimizer classscheduler
: Learning rate scheduler classepochs
: Number of training epochsbatch_size
: Batch size for trainingseeds
: List of random seeds for multiple runsnet_config
: Model-specific configurationoptimizer_config
: Optimizer-specific configurationscheduler_config
: Scheduler-specific configuration
study_name
: Name of the optimization studytrials
: Number of optimization trialsseed
: Random seed for optimizationmetric
: Metric to optimizedirection
: Direction of optimization ('minimize' or 'maximize')sampler
: Optuna sampler configurationpruner
: (Optional) Optuna pruner configurationsearch_space
: Definition of the hyperparameter search space
-
Custom model: Modify or add models in
model.py
. Models should accept ahparams
argument as a dictionary, with keys matching thenet_config
parameters in the run configuration YAML file. -
Custom data: Modify the
load_data
function inutil.py
. The current example uses Cosine regression. Theload_data
function should return train and validation datasets compatible with PyTorch's DataLoader. -
Custom training: Customize the
Trainer
class inutil.py
by modifyingstep
,train_epoch
,val_epoch
, andtrain
methods to suit your task. Ensure thattrain
returnsval_loss
or a custom metric for proper hyperparameter optimization.
- Configurable experiments using YAML files
- Integration with Weights & Biases for experiment tracking
- Hyperparameter optimization using Optuna
- Support for multiple random seeds
- Flexible model architecture (currently MLP)
- Device selection (CPU/CUDA)
- Learning rate scheduling
- Analysis tools for completed runs and optimizations
The analyze.py
script utilizes functions from util.py
to analyze completed runs and optimizations. Key functions include:
select_group
: Select a run group for analysisselect_seed
: Select a specific seed from a run groupselect_device
: Choose a device for analysisload_model
: Load a trained model and its configurationload_study
: Load an Optuna studyload_best_model
: Load the best model from an optimization study
These functions are defined in util.py
and used within analyze.py
.
To use the analysis tools:
-
Run the
analyze.py
script:python analyze.py
-
Follow the prompts to select the project, run group, and seed (if applicable).
-
The script will load the selected model and perform basic analysis, such as calculating the validation loss.
-
You can extend the
main()
function inanalyze.py
to add custom analysis as needed, utilizing the utility functions fromutil.py
.
Contributions are welcome! Please feel free to submit a Pull Request.
This project is provided as a template and is intended to be freely used, modified, and distributed. Users of this template are encouraged to choose a license that best suits their specific project needs.
For the template itself:
- You are free to use, modify, and distribute this template.
- No attribution is required, although it is appreciated.
- The template is provided "as is", without warranty of any kind.
When using this template for your own project, please remember to:
- Remove this license section or replace it with your chosen license.
- Ensure all dependencies and libraries used in your project comply with their respective licenses.
For more information on choosing a license, visit choosealicense.com.