Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Oct 10, 2023
1 parent 7ba0c4c commit fea7529
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup

try:
import wandb # pylint: disable=g-import-not-at-top
except ModuleNotFoundError:
logging.exception('Unable to import wandb.')
wandb = None
# try:
# import wandb # pylint: disable=g-import-not-at-top
# except ModuleNotFoundError:
# logging.exception('Unable to import wandb.')
wandb = None

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def loss_fn(self,
label_smoothing: float = 0.0) -> spec.Tensor: # differentiable
del label_smoothing
losses = jnp.sum(
jnp.abs(outputs_batch - label_batch),
axis=tuple(range(1, outputs_batch.ndim)))
jnp.abs(logits_batch - label_batch),
axis=tuple(range(1, logits_batch.ndim)))
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
losses *= mask_batch
Expand Down

0 comments on commit fea7529

Please sign in to comment.