Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow benchmark between multiple configs #703

Open
wants to merge 1 commit into
base: gh/H-Huang/17/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions metric_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from typing import Any, Dict, Sequence

from torchtitan.metrics import MetricRetriever

metrics: Dict[str, Dict[int, Dict[str, Any]]] = {}

run_id = 0
dump_dir = f"test_out/my_example"
log_dir = os.path.join(dump_dir, "tb", str(run_id))
print(log_dir)
metric_retriever = MetricRetriever(log_dir)

metrics[f"Run ID {run_id}"] = metric_retriever.get_metrics()


def print_metrics(
metrics: Dict[str, Dict[int, Dict[str, Any]]],
filter_keys=[
"wps",
"mfu(%)",
"memory/max_active(GiB)",
"memory/max_active(%)",
"memory/max_reserved(%)",
"loss_metrics/global_avg_loss",
"loss_metrics/global_max_loss",
],
) -> None:
for run_id, all_step_metrics in metrics.items():
print("=" * 100)
print(run_id)
print("=" * 100)
if all_step_metrics:
last_step = next(reversed(all_step_metrics))
last_step_metrics = all_step_metrics[last_step]
# Print the column headers
if filter_keys:
filtered_keys = [key for key in filter_keys if key in last_step_metrics]
else:
filtered_keys = list(last_step_metrics.keys())

max_key_length = max(len(key) for key in filtered_keys)
# Add an empty header for the run_id column
header_row = " | ".join(
[" " * 10] + [f"{key.ljust(max_key_length)}" for key in filtered_keys]
)
print(header_row)
print("-" * len(header_row))
# Print the run_id and the values
value_row = " | ".join(
[f"{run_id:10}"]
+ [
f"{str(last_step_metrics[key]).ljust(max_key_length)}"
for key in filtered_keys
]
)
print(value_row)


print_metrics(metrics)
Binary file not shown.
Binary file not shown.
Binary file not shown.
89 changes: 86 additions & 3 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import subprocess
from collections import defaultdict
from dataclasses import dataclass
from typing import Sequence
from typing import Any, Dict, Sequence
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE the binary files above, not sure if you meant to include those in the PR


from torchtitan.metrics import MetricRetriever

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -362,6 +364,26 @@ def build_test_list():
"enable_cpu_offload+PP",
ngpu=4,
),
OverrideDefinitions(
[
[
"--experimental.pipeline_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--metrics.enable_tensorboard",
],
[
"--training.data_parallel_shard_degree 4",
"--metrics.enable_tensorboard",
],
[
"--training.tensor_parallel_degree 4",
"--metrics.enable_tensorboard",
],
],
"example",
"my_example",
ngpu=4,
),
]
return integration_tests_flavors

Expand All @@ -376,10 +398,55 @@ def _run_cmd(cmd):
)


def print_metrics(
metrics: Dict[str, Dict[int, Dict[str, Any]]],
run_id_to_args: Dict[str, Sequence[str]],
filter_keys=[
"wps",
"mfu(%)",
"memory/max_active(GiB)",
"memory/max_active(%)",
"memory/max_reserved(%)",
"loss_metrics/global_avg_loss",
"loss_metrics/global_max_loss",
],
) -> None:
for run_id, args in run_id_to_args.items():
print(f"Run ID: {run_id}, args: {args}")

for run_id, all_step_metrics in metrics.items():
if all_step_metrics:
last_step = next(reversed(all_step_metrics))
last_step_metrics = all_step_metrics[last_step]
# Print the column headers
if filter_keys:
filtered_keys = [key for key in filter_keys if key in last_step_metrics]
else:
filtered_keys = list(last_step_metrics.keys())

max_key_length = max(len(key) for key in filtered_keys)
# Add an empty header for the run_id column
header_row = " | ".join(
[" " * 10] + [f"{key.ljust(max_key_length)}" for key in filtered_keys]
)
print(header_row)
print("-" * len(header_row))
# Print the run_id and the values
value_row = " | ".join(
[f"{run_id:10}"]
+ [
f"{str(last_step_metrics[key]).ljust(max_key_length)}"
for key in filtered_keys
]
)
print(value_row)


def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
# run_test supports sequence of tests.
test_name = test_flavor.test_name
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"
dump_dir = f"{output_dir}/{test_name}"
dump_folder_arg = f"--job.dump_folder {dump_dir}"
model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))

Expand All @@ -391,12 +458,18 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
result = _run_cmd(cmd)
logger.info(result.stdout)

for override_arg in test_flavor.override_args:
# Store all metrics here
metrics: Dict[str, Dict[int, Dict[str, Any]]] = {}
run_id_to_args: Dict[str, Sequence[str]] = {}
for run_id, override_arg in enumerate(test_flavor.override_args):
run_id_arg = f"--metrics.run_id_folder {run_id}"

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
if test_name == "fsdp2_mem_tracker":
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh"
cmd += " " + dump_folder_arg
cmd += " " + model_flavor_arg
cmd += " " + run_id_arg
if override_arg:
cmd += " " + " ".join(override_arg)
logger.info(
Expand All @@ -409,6 +482,15 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}"
)

print("=" * 100)
print(cmd)
log_dir = os.path.join(dump_dir, "tb", str(run_id))
print(log_dir)
metric_retriever = MetricRetriever(log_dir)
metrics[str(run_id)] = metric_retriever.get_metrics()
run_id_to_args[str(run_id)] = override_arg
print_metrics(metrics, run_id_to_args)


def run_tests(args):
integration_tests_flavors = build_test_list()
Expand Down Expand Up @@ -447,6 +529,7 @@ def main():
)
parser.add_argument("--ngpu", default=4, type=int)
args = parser.parse_args()
print(args)

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ def __init__(self):
which is the only stage that computes loss metrics.
""",
)
self.parser.add_argument(
"--metrics.run_id_folder",
type=str,
default="",
help="Subfolder to store TensorBoard runs. This is used to identify metrics between runs.",
)

# model configs
self.parser.add_argument(
Expand Down
52 changes: 40 additions & 12 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
# LICENSE file in the root directory of this source tree.

import os
from collections import namedtuple
from collections import namedtuple, OrderedDict
from datetime import datetime
from typing import Any, Dict, Optional

import torch

from tensorboard.backend.event_processing import event_accumulator
from torch.utils.tensorboard import SummaryWriter
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
Expand Down Expand Up @@ -93,9 +95,32 @@ def build_gpu_memory_monitor():
return gpu_memory_monitor


class MetricRetriever:

def __init__(self, log_dir: str):
self.log_dir = log_dir

def get_metrics(self) -> Dict[int, Dict[str, Any]]:
# Initialize an EventAccumulator to read the event files
ea = event_accumulator.EventAccumulator(self.log_dir)
ea.Reload() # Load the event files
# Extract scalar data
metrics = {}
for tag in ea.Tags()["scalars"]:
events = ea.Scalars(tag)
for event in events:
step = event.step
if step not in metrics:
metrics[step] = {}
metrics[step][tag] = event.value
return metrics


class MetricLogger:
def __init__(self, log_dir, tag, enable_tb):
self.tag = tag
self.log_dir = log_dir
print(f"!!!! Tensorboard log dir: {self.log_dir}")
self.writer: Optional[SummaryWriter] = None
if enable_tb:
self.writer = SummaryWriter(log_dir, max_queue=1000)
Expand All @@ -111,24 +136,23 @@ def close(self):
self.writer.close()


def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
def _get_metrics_rank(job_config: JobConfig) -> int:
"""
Returns global rank 0 in non-pipeline-parallel configs, and returns the global
rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled.
"""
if parallel_dims.pp_enabled:
world_size = parallel_dims.world_size
pp_size = parallel_dims.pp
pp_size = job_config.experimental.pipeline_parallel_degree
pp_enabled = pp_size > 1
if pp_enabled:
world_size = int(os.environ["WORLD_SIZE"])
metrics_log_rank = (world_size // pp_size) * (pp_size - 1)
else:
metrics_log_rank = 0

return metrics_log_rank


def build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
):
def build_metric_logger(job_config: JobConfig, tag: Optional[str] = None):
"""
parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is
Expand All @@ -138,17 +162,21 @@ def build_metric_logger(
dump_dir = job_config.job.dump_folder
tb_config = job_config.metrics
save_tb_folder = tb_config.save_tb_folder
# since we don't have run id, use current minute as the identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)
# if we don't have run id, use current minute as the identifier
run_id_folder = (
datetime.now().strftime("%Y%m%d-%H%M")
if not tb_config.run_id_folder
else tb_config.run_id_folder
)
log_dir = os.path.join(dump_dir, save_tb_folder, run_id_folder)

enable_tb = tb_config.enable_tensorboard
if enable_tb:
logger.info(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
)
if tb_config.rank_0_only:
enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)
enable_tb = torch.distributed.get_rank() == _get_metrics_rank(job_config)
else:
rank_str = f"rank_{torch.distributed.get_rank()}"
log_dir = os.path.join(log_dir, rank_str)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def loss_fn(pred, labels):
"All the substages will be initialized with random weights with same RNG state which can affect convergence."
)

metric_logger = build_metric_logger(job_config, parallel_dims)
metric_logger = build_metric_logger(job_config)

# plot losses loaded from checkpoint (if any) to TensorBoard
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
Expand Down
Loading