Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Sep 13, 2024
1 parent 7935adb commit 64a102a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 26 deletions.
1 change: 1 addition & 0 deletions horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def post(self):
# We report maintenance exception only if we couldn't find any jobs
if self.worker.maintenance:
raise e.WorkerMaintenance(self.worker.maintenance_msg)
logger.debug(self.skipped)
return {"id": None, "ids": [], "skipped": self.skipped}, 200

def get_sorted_wp(self, priority_user_ids=None):
Expand Down
8 changes: 0 additions & 8 deletions horde/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,14 +1058,6 @@ def count_skipped_image_wp(worker, models_list=None, blacklist=None, priority_us
skipped_wps = open_wp_list.filter(
ImageWaitingPrompt.extra_slow_workers == False, # noqa E712
).count()
test = (db.session.query(ImageWaitingPrompt)
.options(noload(ImageWaitingPrompt.processing_gens))
.outerjoin(
WPModels,
WPAllowedWorkers,
)
)
logger.debug(test.all())
if skipped_wps > 0:
ret_dict["performance"] = ret_dict.get("performance",0) + skipped_wps
# Count skipped WPs requiring trusted workers
Expand Down
34 changes: 16 additions & 18 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
print("test_simple_image_gen")
headers = {"apikey": api_key, "Client-Agent": f"aihorde_ci_client:{CIVERSION}:(discord)db0#1625"} # ci/cd user
async_dict = {
"prompt": "a horde of cute stable robots in a sprawling server room repairing a massive mainframe",
Expand Down Expand Up @@ -37,7 +38,6 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
async_results = async_req.json()
req_id = async_results["id"]
# print(async_results)
print(async_results)
pop_dict = {
"name": "CICD Fake Dreamer",
"models": TEST_MODELS,
Expand All @@ -55,14 +55,14 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
}
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
try:
print(pop_req.text)
# print(pop_req.text)
assert pop_req.ok, pop_req.text
except AssertionError as err:
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
print("Request cancelled")
raise err
pop_results = pop_req.json()
print(json.dumps(pop_results, indent=4))
# print(json.dumps(pop_results, indent=4))

job_id = pop_results["id"]
try:
Expand All @@ -84,7 +84,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
retrieve_req = requests.get(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
assert retrieve_req.ok, retrieve_req.text
retrieve_results = retrieve_req.json()
print(json.dumps(retrieve_results, indent=4))
# print(json.dumps(retrieve_results, indent=4))
assert len(retrieve_results["generations"]) == 1
gen = retrieve_results["generations"][0]
assert len(gen["gen_metadata"]) == 0
Expand All @@ -94,15 +94,17 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
assert gen["state"] == "ok"
assert retrieve_results["kudos"] > 1
assert retrieve_results["done"] is True
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)

TEST_MODELS_FLUX = ["Flux.1-Schnell fp8 (Compact)"]



def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
print("test_flux_image_gen")
headers = {"apikey": api_key, "Client-Agent": f"aihorde_ci_client:{CIVERSION}:(discord)db0#1625"} # ci/cd user
async_dict = {
"prompt": "a horde of cute stable robots in a sprawling server room repairing a massive mainframe",
"prompt": "a horde of cute flux robots in a sprawling server room repairing a massive mainframe",
"nsfw": True,
"censor_nsfw": False,
"r2": True,
Expand All @@ -125,7 +127,7 @@ def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
assert async_req.ok, async_req.text
async_results = async_req.json()
req_id = async_results["id"]
print(async_results)
# print(async_results)
pop_dict = {
"name": "CICD Fake Dreamer",
"models": TEST_MODELS_FLUX,
Expand Down Expand Up @@ -154,29 +156,28 @@ def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
print("Request cancelled")
raise err
pop_results = pop_req.json()
print(json.dumps(pop_results, indent=4))
# print(json.dumps(pop_results, indent=4))
try:
assert pop_results["id"] is None, pop_results
assert pop_results["skipped"]["step_count"] == 1, pop_results
assert pop_results["skipped"].get("step_count") == 1, pop_results
except AssertionError as err:
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
print("Request cancelled")
raise err
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)

# Test extra_slow_worker
pop_dict["limit_max_steps"] = False
async_dict["params"]["steps"] = 5
pop_dict["extra_slow_worker"] = True
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
try:
print(pop_req.text)
# print(pop_req.text)
assert pop_req.ok, pop_req.text
except AssertionError as err:
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
print("Request cancelled")
raise err
pop_results = pop_req.json()
print(json.dumps(pop_results, indent=4))
# print(json.dumps(pop_results, indent=4))
try:
assert pop_results["id"] is None, pop_results
assert pop_results["skipped"]["performance"] == 1, pop_results
Expand All @@ -192,17 +193,14 @@ def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
assert async_req.ok, async_req.text
async_results = async_req.json()
req_id = async_results["id"]
print(async_results)
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
try:
print(pop_req.text)
assert pop_req.ok, pop_req.text
except AssertionError as err:
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
print("Request cancelled")
raise err
pop_results = pop_req.json()
print(json.dumps(pop_results, indent=4))
job_id = pop_results["id"]
try:
assert job_id is not None, pop_results
Expand All @@ -223,7 +221,7 @@ def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
retrieve_req = requests.get(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
assert retrieve_req.ok, retrieve_req.text
retrieve_results = retrieve_req.json()
print(json.dumps(retrieve_results, indent=4))
# print(json.dumps(retrieve_results, indent=4))
assert len(retrieve_results["generations"]) == 1
gen = retrieve_results["generations"][0]
assert len(gen["gen_metadata"]) == 0
Expand All @@ -233,9 +231,9 @@ def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
assert gen["state"] == "ok"
assert retrieve_results["kudos"] > 1
assert retrieve_results["done"] is True

requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)

if __name__ == "__main__":
# "ci/cd#12285"
# test_simple_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1")
test_simple_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1")
test_flux_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1")

0 comments on commit 64a102a

Please sign in to comment.