Skip to content

Commit

Permalink
fix trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
0417keito committed Dec 2, 2023
1 parent d399bde commit bbaa609
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def eval_all_tasks(self, rank=0):
avg_loss = loss_sum / num_batches[task] if num_batches[task] > 0 else 0
self.logger.info(f'Average validation loss for task {task}: {avg_loss}')
if self.rank == 0:
scalars = {f'loss/val_{task}', avg_loss}
scalars = {f'loss/val_{task}': avg_loss}
summarize(writer=self.writer, global_step=self.global_step, scalars=scalars)

avg_total_loss = total_loss / total_batches if total_batches > 0 else 0
self.logger.info(f'Average total validation loss: {avg_total_loss}')
if self.rank == 0:
scalars = {f'loss/val_total', avg_total_loss}
scalars = {'loss/val_total': avg_total_loss}
summarize(writer=self.writer, global_step=self.global_step, scalars=scalars)

self.model.train()
Expand All @@ -90,11 +90,11 @@ def eval(self, task, rank=0):
for batch_idx, (audio_emb, metadata) in enumerate(self.valid_dl):
b, _, _, device = *audio_emb.shape, self.config.device
masked_input, mask, causal = self.random_mask(audio_emb, audio_emb.shape[2], task)
conditioning = self.conditioner(metadata)
conditioning = self.conditioner(metadata, self.config.device)
conditioning['masked_input'] = masked_input
conditioning['mask'] = mask
conditioning = self.get_conditioning(conditioning)
num_timesteps = self.diffusion.steps
num_timesteps = self.diffusion.num_timesteps
t = torch.randint(0, num_timesteps, (b,), device=device).long()
with autocast(enabled=self.config.use_fp16):
loss = self.diffusion.training_loosses(self.model, audio_emb, t, conditioning, causal=causal)
Expand All @@ -120,7 +120,7 @@ def train_loop(self):
loss_dict = {}
for task in self.tasks:
loss = self.train(task=task, audio_emb=audio_emb, metadata=metadata)
weighted_loss += loss.item()
weighted_loss += loss
loss_dict[task] = loss.item()

self.optimizer.zero_grad()
Expand All @@ -133,18 +133,19 @@ def train_loop(self):

if self.rank == 0:
loss_text_guided = loss_dict['text_guided']
loss_inpaint = loss_dict['inpaint']
loss_cont = loss_dict['cont']
loss_inpaint = loss_dict['music_inpaint']
loss_cont = loss_dict['music_cont']
if self.global_step % self.config.log_interval == 0:
lr = self.optimizer.param_groups[0]['lr']
self.logger.info('Train Epoch: {}, [{:.0f}%]'.format(
epoch, 100. * batch_idx / len(self.train_dl)
))
self.logger.info(f'loss: {weighted_loss}, \
loss_text_guided: {loss_text_guided} \
loss_inpaint: {loss_inpaint} \
loss_cont: {loss_cont} \
global_step: {self.global_step}, lr:{lr}')
self.logger.info(
f'loss: {weighted_loss} '
f'loss_text_guided: {loss_text_guided} '
f'loss_inpaint: {loss_inpaint} '
f'loss_cont: {loss_cont} '
f'global_step: {self.global_step}, lr:{lr}')
scalars = {'loss/train': weighted_loss,
'loss_text_guided/train': loss_text_guided,
'loss_inpaint/train': loss_inpaint,
Expand All @@ -155,7 +156,8 @@ def train_loop(self):
self.eval_all_tasks(rank=self.rank)
save_checkpoint(model=self.model, optimizer=self.optimizer,
lr=self.config.optimizer_config.lr, iteration=epoch,
checkpoint_path=os.path.join(self.config.save_dir, f'Jen1_step_{self.global_step}.pth'))
checkpoint_path=os.path.join(self.config.save_dir, f'Jen1_step_{self.global_step}.pth'),
logger=self.logger)

self.global_step += 1

Expand All @@ -170,7 +172,6 @@ def train(self, task, audio_emb, metadata):
num_timesteps = self.diffusion.num_timesteps
t = torch.randint(0, num_timesteps, (b,), device=device).long()
with autocast(enabled=self.config.use_fp16):
print('audio_emb.shape:', audio_emb.shape)
loss = self.diffusion.training_loosses(self.model, audio_emb, t, conditioning, causal=causal)
return loss

Expand All @@ -196,7 +197,7 @@ def random_mask(self, sequence, max_mask_length, task):
elif task.lower() == 'music_cont':
mask_length = random.randint(sequence_length*0.2, sequence_length*0.8)

mask = torch.onse((1, 1, sequence_length))
mask = torch.ones((1, 1, sequence_length))
mask[:, :, -mask_length:] = 0
masks.append(mask)
causal = True
Expand Down

0 comments on commit bbaa609

Please sign in to comment.