Skip to content

Commit

Permalink
add test for training model
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Mar 22, 2024
1 parent 2b5919d commit 59bebc9
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,16 +401,16 @@ def test_ijepa_model_initialization_and_processing(self):
params = self.model.init(
jax.random.key(0),
self.x,
context_mask,
target_mask,
context_mask[jnp.newaxis],
target_mask[jnp.newaxis],
training=False
)

outputs , _ = self.model.apply(
params,
self.x,
context_mask,
target_mask,
context_mask[jnp.newaxis],
target_mask[jnp.newaxis],
training=False
)

Expand All @@ -419,5 +419,26 @@ def test_ijepa_model_initialization_and_processing(self):
self.assertEqual(outputs[0][0].shape, outputs[0][1].shape)


def test_ijepa_training(self):
x = jax.random.normal(
jax.random.PRNGKey(0),
(9, self.image_size, self.image_size, self.num_channels)
)

dataset = ArrayDataset(x)

dataloader = DataLoader(dataset,
batch_size=3,
shuffle=True,
drop_last=False)

data_sampler = IJEPADataSampler(
image_size=self.image_size,
patch_size=self.patch_size
)

trainer = IJEPADataParallelTrainer(self.model, x.shape, 'params.pkl', data_sampler=data_sampler)
trainer.train(dataloader, 10, dataloader)

if __name__ == '__main__':
unittest.main()

0 comments on commit 59bebc9

Please sign in to comment.