Skip to content

Commit

Permalink
masking demo (#29)
Browse files Browse the repository at this point in the history
* made client slightly more robust to funny inputs

* added masking demo

* version bump
  • Loading branch information
dmarx authored Sep 7, 2022
1 parent 9a29651 commit 2b060e4
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 12 deletions.
86 changes: 79 additions & 7 deletions nbs/demo_colab.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='stability-sdk',
version='0.2.0',
version='0.2.1',
author='Wes Brown',
author_email='[email protected]',
maintainer='David Marx',
Expand Down
8 changes: 4 additions & 4 deletions src/stability_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,10 @@ def generate(
if safety and classifiers is None:
classifiers = generation.ClassifierParameters()

if not prompt and not init_image:
if (prompt is None) and (init_image is None):
raise ValueError("prompt and/or init_image must be provided")

if mask_image and not init_image:
if (mask_image is not None) and (init_image is None):
raise ValueError("If mask_image is provided, init_image must also be provided")

request_id = str(uuid.uuid4())
Expand All @@ -256,7 +256,7 @@ def generate(
else:
raise TypeError("prompt must be a string or a sequence")

if init_image:
if (init_image is not None):
prompt += [image_to_prompt(init_image, init=True)]
parameters = generation.StepParameter(
scaled_step=0,
Expand All @@ -268,7 +268,7 @@ def generate(
end=end_schedule,
)
),
if mask_image:
if (mask_image is not None):
prompt += [image_to_prompt(mask_image, mask=True)]
else:
parameters = generation.StepParameter(
Expand Down

0 comments on commit 2b060e4

Please sign in to comment.