From 2c4e5fc1f4a502a7e79c1b2cac46b26d51e8ef7b Mon Sep 17 00:00:00 2001 From: LVeefkind Date: Wed, 25 Sep 2024 13:13:44 +0200 Subject: [PATCH] add option for stochastic smoothing --- neural_networks/train_nn.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/neural_networks/train_nn.py b/neural_networks/train_nn.py index 37f26f73..3c9a8e73 100644 --- a/neural_networks/train_nn.py +++ b/neural_networks/train_nn.py @@ -350,6 +350,7 @@ def main( batch_size: int, use_compile: bool, label_smoothing: float, + stochastic_smoothing: bool, lift: str, ): torch.set_float32_matmul_precision("high") @@ -424,6 +425,7 @@ def main( optimizer=optimizer, logging_interval=logging_interval, label_smoothing=label_smoothing, + stochastic_smoothing=stochastic_smoothing, ) val_step_f = partial(val_step_f, val_dataloader=val_dataloader) @@ -531,6 +533,7 @@ def train_step( logging_interval, metrics_logger, label_smoothing=0, + stochastic_smoothing=False, ): # print("training") model.train() @@ -541,8 +544,11 @@ def train_step( global_step += 1 data, labels = prepare_data_f(data, labels) - smoothed_label = (1 - label_smoothing) * labels + 0.5 * label_smoothing - + # Stochastic smoothing factor + smoothing_factor = label_smoothing - ( + torch.rand_like(labels) * label_smoothing * stochastic_smoothing + ) + smoothed_label = (1 - smoothing_factor) * labels + 0.5 * smoothing_factor data = augmentation(data) optimizer.zero_grad(set_to_none=True) @@ -729,6 +735,12 @@ def get_argparser(): help="Label smoothing factor", ) + parser.add_argument( + "--stochastic_smoothing", + action="store_true", + help="use stochastic smoothing", + ) + parser.add_argument( "--lift", type=str, @@ -755,7 +767,7 @@ def sanity_check_args(parsed_args): print("Setting resize to 512 since vit_16_l is being used") parsed_args.resize = 512 if parsed_args.model_name == "dino_v2" and parsed_args.resize == 0: - resize = 784 + resize = 504 print(f"\n#######\nSetting resize to {resize} \n######\n") parsed_args.resize = resize assert parsed_args.resize % 14 == 0 or parsed_args.model_name != "dino_v2"