Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes for eval #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,13 +844,6 @@ def evaluate(forward_step_func,
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches()

# Sum LBLs across pipeline-model-parallel shards.
if args.model_type == ModelType.encoder_or_decoder_with_lbl:
assert "load balancing loss" in total_loss_dict
torch.distributed.all_reduce(
total_loss_dict["load balancing loss"],
group=mpu.get_pipeline_model_parallel_group())

return total_loss_dict, collected_non_loss_data

def evaluate_and_print_results(prefix, forward_step_func,
Expand Down
2 changes: 1 addition & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def forward_step(data_iterator, model):
labels=labels)

loss_fn = (
moe_loss_func if args.moe_num_experts is not None else loss_func)
moe_loss_func if args.moe_num_experts is not None and model.training else loss_func)
return output_tensor, partial(loss_fn, loss_mask)

def train_valid_test_datasets_provider(train_val_test_num_samples):
Expand Down