Skip to content

Commit

Permalink
fix bulk runner for img2img
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Oct 3, 2023
1 parent ed3c0cf commit dc59890
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions recipes/BulkRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def render_form_v2(self):
field_props = schema["properties"][field]
title = field_props["title"]
keys = None
if field_props["type"] == "array":
if is_arr(field_props):
try:
ref = field_props["items"]["$ref"]
props = schema["definitions"][ref]["properties"]
Expand All @@ -103,7 +103,7 @@ def render_form_v2(self):
keys = {k: k for k in sr.state[field][0].keys()}
except (KeyError, IndexError, AttributeError):
pass
elif field_props["type"] == "object":
elif field_props.get("type") == "object":
try:
keys = {k: k for k in sr.state[field].keys()}
except (KeyError, AttributeError):
Expand Down Expand Up @@ -277,7 +277,7 @@ def build_requests_for_df(df, request, df_ix, arr_len):
for field, col in request.input_columns.items():
parts = field.split(".")
field_props = properties.get(parts[0]) or properties.get(parts)
if field_props["type"] == "array":
if is_arr(field_props):
arr = request_body.setdefault(parts[0], [])
for arr_ix in range(arr_len):
value = df.at[df_ix + arr_ix, col]
Expand All @@ -289,7 +289,7 @@ def build_requests_for_df(df, request, df_ix, arr_len):
if len(arr) <= arr_ix:
arr.append(None)
arr[arr_ix] = value
elif len(parts) > 1 and field_props["type"] == "object":
elif len(parts) > 1 and field_props.get("type") == "object":
obj = request_body.setdefault(parts[0], {})
obj[parts[1]] = df.at[df_ix, col]
else:
Expand All @@ -312,7 +312,7 @@ def slice_request_df(df, request):
properties = schema["properties"]

for field, col in request.input_columns.items():
if properties.get(field.split(".")[0])["type"] != "array":
if is_arr(properties.get(field.split(".")[0])):
non_array_cols.add(col)
non_array_df = df[list(non_array_cols)]

Expand All @@ -327,6 +327,16 @@ def slice_request_df(df, request):
df_ix += arr_len


def is_arr(field_props: dict) -> bool:
try:
return field_props["type"] == "array"
except KeyError:
for props in field_props.get("anyOf", []):
if props["type"] == "array":
return True
return False


def _read_df(f: str) -> "pd.DataFrame":
import pandas as pd

Expand Down

0 comments on commit dc59890

Please sign in to comment.