Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Deprecated Functions #827

Open
wants to merge 2 commits into
base: python_upgrades
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _prepare(x):
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1, *x.shape[1:]))

return jax.tree_map(_prepare, batch)
return jax.tree.map(_prepare, batch)


def pad(tensor: np.ndarray,
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def pytorch_param_types(

def jax_param_shapes(
params: spec.ParameterContainer) -> spec.ParameterShapeTree:
return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params)
return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params)


def jax_param_types(param_shapes: spec.ParameterShapeTree,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,4 @@ def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _eval_model_on_split(self,
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value

eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples),
eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples),
eval_metrics)
return eval_metrics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Encoder1DBlock(nn.Module):
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.SelfAttention(
y = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
Expand All @@ -89,7 +89,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
x = x + y
else:
y = x
y = nn.SelfAttention(
y = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,9 @@ def __call__(self, inputs, paddings, train):
mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32)

inputs = LayerNorm(dim=config.encoder_dim)(inputs)

attention_fn = functools.partial(
dot_product_attention, temperature=config.attention_temperature)
result = nn.SelfAttention(
result = nn.MultiHeadDotProductAttention(
num_heads=config.num_attention_heads,
qkv_features=config.encoder_dim,
decode=False,
Expand All @@ -410,7 +409,8 @@ def __call__(self, inputs, paddings, train):
broadcast_dropout=False,
attention_fn=attention_fn,
dropout_rate=config.attention_dropout_rate,
deterministic=not train)(inputs, attention_mask)
deterministic=not train)(
inputs_q=inputs, mask=attention_mask)

if config.attention_residual_dropout_rate is None:
attention_residual_dropout_rate = 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,12 @@ def ctc_loss(self,
labels: spec.Tensor,
label_paddings: spec.Tensor,
blank_id: int = 0) -> spec.Tensor:
return optax.ctc_loss(logits,
logit_paddings,
labels,
label_paddings,
blank_id)
return optax.ctc_loss(
logits=logits,
logit_paddings=logit_paddings,
labels=labels,
label_paddings=label_paddings,
blank_id=blank_id)

# Adapted from lingvo's greedy decoding logic here:
# https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/ogbg/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir):

def _to_jraph(example):
"""Converts an example graph to jraph.GraphsTuple."""
example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access
example = jax.tree.map(lambda x: x._numpy(), example) # pylint: disable=protected-access
edge_feat = example['edge_feat']
node_feat = example['node_feat']
edge_index = example['edge_index']
Expand Down Expand Up @@ -150,7 +150,7 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
if count == num_shards:

def f(x):
return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:])
return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:])

graphs_shards = f(graphs_shards)
labels_shards = f(labels_shards)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

def _pytorch_map(inputs: Any) -> Any:
if USE_PYTORCH_DDP:
return jax.tree_map(lambda a: torch.as_tensor(a, device=DEVICE), inputs)
return jax.tree_map(
return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs)
return jax.tree.map(
lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1])
if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1),
inputs)
Expand All @@ -30,7 +30,7 @@ def _pytorch_map(inputs: Any) -> Any:
def _shard(inputs: Any) -> Any:
if not USE_PYTORCH_DDP:
return inputs
return jax.tree_map(lambda tensor: tensor[RANK], inputs)
return jax.tree.map(lambda tensor: tensor[RANK], inputs)


def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple:
Expand Down
8 changes: 4 additions & 4 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def gather_fn(x):
return x
return x[batch_indices, beam_indices]

return jax.tree_map(gather_fn, nested)
return jax.tree.map(gather_fn, nested)


def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
Expand Down Expand Up @@ -139,7 +139,7 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
Expand Down Expand Up @@ -225,7 +225,7 @@ def beam_search_loop_body_fn(state):
(batch_size, beam_size, 1)))
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = jax.tree_map(flatten_beam_dim, state.cache)
flat_cache = jax.tree.map(flatten_beam_dim, state.cache)

# Call fast-decoder model on current tokens to get next-position logits.
# --> [batch * beam, vocab]
Expand All @@ -236,7 +236,7 @@ def beam_search_loop_body_fn(state):
logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = jax.tree_map(
new_cache = jax.tree.map(
lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)

# Gather log probabilities from logits
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def eval_step(self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
replicated_eval_metrics = self.eval_step_pmapped(params, batch)
return jax.tree_map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics)
return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics)

@functools.partial(
jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,))
Expand Down Expand Up @@ -291,7 +291,7 @@ def _normalize_eval_metrics(
"""Normalize eval metrics."""
del num_examples
eval_denominator = total_metrics.pop('denominator')
return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics)
return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics)


class WmtWorkloadPostLN(WmtWorkload):
Expand Down
8 changes: 4 additions & 4 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def gather_fn(x):
return x
return x[batch_indices, beam_indices]

return jax.tree_map(gather_fn, nested)
return jax.tree.map(gather_fn, nested)


def gather_topk_beams(nested: Dict[str, Any],
Expand Down Expand Up @@ -164,7 +164,7 @@ def beam_init(batch_size: int,
dtype=torch.bool,
device=DEVICE)
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
Expand Down Expand Up @@ -251,7 +251,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1])
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = jax.tree_map(flatten_beam_dim, state.cache)
flat_cache = jax.tree.map(flatten_beam_dim, state.cache)

# Call fast-decoder model on current tokens to get next-position logits.
# --> [batch * beam, vocab]
Expand All @@ -262,7 +262,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = jax.tree_map(
new_cache = jax.tree.map(
lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)

# Gather log probabilities from logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _normalize_eval_metrics(
dist.all_reduce(metric)
total_metrics = {k: v.item() for k, v in total_metrics.items()}
eval_denominator = total_metrics.pop('denominator')
return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics)
return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics)


class WmtWorkloadPostLN(WmtWorkload):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9,
raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power)

def init_fn(params):
mu = jax.tree_map(jnp.zeros_like, params) # First moment
nu = jax.tree_map(jnp.zeros_like, params) # Second moment
mu = jax.tree.map(jnp.zeros_like, params) # First moment
nu = jax.tree.map(jnp.zeros_like, params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

def update_fn(updates, state, params=None):
Expand All @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(
updates = jax.tree.map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

Expand All @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(
return jax.tree.map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
"""Perform bias correction. This becomes a no-op as count goes to infinity."""
beta = 1 - decay**count
return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)
return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment)


def scale_by_learning_rate(learning_rate, flip_sign=True):
Expand Down Expand Up @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
b2=hyperparameters.beta2,
eps=1e-8,
weight_decay=hyperparameters.weight_decay)
params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
workload.param_shapes)
optimizer_state = opt_init_fn(params_zeros_like)

Expand Down Expand Up @@ -236,15 +236,15 @@ def _loss_fn(params):
(summed_loss, n_valid_examples, grad) = lax.psum(
(summed_loss, n_valid_examples, grad), axis_name='batch')
loss = summed_loss / n_valid_examples
grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)

grad_norm = jnp.sqrt(
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))

if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)

updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
current_param_container)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9,
raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power)

def init_fn(params):
mu = jax.tree_map(jnp.zeros_like, params) # First moment
nu = jax.tree_map(jnp.zeros_like, params) # Second moment
mu = jax.tree.map(jnp.zeros_like, params) # First moment
nu = jax.tree.map(jnp.zeros_like, params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

def update_fn(updates, state, params=None):
Expand All @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(
updates = jax.tree.map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

Expand All @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(
return jax.tree.map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
"""Perform bias correction. This becomes a no-op as count goes to infinity."""
beta = 1 - decay**count
return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)
return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment)


def scale_by_learning_rate(learning_rate, flip_sign=True):
Expand Down Expand Up @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
b2=hyperparameters.beta2,
eps=1e-8,
weight_decay=hyperparameters.weight_decay)
params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
workload.param_shapes)
optimizer_state = opt_init_fn(params_zeros_like)

Expand Down Expand Up @@ -236,15 +236,15 @@ def _loss_fn(params):
(summed_loss, n_valid_examples, grad) = lax.psum(
(summed_loss, n_valid_examples, grad), axis_name='batch')
loss = summed_loss / n_valid_examples
grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)

grad_norm = jnp.sqrt(
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))

if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)

updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
current_param_container)
Expand Down
Loading
Loading