From 7a945458b687173fda5c9c539b97722b3c423b65 Mon Sep 17 00:00:00 2001 From: Noah Bryan Date: Tue, 26 Nov 2024 02:15:52 -0500 Subject: [PATCH] hparam tweaks for final output --- magenta/models/coconet/coconet_train.py | 2 +- magenta/models/coconet/lib_graph.py | 4 ++-- magenta/models/coconet/sample_bazel.sh | 2 +- magenta/models/coconet/train_bazel.sh | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/magenta/models/coconet/coconet_train.py b/magenta/models/coconet/coconet_train.py index 2f376fc2..09e174b3 100644 --- a/magenta/models/coconet/coconet_train.py +++ b/magenta/models/coconet/coconet_train.py @@ -113,7 +113,7 @@ flags.DEFINE_float('corrupt_ratio', 0.5, 'Fraction of variables to mask out.') # Run parameters. -flags.DEFINE_integer('num_epochs', 50, +flags.DEFINE_integer('num_epochs', 100, 'The number of epochs to train the model. Default ' 'is 0, which means to run until terminated ' 'manually.') diff --git a/magenta/models/coconet/lib_graph.py b/magenta/models/coconet/lib_graph.py index 35d48fdb..6f279fe6 100644 --- a/magenta/models/coconet/lib_graph.py +++ b/magenta/models/coconet/lib_graph.py @@ -112,8 +112,8 @@ def build(self): logits=self.logits, labels=self.pianorolls) self.compute_loss( - self.cpep_calculator.calculate_voice_range_penalty(self.predictions) + - self.cpep_calculator.calculate_kernel_penalty(self.predictions) + + self.cpep_calculator.calculate_voice_range_penalty(self.predictions) + + self.cpep_calculator.calculate_kernel_penalty(self.predictions) + self.cpep_calculator.calculate_parallel_perfect_penalty(self.predictions) + self.cross_entropy) self.setup_optimizer() diff --git a/magenta/models/coconet/sample_bazel.sh b/magenta/models/coconet/sample_bazel.sh index ec72b664..180a01e3 100755 --- a/magenta/models/coconet/sample_bazel.sh +++ b/magenta/models/coconet/sample_bazel.sh @@ -29,7 +29,7 @@ generation_output_dir=$HOME/samples # Generation parameters. # Number of samples to generate in a batch. -gen_batch_size=10 +gen_batch_size=20 piece_length=64 strategy=igibbs tfsample=true diff --git a/magenta/models/coconet/train_bazel.sh b/magenta/models/coconet/train_bazel.sh index af679eda..3c11063f 100755 --- a/magenta/models/coconet/train_bazel.sh +++ b/magenta/models/coconet/train_bazel.sh @@ -25,7 +25,7 @@ data_dir=$HOME/JSB-Chorales-dataset/ dataset=Jsb16thSeparated # Data preprocessing. -crop_piece_len=48 +crop_piece_len=64 separate_instruments=True quantization_level=0.125 # 16th notes