Skip to content

Commit

Permalink
dev(narugo): add better generic clasisification
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Mar 20, 2024
1 parent c4d1062 commit 3d33bbf
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion imgutils/generic/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,16 @@ def _open_label(self, model_name: str) -> List[str]:

def _raw_predict(self, image: ImageTyping, model_name: str):
image = load_image(image, force_background='white', mode='RGB')
input_ = _img_encode(image)[None, ...]
model = self._open_model(model_name)
batch, channels, height, width = model.get_inputs()[0].shape
if channels != 3:
raise RuntimeError(f'Model {model_name!r} required {[batch, channels, height, width]!r}, '
f'channels not 3.') # pragma: no cover

if isinstance(height, int) and isinstance(width, int):
input_ = _img_encode(image, size=(width, height))[None, ...]

Check warning on line 91 in imgutils/generic/classify.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/classify.py#L91

Added line #L91 was not covered by tests
else:
input_ = _img_encode(image)[None, ...]
output, = self._open_model(model_name).run(['output'], {'input': input_})
return output

Expand Down

0 comments on commit 3d33bbf

Please sign in to comment.