Skip to content

Commit

Permalink
[Bug] Fix data mutation issue and improve prompts (#603)
Browse files Browse the repository at this point in the history
Co-authored-by: Maximilian Schulz <[email protected]>
  • Loading branch information
lingyielia and maxschulz-COL authored Jul 30, 2024
1 parent cd76b81 commit 77fc1ac
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->

### Changed

- Stabilized `plot` performance by addressing several dataframe mutation issues. ([#603](https://github.com/mckinsey/vizro/pull/603))

<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->

<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
4 changes: 2 additions & 2 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def _run_plot_tasks(
) -> PlotOutputs:
"""Task execution."""
chart_type_pipeline = self.pipeline_manager.chart_type_pipeline
chart_types = chart_type_pipeline.run(initial_args={"chain_input": user_input, "df": df})
chart_type = chart_type_pipeline.run(initial_args={"chain_input": user_input, "df": df})

# TODO update to loop through charts for multiple charts creation
plot_pipeline = self.pipeline_manager.plot_pipeline
custom_chart_code = plot_pipeline.run(
initial_args={"chain_input": user_input, "df": df, "chart_types": chart_types}
initial_args={"chain_input": user_input, "df": df, "chart_type": chart_type}
)

# TODO add debug in pipeline after getting _debug_helper logic in component
Expand Down
17 changes: 5 additions & 12 deletions vizro-ai/src/vizro_ai/plot/components/chart_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.plot.components import VizroAIComponentBase
from vizro_ai.plot.schema_manager import SchemaManager
from vizro_ai.utils.helper import _get_df_info

# initialization of schema manager, and register schema needed
# preprocess: llm kwargs for function description schema + partial vars
Expand All @@ -35,12 +36,12 @@ class ChartSelection(BaseModel):


# 2. Define prompt
chart_type_prompt = "choose a best chart types for this df info:{df_schema}, {df_head} and user question {input}?"
chart_type_prompt = "choose a best chart type for this df info:{df_schema}, {df_sample} and user question {input}?"


# 3. Define Component
class GetChartSelection(VizroAIComponentBase):
"""Get Chart Types.
"""Get chart type.
Attributes
prompt (str): Prompt chart selection chains.
Expand All @@ -63,9 +64,9 @@ def _pre_process(self, df: pd.DataFrame, *args, **kwargs) -> Tuple[Dict, Dict]:
It should return llm_kwargs and partial_vars_map for
"""
df_schema, df_head = self._get_df_info(df)
df_schema, df_sample = _get_df_info(df)
llm_kwargs_to_use = openai_schema_manager.get_llm_kwargs("ChartSelection")
partial_vars = {"df_schema": df_schema, "df_head": df_head}
partial_vars = {"df_schema": df_schema, "df_sample": df_sample}
return llm_kwargs_to_use, partial_vars

def _post_process(self, response: Dict, *args, **kwargs) -> str:
Expand All @@ -87,14 +88,6 @@ def run(self, chain_input: str, df: pd.DataFrame = None) -> str:
"""
return super().run(chain_input=chain_input, df=df)

@staticmethod
def _get_df_info(df: pd.DataFrame) -> Tuple[str, str]:
"""Get the dataframe schema and head info as string."""
formatted_pairs = [f"{col_name}: {dtype}" for col_name, dtype in df.dtypes.items()]
schema_string = "\n".join(formatted_pairs)

return schema_string, df.head().to_markdown()

@staticmethod
def _chart_to_use(load_args) -> str:
"""Get Chart name as string or list of chart names as string."""
Expand Down
15 changes: 11 additions & 4 deletions vizro-ai/src/vizro_ai/plot/components/code_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@ class CodeDebug(BaseModel):


# 2. Define prompt
debugging_prompt = (
"Return the full code snippet after fixing the bug in the code snippet {code_snippet}, this is the error message "
"{input},"
)
debugging_prompt = """
You are an expert Python and Pandas code reviewer and corrector.
Your task is to review Pandas code strings provided, identify any issues or improvements,
and return a corrected version of the code.
Return the full code snippet after fixing the bug in the code snippet:
{code_snippet}
IMPORTANT: Avoid adding fake data for the variable df. It will be provided by the user when executed.
This is the error message:
{input},
"""


# 3. Define Component
Expand Down
22 changes: 14 additions & 8 deletions vizro-ai/src/vizro_ai/plot/components/custom_chart_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,26 @@

@openai_schema_manager.register
class CustomChart(BaseModel):
"""Plotly code per user request that is suitable for chart types for given data."""
"""Plotly code per user request that is suitable for chart type for given data."""

custom_chart_code: str = Field(..., description="Modified and decorated code snippet to allow use in dashboards")


# 2. Define prompt
custom_chart_prompt = """
Please modify the following code {input} such that:
1. You wrap the entire chart code into function called 'custom_chart' that takes a single optional arg called
data_frame and returns only the fig object, ie `def custom_chart(data_frame): as first line
2. You ensure that the above function only returns the plotly fig object,
and that the variables are renamed such that all data is derived from 'data_frame'
3. Leave all imports as is above that function, and do NOT add anything else
"""
Your task is to correctly wrap the provided code as instructed. IMPORTANT: Do not mock the data.
Instruction:
1. You wrap the entire chart code into function called 'custom_chart' that takes a single optional arg called
data_frame and returns only the fig object, ie `def custom_chart(data_frame): as first line.
2. You ensure that the above function only returns the plotly fig object,
and that the variables are renamed such that all data is derived from 'data_frame'.
3. Leave all imports as is above that function, and do NOT add anything else.
Please modify the following code:
{input}
"""


class GetCustomChart(VizroAIComponentBase):
Expand Down
36 changes: 18 additions & 18 deletions vizro-ai/src/vizro_ai/plot/components/dataframe_craft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.plot.components import VizroAIComponentBase
from vizro_ai.plot.schema_manager import SchemaManager
from vizro_ai.utils.helper import _get_df_info

logger = logging.getLogger(__name__)

Expand All @@ -31,14 +32,21 @@ class DataFrameCraft(BaseModel):


# 2. Define prompt
dataframe_prompt = """Context: You are working with a pandas DataFrame in Python named df.
DataFrame Details Schema: {df_schema}, Sample Data: {df_head}, User Query: {input}
Instructions: 1.Write code to manipulate the df DataFrame according to the user's query.
2.Do not create any new DataFrames; work only with df.
3.Ensure that any aggregated columns are named appropriately and re-indexed if necessary.
4.If a visualization is implied by the user's query, only write the necessary DataFrame manipulation
code for that visualization. 5.Do not include any plotting code.
6. Produce the code in a line-by-line format, not wrapped inside a function."""
dataframe_prompt = """
You are a software engineer working with a pandas DataFrame in Python named df.
Your task is to write code to manipulate the df DataFrame according to the user's query.
So user can get the desired output for create subsequent visualization.
DataFrame Details Schema: {df_schema}, Sample Data: {df_sample}, User Query: {input}
Instructions:
1. Write code to manipulate the df DataFrame according to the user's query.
2. Do not create any new DataFrames; work only with df.
3. Always make a hard copy of the DataFrame before manipulating it. Important: Do not modify the original DataFrame.
4. Ensure that any aggregated columns are named appropriately and re-indexed if necessary.
5. If a visualization is implied by the user's query, only write the necessary DataFrame manipulation
code for that visualization.
6. Do not include any plotting code.
7. Produce the code in a line-by-line format, not wrapped inside a function."""


# 3. Define Component
Expand Down Expand Up @@ -66,9 +74,9 @@ def _pre_process(self, df: pd.DataFrame, *args, **kwargs) -> Tuple[Dict, Dict]:
It should return llm_kwargs and partial_vars_map
"""
df_schema, df_head = self._get_df_info(df)
df_schema, df_sample = _get_df_info(df)
llm_kwargs_to_use = openai_schema_manager.get_llm_kwargs("DataFrameCraft")
partial_vars_map = {"df_schema": df_schema, "df_head": df_head}
partial_vars_map = {"df_schema": df_schema, "df_sample": df_sample}

return llm_kwargs_to_use, partial_vars_map

Expand All @@ -90,14 +98,6 @@ def run(self, chain_input: str, df: pd.DataFrame = None) -> str:
"""
return super().run(chain_input, df)

@staticmethod
def _get_df_info(df: pd.DataFrame) -> Tuple[str, str]:
"""Get the dataframe schema and head info as string."""
formatted_pairs = [f"{col_name}: {dtype}" for col_name, dtype in df.dtypes.items()]
schema_string = "\n".join(formatted_pairs)

return schema_string, df.head().to_markdown()

@staticmethod
def _format_dataframe_string(s: str) -> str:
"""Format the dataframe code snippet string."""
Expand Down
8 changes: 4 additions & 4 deletions vizro-ai/src/vizro_ai/plot/components/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class CodeExplanation(BaseModel):


# 2. Define prompt
code_explanation_prompt = (
"Given user question {input} and answer {code_snippet}, (less than 400 characters),"
"DO NOT just use one sentence for business insights, give detailed information"
)
code_explanation_prompt = """
Given user question {input} and answer {code_snippet} (less than 400 characters),
DO NOT just use one sentence for business insights, give detailed information.
"""


# 3. Define Component
Expand Down
57 changes: 44 additions & 13 deletions vizro-ai/src/vizro_ai/plot/components/visual_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Dict, Tuple

import pandas as pd

try:
from pydantic.v1 import BaseModel, Field
except ImportError: # pragma: no cov
Expand All @@ -12,25 +14,45 @@
from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.plot.components import VizroAIComponentBase
from vizro_ai.plot.schema_manager import SchemaManager
from vizro_ai.utils.helper import _get_df_info

# 1. Define schema
openai_schema_manager = SchemaManager()


@openai_schema_manager.register
class VizroCode(BaseModel):
"""Plotly code per user request that is suitable for chart types for given data."""
"""Plotly code per user request that is suitable for chart type for given data."""

visual_code: str = Field(..., description="code snippet for plot visuals using plotly")


# 2. Define prompt
visual_code_prompt = (
"Context: You are working with a pandas dataframe in Python. The name of the dataframe is `df`."
"Instructions: Given the code snippet {df_code}, generate Plotly visualization code to produce a {chart_types} "
"chart that addresses user query: {input}. "
"Please ensure the Plotly code aligns with the provided DataFrame details."
)
visual_code_prompt = """
Context: You are an AI assistant specialized in data visualization using Python, pandas, and Plotly.
Given:
- A pandas DataFrame named `df`
- DataFrame schema: {df_schema}
- Sample data (first few rows): {df_sample}
- Data preprocessing code: {df_code}
- User's visualization request: {input}
- Requested chart type: {chart_type}
Instructions:
1. Analyze the provided DataFrame information and preprocessing code.
2. Generate Plotly code to create a {chart_type} chart that addresses the user's query: {input}
3. Ensure the visualization accurately represents the data and aligns with the DataFrame structure.
4. Use appropriate Plotly Express functions when possible for simplicity.
5. If custom Plotly Graph Objects are necessary, provide clear explanations.
6. Include axis labels, title, and any other relevant chart components.
7. If color coding or additional visual elements would enhance the chart, incorporate them.
Output:
- Provide the complete Plotly code required to generate the requested visualization.
Note: Ensure all variable names and data references are consistent with the provided DataFrame (`df`).
"""


# 3. Define Component
Expand All @@ -53,13 +75,21 @@ def __init__(self, llm: BaseChatModel):
"""
super().__init__(llm)

def _pre_process(self, chart_types: str, df_code: str, *args, **kwargs) -> Tuple[Dict, Dict]:
def _pre_process(self, chart_type: str, df_code: str, df: pd.DataFrame, *args, **kwargs) -> Tuple[Dict, Dict]:
"""Preprocess for visual code.
It should return llm_kwargs and partial_vars_map.
"""
llm_kwargs_to_use = openai_schema_manager.get_llm_kwargs("VizroCode")
partial_vars_map = {"chart_types": chart_types, "df_code": df_code}

df_schema, df_sample = _get_df_info(df)

partial_vars_map = {
"chart_type": chart_type,
"df_code": df_code,
"df_schema": df_schema,
"df_sample": df_sample,
}

return llm_kwargs_to_use, partial_vars_map

Expand All @@ -72,19 +102,20 @@ def _post_process(self, response: Dict, df_code: str, *args, **kwargs) -> str:
return self._clean_visual_code(code_snippet)

@_log_time
def run(self, chain_input: str, df_code: str, chart_types: str) -> str:
def run(self, chain_input: str, df_code: str, chart_type: str, df: pd.DataFrame = None) -> str:
"""Run chain to get visual code.
Args:
chain_input: User input or intermediate question if needed.
df_code: Code snippet of dataframe.
chart_types: Chart types.
chart_type: chart type.
df: The dataframe for plotting.
Returns:
Visual code snippet.
"""
return super().run(chain_input=chain_input, df_code=df_code, chart_types=chart_types)
return super().run(chain_input=chain_input, df_code=df_code, chart_type=chart_type, df=df)

@staticmethod
def _add_df_string(code_string: str, df_code: str) -> str:
Expand Down Expand Up @@ -117,7 +148,7 @@ def _clean_visual_code(raw_code: str) -> str:
res = test_visual_code.run(
chain_input="choose a best chart for describe the composition of gdp in continent, "
"and horizontal line for avg gdp",
chart_types="bar",
chart_type="bar",
df_code=df_code,
)
print(res) # noqa: T201
6 changes: 3 additions & 3 deletions vizro-ai/src/vizro_ai/plot/task_pipeline/_pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ def __init__(self, llm: BaseChatModel = None):

@property
def chart_type_pipeline(self):
"""Target chart types pipeline."""
"""Target chart type pipeline."""
pipeline = Pipeline(self.llm)
pipeline.add(GetChartSelection, input_keys=["df", "chain_input"], output_key="chart_types")
pipeline.add(GetChartSelection, input_keys=["df", "chain_input"], output_key="chart_type")
return pipeline

@property
def plot_pipeline(self):
"""Plot pipeline."""
pipeline = Pipeline(self.llm)
pipeline.add(GetDataFrameCraft, input_keys=["df", "chain_input"], output_key="df_code")
pipeline.add(GetVisualCode, input_keys=["chain_input", "chart_types", "df_code"], output_key="chain_input")
pipeline.add(GetVisualCode, input_keys=["chain_input", "chart_type", "df_code", "df"], output_key="chain_input")
pipeline.add(GetCustomChart, input_keys=["chain_input"], output_key="custom_chart_code")
return pipeline
Loading

0 comments on commit 77fc1ac

Please sign in to comment.