Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Oct 17, 2024
1 parent f574bf0 commit 7ca8365
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 24 deletions.
3 changes: 2 additions & 1 deletion algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def sync_ddp_time(time: float, device: torch.device) -> float:
dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX)
return time_tensor.item()


def update_batch_norm_fn(module: spec.ParameterContainer,
update_batch_norm: bool) -> None:
bn_layers = (
Expand All @@ -75,4 +76,4 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
else:
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup
module.momentum = module.momentum_backup
6 changes: 3 additions & 3 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def __call__(self,
update_batch_norm: bool = True,
use_running_average_bn: bool = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
# Preserve default behavior for backwards compatibility

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=use_running_average_bn,
Expand Down
3 changes: 2 additions & 1 deletion algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def __call__(self,
use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=use_running_average_bn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,15 +454,19 @@ def setup(self):
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)

@nn.compact
def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn):
def __call__(self,
inputs,
input_paddings,
update_batch_norm,
use_running_average_bn):
rank = inputs.ndim
reduce_over_dims = list(range(0, rank - 1))

padding = jnp.expand_dims(input_paddings, -1)
momentum = self.config.batch_norm_momentum
epsilon = self.config.batch_norm_epsilon

if use_running_average_bn:
if use_running_average_bn:
mean = self.ra_mean.value
var = self.ra_var.value

Expand All @@ -482,13 +486,13 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag
keepdims=True)

var = sum_vv / count_v

if update_batch_norm:
self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var

inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
bn_output = (inputs - mean) * inv + self.beta
bn_output *= 1.0 - padding
Expand Down Expand Up @@ -519,7 +523,12 @@ class ConvolutionBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average_bn):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)

Expand Down Expand Up @@ -548,7 +557,10 @@ def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running
kernel_init=nn.initializers.xavier_uniform())(
inputs)

inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn)
inputs = BatchNorm(config)(inputs,
input_paddings,
update_batch_norm,
use_running_average_bn)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
Expand Down Expand Up @@ -588,7 +600,12 @@ class ConformerBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)

Expand Down Expand Up @@ -631,12 +648,12 @@ def setup(self):
.use_dynamic_time_mask_max_frames)

@nn.compact
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None):
config = self.config

outputs = inputs
Expand Down Expand Up @@ -673,7 +690,11 @@ def __call__(self,

# Run the conformer encoder layers.
for _ in range(config.num_encoder_layers):
outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn)
outputs = ConformerBlock(config)(outputs,
output_paddings,
train,
update_batch_norm,
use_running_average_bn)

outputs = LayerNorm(config.encoder_dim)(outputs)
# Run the decoder which in this case is a trivial projection layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def forward(self, inputs, input_paddings):
self.momentum) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
self.momentum) * var.detach()

else:
mean = self.running_mean
var = self.running_var
Expand Down

0 comments on commit 7ca8365

Please sign in to comment.