From ef38f73f39deb1a6111c59d56036afc9d6e24e7c Mon Sep 17 00:00:00 2001 From: db0 Date: Fri, 13 Sep 2024 11:35:13 +0200 Subject: [PATCH] flush --- horde/classes/base/worker.py | 34 +++++++++++++++++----------------- horde/classes/kobold/worker.py | 5 +++-- horde/classes/stable/worker.py | 2 +- tests/test_image.py | 4 ++-- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/horde/classes/base/worker.py b/horde/classes/base/worker.py index cb5cdbd4..f8bf2ef8 100644 --- a/horde/classes/base/worker.py +++ b/horde/classes/base/worker.py @@ -197,7 +197,7 @@ def report_suspicion(self, amount=1, reason=Suspicions.WORKER_PROFANITY, formats f"Last suspicion log: {reason.name}.\n" f"Total Suspicion {self.get_suspicion()}", ) - db.session.commit() + db.session.flush() def get_suspicion_reasons(self): return set([s.suspicion_id for s in self.suspicions]) @@ -205,7 +205,7 @@ def get_suspicion_reasons(self): def reset_suspicion(self): """Clears the worker's suspicion and resets their reasons""" db.session.query(WorkerSuspicions).filter_by(worker_id=self.id).delete() - db.session.commit() + db.session.flush() def get_suspicion(self): return len(self.suspicions) @@ -226,7 +226,7 @@ def set_name(self, new_name): if len(new_name) > 100: return "Too Long" self.name = sanitize_string(new_name) - db.session.commit() + db.session.flush() return "OK" def set_info(self, new_info): @@ -237,12 +237,12 @@ def set_info(self, new_info): if len(new_info) > 1000: return "Too Long" self.info = sanitize_string(new_info) - db.session.commit() + db.session.flush() return "OK" def set_team(self, new_team): self.team_id = new_team.id - db.session.commit() + db.session.flush() return "OK" # This should be overwriten by each specific horde @@ -254,11 +254,11 @@ def toggle_maintenance(self, is_maintenance_active, maintenance_msg=None): self.maintenance_msg = self.default_maintenance_msg if self.maintenance and maintenance_msg not in [None, ""]: self.maintenance_msg = sanitize_string(maintenance_msg) - db.session.commit() + db.session.flush() def toggle_paused(self, is_paused_active): self.paused = is_paused_active - db.session.commit() + db.session.flush() # This should be extended by each worker type def check_in(self, **kwargs): @@ -343,7 +343,7 @@ def record_contribution(self, raw_things, kudos, things_per_sec): ).delete(synchronize_session=False) new_performance = WorkerPerformance(worker_id=self.id, performance=things_per_sec) db.session.add(new_performance) - db.session.commit() + db.session.flush() if things_per_sec / hv.thing_divisors[self.wtype] > hv.suspicion_thresholds[self.wtype]: self.report_suspicion( reason=Suspicions.UNREASONABLY_FAST, @@ -356,10 +356,10 @@ def modify_kudos(self, kudos, action="generated"): if not kudos_details: kudos_details = WorkerStats(worker_id=self.id, action=action, value=round(kudos, 2)) db.session.add(kudos_details) - db.session.commit() + db.session.flush() else: kudos_details.value = round(kudos_details.value + kudos, 2) - db.session.commit() + db.session.flush() logger.trace([kudos_details, kudos_details.value]) def log_aborted_job(self): @@ -391,7 +391,7 @@ def log_aborted_job(self): self.report_suspicion(reason=Suspicions.TOO_MANY_JOBS_ABORTED) self.aborted_jobs = 0 self.uncompleted_jobs += 1 - db.session.commit() + db.session.flush() # def is_slow(self): @@ -430,19 +430,19 @@ def import_kudos_details(self, kudos_details): for key in kudos_details: new_kd = WorkerStats(worker_id=self.id, action=key, value=kudos_details[key]) db.session.add(new_kd) - db.session.commit() + db.session.flush() def import_performances(self, performances): for p in performances: new_kd = WorkerPerformance(worker_id=self.id, performance=p) db.session.add(new_kd) - db.session.commit() + db.session.flush() def import_suspicions(self, suspicions): for s in suspicions: new_suspicion = WorkerSuspicions(worker_id=self.id, suspicion_id=int(s)) db.session.add(new_suspicion) - db.session.commit() + db.session.flush() # Should be extended by each specific horde @logger.catch(reraise=True) @@ -528,7 +528,7 @@ def set_blacklist(self, blacklist): for word in blacklist: blacklisted_word = WorkerBlackList(worker_id=self.id, word=word[0:15]) db.session.add(blacklisted_word) - db.session.commit() + db.session.flush() def refresh_model_cache(self): models_list = [m.model for m in self.models] @@ -564,11 +564,11 @@ def set_models(self, models): return # logger.debug([existing_model_names,models, existing_model_names == models]) db.session.query(WorkerModel).filter_by(worker_id=self.id).delete() - db.session.commit() + db.session.flush() for model_name in models: model = WorkerModel(worker_id=self.id, model=model_name) db.session.add(model) - db.session.commit() + db.session.flush() self.refresh_model_cache() def parse_models(self, models): diff --git a/horde/classes/kobold/worker.py b/horde/classes/kobold/worker.py index 857692e6..a52c8f71 100644 --- a/horde/classes/kobold/worker.py +++ b/horde/classes/kobold/worker.py @@ -57,6 +57,7 @@ def check_in(self, max_length, max_context_length, softprompts, **kwargs): f"{paused_string}Text Worker {self.name} checked-in, offering models {self.models} " f"at {self.max_length} max tokens and {self.max_context_length} max content length.", ) + db.session.flush() def refresh_softprompt_cache(self): softprompts_list = [s.softprompt for s in self.softprompts] @@ -100,11 +101,11 @@ def set_softprompts(self, softprompts): ], ) db.session.query(TextWorkerSoftprompts).filter_by(worker_id=self.id).delete() - db.session.commit() + db.session.flush() for softprompt_name in softprompts: softprompt = TextWorkerSoftprompts(worker_id=self.id, softprompt=softprompt_name) db.session.add(softprompt) - db.session.commit() + db.session.flush() self.refresh_softprompt_cache() def calculate_uptime_reward(self): diff --git a/horde/classes/stable/worker.py b/horde/classes/stable/worker.py index 132ab6ff..4dfe9609 100644 --- a/horde/classes/stable/worker.py +++ b/horde/classes/stable/worker.py @@ -52,7 +52,7 @@ def check_in(self, max_pixels, **kwargs): paused_string = "" if self.paused: paused_string = "(Paused) " - db.session.commit() + db.session.flush() logger.trace( f"{paused_string}Stable Worker {self.name} checked-in, offering models {self.get_model_names()} " f"at {self.max_pixels} max pixels", diff --git a/tests/test_image.py b/tests/test_image.py index 1b569bfb..a9476ad5 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -161,8 +161,8 @@ def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: assert pop_results["id"] is None, pop_results assert pop_results["skipped"].get("step_count") == 1, pop_results except AssertionError as err: - # requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) - # print("Request cancelled") + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") raise err # Test extra_slow_worker