Skip to content

Commit

Permalink
Fix template vars
Browse files Browse the repository at this point in the history
  • Loading branch information
steventkrawczyk committed Aug 16, 2023
1 parent 832ff10 commit a566467
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
12 changes: 8 additions & 4 deletions prompttools/playground/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@


def render_prompts(templates, vars):
print(templates)
print(vars)
prompts = []
for i, template in enumerate(templates):
environment = jinja2.Environment()
jinja_template = environment.from_string(template)
prompts.append(jinja_template.render(**vars[i]))
for template in templates:
for var_set in vars:
environment = jinja2.Environment()
jinja_template = environment.from_string(template)
prompts.append(jinja_template.render(**var_set))
print(prompts)
return prompts


Expand Down
12 changes: 7 additions & 5 deletions prompttools/playground/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@

if mode == "Instruction":
placeholders = [[st.empty() for _ in range(instruction_count + 1)] for _ in range(prompt_count)]
# placeholders = []

cols = st.columns(instruction_count + 1)

Expand Down Expand Up @@ -150,9 +151,10 @@
key=f"prompt_{i}",
)
)
placeholders.append([])
for j in range(1, instruction_count + 1):
with cols[j]:
placeholders[i][j] = st.empty() # placeholders for the future output
placeholders[i].append(st.empty()) # placeholders for the future output
st.divider()

run_button, clear_button, share_button = st.columns([1, 1, 1], gap="small")
Expand Down Expand Up @@ -266,15 +268,15 @@
df = load_data(model_type, model, [instruction], prompts, temperature, api_key=api_key)
st.session_state.prompts = prompts
st.session_state.df = df
for i in range(len(prompts)):
for i in range(len(vars)):
for j in range(len(templates)):
placeholders[i][j + variable_count].markdown(df["response"][i + len(prompts) * j])
placeholders[i][j + variable_count].markdown(df["response"][i + len(vars) * j])
elif "df" in st.session_state and "prompts" in st.session_state and not clear:
df = st.session_state.df
prompts = st.session_state.prompts
for i in range(len(prompts)):
for i in range(len(vars)):
for j in range(len(templates)):
placeholders[i][j + variable_count].markdown(df["response"][i + len(prompts) * j])
placeholders[i][j + variable_count].markdown(df["response"][i + len(vars) * j])
elif mode == "Model Comparison":
placeholders = [[st.empty() for _ in range(model_count + 1)] for _ in range(prompt_count)]

Expand Down

0 comments on commit a566467

Please sign in to comment.