Skip to content

Commit

Permalink
add option for stochastic smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
LVeefkind committed Sep 25, 2024
1 parent dbd36c3 commit 2c4e5fc
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions neural_networks/train_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def main(
batch_size: int,
use_compile: bool,
label_smoothing: float,
stochastic_smoothing: bool,
lift: str,
):
torch.set_float32_matmul_precision("high")
Expand Down Expand Up @@ -424,6 +425,7 @@ def main(
optimizer=optimizer,
logging_interval=logging_interval,
label_smoothing=label_smoothing,
stochastic_smoothing=stochastic_smoothing,
)
val_step_f = partial(val_step_f, val_dataloader=val_dataloader)

Expand Down Expand Up @@ -531,6 +533,7 @@ def train_step(
logging_interval,
metrics_logger,
label_smoothing=0,
stochastic_smoothing=False,
):
# print("training")
model.train()
Expand All @@ -541,8 +544,11 @@ def train_step(
global_step += 1

data, labels = prepare_data_f(data, labels)
smoothed_label = (1 - label_smoothing) * labels + 0.5 * label_smoothing

# Stochastic smoothing factor
smoothing_factor = label_smoothing - (
torch.rand_like(labels) * label_smoothing * stochastic_smoothing
)
smoothed_label = (1 - smoothing_factor) * labels + 0.5 * smoothing_factor
data = augmentation(data)

optimizer.zero_grad(set_to_none=True)
Expand Down Expand Up @@ -729,6 +735,12 @@ def get_argparser():
help="Label smoothing factor",
)

parser.add_argument(
"--stochastic_smoothing",
action="store_true",
help="use stochastic smoothing",
)

parser.add_argument(
"--lift",
type=str,
Expand All @@ -755,7 +767,7 @@ def sanity_check_args(parsed_args):
print("Setting resize to 512 since vit_16_l is being used")
parsed_args.resize = 512
if parsed_args.model_name == "dino_v2" and parsed_args.resize == 0:
resize = 784
resize = 504
print(f"\n#######\nSetting resize to {resize} \n######\n")
parsed_args.resize = resize
assert parsed_args.resize % 14 == 0 or parsed_args.model_name != "dino_v2"
Expand Down

0 comments on commit 2c4e5fc

Please sign in to comment.