Skip to content

Commit

Permalink
Merge branch 'fix/markers'
Browse files Browse the repository at this point in the history
  • Loading branch information
wjm41 committed Feb 27, 2022
2 parents a0c5838 + fc8d154 commit b5df13d
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 42 deletions.
48 changes: 20 additions & 28 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@
"import molplotly\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -4459,7 +4449,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7fc93564f610>"
"<IPython.lib.display.IFrame at 0x7feaa9613490>"
]
},
"metadata": {},
Expand Down Expand Up @@ -4488,7 +4478,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -4506,7 +4496,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7fc9346330d0>"
"<IPython.lib.display.IFrame at 0x7feaa96d29e0>"
]
},
"metadata": {},
Expand Down Expand Up @@ -4542,7 +4532,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -4560,7 +4550,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x20f8e103520>"
"<IPython.lib.display.IFrame at 0x7feaa96d2e30>"
]
},
"metadata": {},
Expand Down Expand Up @@ -4598,7 +4588,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -4616,7 +4606,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x20f907deb80>"
"<IPython.lib.display.IFrame at 0x7feaaabd1420>"
]
},
"metadata": {},
Expand All @@ -4634,6 +4624,7 @@
" x=\"y_true\",\n",
" y=\"y_pred\",\n",
" size='Molecular Weight',\n",
" symbol='Minimum Degree',\n",
" color='dataset',\n",
" title='ESOL Regression (colored by random train/test split)',\n",
" labels={'y_pred': 'Predicted Solubility',\n",
Expand All @@ -4645,7 +4636,8 @@
" df=df_esol,\n",
" smiles_col='smiles',\n",
" title_col='Compound ID',\n",
" color_col='dataset')\n",
" color_col='dataset',\n",
" marker_col='Minimum Degree')\n",
"\n",
"app_train_test.run_server(mode='inline', port=8703, height=1000)\n"
]
Expand All @@ -4672,7 +4664,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -4690,7 +4682,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x20f90876b80>"
"<IPython.lib.display.IFrame at 0x7feaaad74970>"
]
},
"metadata": {},
Expand Down Expand Up @@ -4733,7 +4725,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -4751,7 +4743,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x20f909b2700>"
"<IPython.lib.display.IFrame at 0x7feaaad74760>"
]
},
"metadata": {},
Expand Down Expand Up @@ -4800,7 +4792,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -4835,7 +4827,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -4853,7 +4845,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x20f8b475b80>"
"<IPython.lib.display.IFrame at 0x7feaaadcfd90>"
]
},
"metadata": {},
Expand Down Expand Up @@ -4893,7 +4885,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -4932,7 +4924,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand All @@ -4950,7 +4942,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x20f91152070>"
"<IPython.lib.display.IFrame at 0x7feaab5afb20>"
]
},
"metadata": {},
Expand Down
107 changes: 93 additions & 14 deletions molplotly/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,71 @@ def str2bool(v):
return v.lower() in ("yes", "true", "t", "1")


def test_groups(fig, df_grouped):
"""Test if plotly figure curve names match up with pandas dataframe groups
Args:
fig (plotly figure): _description_
groups (pandas groupby object): _description_
Returns:
_type_: Bool describing whether or not groups is the correct dataframe grouping descbrining the data in fig
"""
str_groups = {}
for name, group in df_grouped:
# if isinstance(name, bool) or isinstance(name, int):
# str_groups[str(name)] = group
if isinstance(name, tuple):
str_groups[", ".join(str(x) for x in name)] = group
else:
str_groups[name] = group

for data in fig.data:
if data.name in str_groups:
if len(data.y) == len(str_groups[data.name]):
continue
else:
return False
return True


def find_grouping(fig, df_data, cols):

if len(cols) == 1:
df_grouped = df_data.groupby(cols)
if not test_groups(fig, df_grouped):
raise ValueError(
"marker_col is mispecified because the dataframe grouping names don't match the names in the plotly figure."
)

elif len(cols) == 2: # color_col and marker_col

df_grouped_x = df_data.groupby(cols)
df_grouped_y = df_data.groupby([cols[1], cols[0]])

if test_groups(fig, df_grouped_x):
df_grouped = df_grouped_x

elif test_groups(fig, df_grouped_y):
df_grouped = df_grouped_y
else:
raise ValueError(
"color_col and marker_col are mispecified because their dataframe grouping names don't match the names in the plotly figure."
)
else:
raise ValueError("Too many columns specified for grouping.")

str_groups = {}
for name, group in df_grouped:
if isinstance(name, tuple):
str_groups[", ".join(str(x) for x in name)] = group
else:
str_groups[name] = group

curve_dict = {index: str_groups[x["name"]] for index, x in enumerate(fig.data)}
return df_grouped, curve_dict


def add_molecules(
fig,
df,
Expand All @@ -31,6 +96,7 @@ def add_molecules(
caption_cols=None,
caption_transform={},
color_col=None,
marker_col=None,
wrap=True,
wraplen=20,
width=150,
Expand Down Expand Up @@ -77,22 +143,25 @@ def add_molecules(
the font size used in the hover box - the font of the title line is fontsize+2 (default 12)
"""
fig.update_traces(hoverinfo="none", hovertemplate=None)

df_data = df.copy()
if color_col is not None:
df_data[color_col] = df_data[color_col].astype(str)
if marker_col is not None:
df_data[marker_col] = df_data[marker_col].astype(str)
colors = {0: "black"}

if len(fig.data) != 1:
if color_col is not None:
colors = {index: x.marker["color"] for index, x in enumerate(fig.data)}
if df[color_col].dtype == bool:
curve_dict = {
index: str2bool(x["name"]) for index, x in enumerate(fig.data)
}
elif df[color_col].dtype == int:
curve_dict = {index: int(x["name"]) for index, x in enumerate(fig.data)}
else:
curve_dict = {index: x["name"] for index, x in enumerate(fig.data)}
else:
if color_col is None and marker_col is None:
raise ValueError(
"color_col needs to be specified if there is more than one plotly curve in the figure!"
"More than one plotly curve in figure - color_col and/or marker_col needs to be specified."
)
if color_col is None:
df_grouped, curve_dict = find_grouping(fig, df_data, [marker_col])
elif marker_col is None:
df_grouped, curve_dict = find_grouping(fig, df_data, [color_col])
else:
df_grouped, curve_dict = find_grouping(
fig, df_data, [color_col, marker_col]
)

app = JupyterDash(__name__)
Expand Down Expand Up @@ -143,8 +212,16 @@ def display_hover(hoverData, value):
num = pt["pointNumber"]
curve_num = pt["curveNumber"]

# print(hoverData)
# print(pt)

if len(fig.data) != 1:
df_curve = df[df[color_col] == curve_dict[curve_num]].reset_index(drop=True)
# TODO replace with query
# df_curve = df_grouped.get_group(curve_dict[curve_num]).reset_index(
# drop=True
# )
df_curve = curve_dict[curve_num].reset_index(drop=True)
# df_curve = df[df[color_col] == curve_dict[curve_num]]
df_row = df_curve.iloc[num]
else:
df_row = df.iloc[num]
Expand Down Expand Up @@ -197,6 +274,8 @@ def display_hover(hoverData, value):
title = textwrap.fill(title, width=wraplen)
else:
title = title[:wraplen] + "..."

# TODO colorbar color titles
hoverbox_elements.append(
html.H4(
f"{title}",
Expand Down

0 comments on commit b5df13d

Please sign in to comment.