Skip to content

Commit

Permalink
flush
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Sep 13, 2024
1 parent ce6ada5 commit ef38f73
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 22 deletions.
34 changes: 17 additions & 17 deletions horde/classes/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ def report_suspicion(self, amount=1, reason=Suspicions.WORKER_PROFANITY, formats
f"Last suspicion log: {reason.name}.\n"
f"Total Suspicion {self.get_suspicion()}",
)
db.session.commit()
db.session.flush()

def get_suspicion_reasons(self):
return set([s.suspicion_id for s in self.suspicions])

def reset_suspicion(self):
"""Clears the worker's suspicion and resets their reasons"""
db.session.query(WorkerSuspicions).filter_by(worker_id=self.id).delete()
db.session.commit()
db.session.flush()

def get_suspicion(self):
return len(self.suspicions)
Expand All @@ -226,7 +226,7 @@ def set_name(self, new_name):
if len(new_name) > 100:
return "Too Long"
self.name = sanitize_string(new_name)
db.session.commit()
db.session.flush()
return "OK"

def set_info(self, new_info):
Expand All @@ -237,12 +237,12 @@ def set_info(self, new_info):
if len(new_info) > 1000:
return "Too Long"
self.info = sanitize_string(new_info)
db.session.commit()
db.session.flush()
return "OK"

def set_team(self, new_team):
self.team_id = new_team.id
db.session.commit()
db.session.flush()
return "OK"

# This should be overwriten by each specific horde
Expand All @@ -254,11 +254,11 @@ def toggle_maintenance(self, is_maintenance_active, maintenance_msg=None):
self.maintenance_msg = self.default_maintenance_msg
if self.maintenance and maintenance_msg not in [None, ""]:
self.maintenance_msg = sanitize_string(maintenance_msg)
db.session.commit()
db.session.flush()

def toggle_paused(self, is_paused_active):
self.paused = is_paused_active
db.session.commit()
db.session.flush()

# This should be extended by each worker type
def check_in(self, **kwargs):
Expand Down Expand Up @@ -343,7 +343,7 @@ def record_contribution(self, raw_things, kudos, things_per_sec):
).delete(synchronize_session=False)
new_performance = WorkerPerformance(worker_id=self.id, performance=things_per_sec)
db.session.add(new_performance)
db.session.commit()
db.session.flush()
if things_per_sec / hv.thing_divisors[self.wtype] > hv.suspicion_thresholds[self.wtype]:
self.report_suspicion(
reason=Suspicions.UNREASONABLY_FAST,
Expand All @@ -356,10 +356,10 @@ def modify_kudos(self, kudos, action="generated"):
if not kudos_details:
kudos_details = WorkerStats(worker_id=self.id, action=action, value=round(kudos, 2))
db.session.add(kudos_details)
db.session.commit()
db.session.flush()
else:
kudos_details.value = round(kudos_details.value + kudos, 2)
db.session.commit()
db.session.flush()
logger.trace([kudos_details, kudos_details.value])

def log_aborted_job(self):
Expand Down Expand Up @@ -391,7 +391,7 @@ def log_aborted_job(self):
self.report_suspicion(reason=Suspicions.TOO_MANY_JOBS_ABORTED)
self.aborted_jobs = 0
self.uncompleted_jobs += 1
db.session.commit()
db.session.flush()

# def is_slow(self):

Expand Down Expand Up @@ -430,19 +430,19 @@ def import_kudos_details(self, kudos_details):
for key in kudos_details:
new_kd = WorkerStats(worker_id=self.id, action=key, value=kudos_details[key])
db.session.add(new_kd)
db.session.commit()
db.session.flush()

def import_performances(self, performances):
for p in performances:
new_kd = WorkerPerformance(worker_id=self.id, performance=p)
db.session.add(new_kd)
db.session.commit()
db.session.flush()

def import_suspicions(self, suspicions):
for s in suspicions:
new_suspicion = WorkerSuspicions(worker_id=self.id, suspicion_id=int(s))
db.session.add(new_suspicion)
db.session.commit()
db.session.flush()

# Should be extended by each specific horde
@logger.catch(reraise=True)
Expand Down Expand Up @@ -528,7 +528,7 @@ def set_blacklist(self, blacklist):
for word in blacklist:
blacklisted_word = WorkerBlackList(worker_id=self.id, word=word[0:15])
db.session.add(blacklisted_word)
db.session.commit()
db.session.flush()

def refresh_model_cache(self):
models_list = [m.model for m in self.models]
Expand Down Expand Up @@ -564,11 +564,11 @@ def set_models(self, models):
return
# logger.debug([existing_model_names,models, existing_model_names == models])
db.session.query(WorkerModel).filter_by(worker_id=self.id).delete()
db.session.commit()
db.session.flush()
for model_name in models:
model = WorkerModel(worker_id=self.id, model=model_name)
db.session.add(model)
db.session.commit()
db.session.flush()
self.refresh_model_cache()

def parse_models(self, models):
Expand Down
5 changes: 3 additions & 2 deletions horde/classes/kobold/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def check_in(self, max_length, max_context_length, softprompts, **kwargs):
f"{paused_string}Text Worker {self.name} checked-in, offering models {self.models} "
f"at {self.max_length} max tokens and {self.max_context_length} max content length.",
)
db.session.flush()

def refresh_softprompt_cache(self):
softprompts_list = [s.softprompt for s in self.softprompts]
Expand Down Expand Up @@ -100,11 +101,11 @@ def set_softprompts(self, softprompts):
],
)
db.session.query(TextWorkerSoftprompts).filter_by(worker_id=self.id).delete()
db.session.commit()
db.session.flush()
for softprompt_name in softprompts:
softprompt = TextWorkerSoftprompts(worker_id=self.id, softprompt=softprompt_name)
db.session.add(softprompt)
db.session.commit()
db.session.flush()
self.refresh_softprompt_cache()

def calculate_uptime_reward(self):
Expand Down
2 changes: 1 addition & 1 deletion horde/classes/stable/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def check_in(self, max_pixels, **kwargs):
paused_string = ""
if self.paused:
paused_string = "(Paused) "
db.session.commit()
db.session.flush()
logger.trace(
f"{paused_string}Stable Worker {self.name} checked-in, offering models {self.get_model_names()} "
f"at {self.max_pixels} max pixels",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
assert pop_results["id"] is None, pop_results
assert pop_results["skipped"].get("step_count") == 1, pop_results
except AssertionError as err:
# requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
# print("Request cancelled")
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
print("Request cancelled")
raise err

# Test extra_slow_worker
Expand Down

0 comments on commit ef38f73

Please sign in to comment.