Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save price to saved run #177

Merged
merged 12 commits into from
Oct 12, 2023
2 changes: 1 addition & 1 deletion Procfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ admin: poetry run python manage.py runserver 127.0.0.1:8000

dashboard: poetry run streamlit run Home.py --server.port 8501 --server.headless true

celery: poetry run celery -A celeryapp worker
celery: poetry run celery -A celeryapp worker -P threads -c 16

ui: cd ../gooey-ui/; PORT=3000 npm run dev
1 change: 1 addition & 0 deletions bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ class SavedRunAdmin(admin.ModelAdmin):
"created_at",
"run_time",
"updated_at",
"price",
]
list_filter = ["workflow"]
search_fields = ["workflow", "example_id", "run_id", "uid"]
Expand Down
56 changes: 56 additions & 0 deletions bots/migrations/0044_savedrun_price_alter_savedrun_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Generated by Django 4.2.5 on 2023-10-10 02:58

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("bots", "0043_alter_savedrun_workflow"),
]

operations = [
migrations.AddField(
model_name="savedrun",
name="price",
field=models.IntegerField(default=0),
),
migrations.AlterField(
model_name="savedrun",
name="workflow",
field=models.IntegerField(
choices=[
(1, "Doc Search"),
(2, "Doc Summary"),
(3, "Google GPT"),
(4, "Copilot"),
(5, "Lipysnc + TTS"),
(6, "Text to Speech"),
(7, "Speech Recognition"),
(8, "Lipsync"),
(9, "Deforum Animation"),
(10, "Compare Text2Img"),
(11, "Text2Audio"),
(12, "Img2Img"),
(13, "Face Inpainting"),
(14, "Google Image Gen"),
(15, "Compare AI Upscalers"),
(16, "SEO Summary"),
(17, "Email Face Inpainting"),
(18, "Social Lookup Email"),
(19, "Object Inpainting"),
(20, "Image Segmentation"),
(21, "Compare LLM"),
(22, "Chyron Plant"),
(23, "Letter Writer"),
(24, "Smart GPT"),
(25, "AI QR Code"),
(26, "Doc Extract"),
(27, "Related QnA Maker"),
(28, "Related QnA Maker Doc"),
(29, "Embeddings"),
(30, "Bulk Runner"),
],
default=4,
),
),
]
3 changes: 3 additions & 0 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class SavedRun(models.Model):
related_name="children",
)

price = models.IntegerField(default=0)
workflow = models.IntegerField(
choices=Workflow.choices, default=Workflow.VIDEO_BOTS
)
Expand Down Expand Up @@ -193,6 +194,8 @@ def to_dict(self) -> dict:
ret[StateKeys.hidden] = self.hidden
if self.is_flagged:
ret["is_flagged"] = self.is_flagged
if self.price:
ret["price"] = self.price
devxpy marked this conversation as resolved.
Show resolved Hide resolved
return ret

def set(self, state: dict):
Expand Down
2 changes: 1 addition & 1 deletion celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def save(done=False):
# run completed
except StopIteration:
run_time += time() - start_time
page.deduct_credits(st.session_state)
sr.price = page.deduct_credits(st.session_state)
break
# render errors nicely
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,12 +1047,13 @@ def check_credits(self) -> bool:
assert self.request.user, "request.user must be set to check credits"
return self.request.user.balance >= self.get_price_roundoff(st.session_state)

def deduct_credits(self, state: dict):
def deduct_credits(self, state: dict) -> int:
assert self.request, "request must be set to deduct credits"
assert self.request.user, "request.user must be set to deduct credits"

amount = self.get_price_roundoff(state)
self.request.user.add_balance(-amount, f"gooey_in_{uuid.uuid1()}")
return amount

def get_price_roundoff(self, state: dict) -> int:
# don't allow fractional pricing for now, min 1 credit
Expand Down