Skip to content

Commit

Permalink
fixing ruffs
Browse files Browse the repository at this point in the history
  • Loading branch information
SamChou05 committed Dec 3, 2024
1 parent 76f789d commit c58cd20
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions afqinsight/nn/tests/test_tf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def run_tensorflow_model(model, data_loaders, n_epochs=20):
Smoke testing on Tensorflow models to ensure it trains and tests correctly.
Args:
model (model):
model (function):
Tensorflow model to train and test
data_loaders (tuple):
Train dataset,
test dataset,
validation datasets
n_epoch (int):
Number of epochs to train the model,
default is 100
default is 20
Returns
"""
Expand Down Expand Up @@ -117,7 +117,17 @@ def run_tensorflow_model(model, data_loaders, n_epochs=20):
)
def test_tensorflow_models(model, data_loaders):
"""
Test multiple PyTorch models
Test multiple Tensorflow models
Args:
model (function):
Tensorflow model to train and test
data_loaders (tuple):
Train dataset,
test dataset,
validation datasets
Returns
"""

train_dataset, X_test, X_train, y_test, val_dataset = data_loaders
Expand Down

0 comments on commit c58cd20

Please sign in to comment.