Skip to content

Commit

Permalink
fixed issue with validation_steps=None arising on Google Colab
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Aug 15, 2019
1 parent fd6e842 commit 6258dd7
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 3 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ Most recent releases are shown at the top. Each release shows:
- **Changed**: Additional parameters, changes to inputs or outputs, etc
- **Fixed**: Bug fixes that don't change documented behaviour

## 0.2.1 (2019-08-15)

### New:
- N/A

### Changed:
- N/A

### Fixed:
- Fixed error related to validation_steps=None in call to fit_generator in ```ktrain.core``` on Google Colab


## 0.2.0 (2019-08-12)

### New:
Expand Down
2 changes: 0 additions & 2 deletions examples/vision/pets-ResNet50.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" \n",
"import sys\n",
"sys.path.append('..')\n",
"import ktrain\n",
"from ktrain import vision as vis"
]
Expand Down
4 changes: 4 additions & 0 deletions ktrain/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,9 @@ def fit(self, lr, n_cycles, cycle_len=None, cycle_mult=1,
# handle callbacks
num_samples = U.nsamples_from_data(self.train_data)
steps_per_epoch = num_samples // self.train_data.batch_size
validation_steps = None
if self.val_data is not None:
validation_steps = U.nsamples_from_data(self.val_data)//self.val_data.batch_size

epochs = self._check_cycles(n_cycles, cycle_len, cycle_mult)
self.set_lr(lr)
Expand All @@ -1066,6 +1069,7 @@ def fit(self, lr, n_cycles, cycle_len=None, cycle_mult=1,
warnings.filterwarnings('ignore', message='.*Check your callbacks.*')
hist = self.model.fit_generator(self.train_data,
steps_per_epoch = steps_per_epoch,
validation_steps = validation_steps,
epochs=epochs,
validation_data=self.val_data,
workers=self.workers,
Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.2.0'
__version__ = '0.2.1'

0 comments on commit 6258dd7

Please sign in to comment.