Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev -> main #812

Merged
merged 26 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b24812f
BN Fixes
adefazio Sep 5, 2024
e09bbf5
add `train_state` to all instances of `update_params', passing it by …
Niccolo-Ajroldi Oct 3, 2024
f15e227
update DOCS
Niccolo-Ajroldi Oct 3, 2024
107c6b6
update test
Niccolo-Ajroldi Oct 3, 2024
d4ad0eb
fix linting
Niccolo-Ajroldi Oct 3, 2024
cb7e162
fix isort
Niccolo-Ajroldi Oct 3, 2024
1f59285
fix import sort
Niccolo-Ajroldi Oct 3, 2024
f574bf0
add use_running_average_bn arg for jax
priyakasimbeg Oct 17, 2024
7ca8365
formatting
priyakasimbeg Oct 17, 2024
3913238
formatting
priyakasimbeg Oct 17, 2024
baac0a4
formatting
priyakasimbeg Oct 17, 2024
087fd5c
debugging
priyakasimbeg Oct 17, 2024
c5c36c2
add seperate model_fn for deepspeech jax without use_running_average_bn
priyakasimbeg Oct 18, 2024
783aab4
fix syntax error
priyakasimbeg Oct 18, 2024
28e7e21
fix
priyakasimbeg Oct 18, 2024
b063f9f
fix import order
priyakasimbeg Oct 18, 2024
894cd87
formatting
priyakasimbeg Oct 18, 2024
ce8eb18
ensure backward compatibility
Niccolo-Ajroldi Oct 25, 2024
5a06a0d
adding train_state to all submissions
Niccolo-Ajroldi Oct 25, 2024
86114ef
fix missing import Optional
Niccolo-Ajroldi Oct 25, 2024
1965241
fix yapf
Niccolo-Ajroldi Oct 25, 2024
7a3710f
Merge pull request #790 from Niccolo-Ajroldi/pass_train_state
priyakasimbeg Oct 29, 2024
787f7fb
Merge pull request #798 from mlcommons/bn_fixes_clean
priyakasimbeg Oct 29, 2024
f4c17f0
Update README.md
priyakasimbeg Nov 21, 2024
39e2e1c
Merge pull request #814 from mlcommons/priyakasimbeg-patch-4
priyakasimbeg Nov 21, 2024
5f6a2ff
Merge pull request #818 from mlcommons/main
priyakasimbeg Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def update_params(
batch: Dict[str, Tensor],
loss_type: LossType,
optimizer_state: OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState
Expand All @@ -212,6 +213,7 @@ def update_params(
- The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used.
- Allowed to update state for the optimizer.
- Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state).
- The submission can access the elapsed training time and get further information about the evaluation through `train_state`.
- The submission can access the target evaluation metric via the `workload` variable.
- **A call to this function will be considered a step**
- The time between a call to this function and the next call to this function will be considered the per-step time.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ python3 submission_runner.py \
--workload=mnist \
--experiment_dir=$HOME/experiments \
--experiment_name=my_first_experiment \
--submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \
--submission_path=reference_algorithms/paper_baselines/adamw/pytorch/submission.py \
--tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json
```

Expand Down
11 changes: 7 additions & 4 deletions algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
)
if isinstance(module, bn_layers):
if not update_batch_norm:
module.eval()
module.momentum_backup = module.momentum
if not hasattr(module, 'momentum_backup'):
module.momentum_backup = module.momentum

# module.momentum can be float or torch.Tensor.
module.momentum = 0. * module.momentum_backup
if torch.is_tensor(module.momentum_backup):
module.momentum = torch.zeros_like(module.momentum_backup)
else:
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup
module.track_running_stats = update_batch_norm
6 changes: 4 additions & 2 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ def init_optimizer_state(workload: Workload,
OptimizerState,
List[Tuple[int, float]],
int,
RandomState
RandomState,
Optional[Dict[str, Any]]
],
UpdateReturn]

Expand All @@ -424,7 +425,8 @@ def update_params(workload: Workload,
optimizer_state: OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState) -> UpdateReturn:
rng: RandomState,
train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass

Expand Down
9 changes: 7 additions & 2 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ class ResNet(nn.Module):
@nn.compact
def __call__(self,
x: spec.Tensor,
update_batch_norm: bool = True) -> spec.Tensor:
update_batch_norm: bool = True,
use_running_average_bn: bool = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
use_running_average=use_running_average_bn,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
Expand Down
10 changes: 7 additions & 3 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand All @@ -119,14 +121,16 @@ def model_fn(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,16 @@ class ResNet(nn.Module):
@nn.compact
def __call__(self,
x: spec.Tensor,
update_batch_norm: bool = True) -> spec.Tensor:
update_batch_norm: bool = True,
use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
use_running_average=use_running_average_bn,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand All @@ -157,14 +159,16 @@ def model_fn(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,15 +454,24 @@ def setup(self):
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
update_batch_norm,
use_running_average_bn):
rank = inputs.ndim
reduce_over_dims = list(range(0, rank - 1))

padding = jnp.expand_dims(input_paddings, -1)
momentum = self.config.batch_norm_momentum
epsilon = self.config.batch_norm_epsilon

if train:
if use_running_average_bn:
mean = self.ra_mean.value
var = self.ra_var.value

else:
# compute batch statistics
mask = 1.0 - padding
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
count_v = jnp.sum(
Expand All @@ -478,16 +487,13 @@ def __call__(self, inputs, input_paddings, train):

var = sum_vv / count_v

self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var
else:
mean = self.ra_mean.value
var = self.ra_var.value
if update_batch_norm:
self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var

inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)

bn_output = (inputs - mean) * inv + self.beta
bn_output *= 1.0 - padding

Expand Down Expand Up @@ -517,7 +523,12 @@ class ConvolutionBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average_bn):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)

Expand Down Expand Up @@ -546,7 +557,10 @@ def __call__(self, inputs, input_paddings, train):
kernel_init=nn.initializers.xavier_uniform())(
inputs)

inputs = BatchNorm(config)(inputs, input_paddings, train)
inputs = BatchNorm(config)(inputs,
input_paddings,
update_batch_norm,
use_running_average_bn)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
Expand Down Expand Up @@ -586,7 +600,12 @@ class ConformerBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)

Expand All @@ -597,7 +616,12 @@ def __call__(self, inputs, input_paddings, train):
inputs, input_paddings, train)

inputs = inputs + \
ConvolutionBlock(config)(inputs, input_paddings, train)
ConvolutionBlock(config)(inputs,
input_paddings,
train,
update_batch_norm,
use_running_average
)

inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
inputs, padding_mask, train)
Expand Down Expand Up @@ -629,12 +653,23 @@ def setup(self):
.use_dynamic_time_mask_max_frames)

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None):
config = self.config

outputs = inputs
output_paddings = input_paddings

# Set BN args if not supplied for backwards compatibility
if update_batch_norm is None:
update_batch_norm = train
if use_running_average_bn is None:
use_running_average_bn = not train

# Compute normalized log mel spectrograms from input audio signal.
preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
Expand All @@ -660,7 +695,11 @@ def __call__(self, inputs, input_paddings, train):

# Run the conformer encoder layers.
for _ in range(config.num_encoder_layers):
outputs = ConformerBlock(config)(outputs, output_paddings, train)
outputs = ConformerBlock(config)(outputs,
output_paddings,
train,
update_batch_norm,
use_running_average_bn)

outputs = LayerNorm(config.encoder_dim)(outputs)
# Run the decoder which in this case is a trivial projection layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
Expand All @@ -118,15 +120,17 @@ def model_fn(
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
variables,
inputs,
input_paddings,
train=False,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return (logits, logit_paddings), model_state

def _build_input_queue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ConformerConfig:
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
input_dropout_rate: float = 0.1
batch_norm_momentum: float = 0.999
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
attention_temperature: float = 1.0
Expand Down Expand Up @@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings):
mean = (masked_inp).sum(dim=(0, 1)) / count
var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count

self.running_mean = self.momentum * self.running_mean + (
1 - self.momentum) * mean.detach()
self.running_var = self.momentum * self.running_var + (
1 - self.momentum) * var.detach()
self.running_mean = (1 - self.momentum) * self.running_mean + (
self.momentum) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
self.momentum) * var.detach()

else:
mean = self.running_mean
var = self.running_var
Expand Down
Loading
Loading