Skip to content

Commit

Permalink
minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Mar 24, 2024
1 parent e9b939a commit b5e6297
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions nanodl/__src/models/ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,19 @@ def evaluate(self,
context_mask = jnp.repeat(context_mask[jnp.newaxis], batch_size, axis=0)
target_mask = jnp.repeat(target_mask[jnp.newaxis], batch_size, axis=0)

context_mask = context_mask.reshape((self.num_devices, batch_size_per_device, context_mask.shape[1], context_mask.shape[2]))
target_mask = target_mask.reshape((self.num_devices, batch_size_per_device, target_mask.shape[1], target_mask.shape[2]))
context_mask = context_mask.reshape((
self.num_devices,
batch_size_per_device,
context_mask.shape[1],
context_mask.shape[2]
))

target_mask = target_mask.reshape((
self.num_devices,
batch_size_per_device,
target_mask.shape[1],
target_mask.shape[2]
))

loss = self.evaluation_step(
state=self.state,
Expand Down

0 comments on commit b5e6297

Please sign in to comment.