Skip to content

Commit

Permalink
Merge pull request #17 from leaf-ai/plotly
Browse files Browse the repository at this point in the history
Plotly
  • Loading branch information
ofrancon authored Sep 18, 2020
2 parents 0617d98 + a88a231 commit 650cbd6
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ notebook==6.1.4
scikit-learn==0.23.2
tensorflow==2.3.0
keras==2.4.3
plotly==4.9.0
211 changes: 210 additions & 1 deletion robojudge.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,21 @@
"metadata": {},
"outputs": [],
"source": [
"actual_df.head()"
"actual_df.head(8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# actual_as_pred_df = actual_df.copy()\n",
"# actual_as_pred_df[\"PredictorName\"] = \"Ground truth\"\n",
"# actual_as_pred_df[\"Prediction\"] = False\n",
"# actual_as_pred_df = actual_as_pred_df.rename(columns={\"ActualDailyNewCases\": \"PredictedDailyNewCases\",\n",
"# \"ActualDailyNewCases7DMA\": \"PredictedDailyNewCases7DMA\"})\n",
"# actual_as_pred_df.head(8)"
]
},
{
Expand Down Expand Up @@ -575,6 +589,201 @@
"cr_df[(cr_df.CountryName.isin(NORTH_AMERICA)) & (cr_df.RegionName == \"\")]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Plots"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"default_country = \"Italy\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prediction vs actual"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"country_df = ranking_df[ranking_df.CountryName == default_country]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predictor_names = list(country_df.PredictorName.unique())\n",
"country_names = list(ranking_df.CountryName.unique())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"country_df[country_df[\"PredictorName\"] == 'Predictor #27']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import plotly.graph_objects as go\n",
"\n",
"fig = go.Figure(layout=dict(title=dict(text=f'Predicted New Cases 7-day Moving Average in {default_country}',\n",
" y=0.9,\n",
" x=0.5,\n",
" xanchor='center',\n",
" yanchor='top'\n",
" ),\n",
" plot_bgcolor='#f2f2f2',\n",
" xaxis_title=\"Date\",\n",
" yaxis_title=\"New Cases\"\n",
" ))\n",
"\n",
"# Add 1 trace per predictor\n",
"for predictor_name in predictor_names:\n",
" pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n",
" fig.add_trace(go.Scatter(x=pred_country_df.Date,\n",
" y=pred_country_df.PredictedDailyNewCases7DMA,\n",
" name=predictor_name)\n",
" )\n",
"\n",
"# Add 1 trace for the true number of cases\n",
"country_actual_df = actual_df[(actual_df.CountryName == default_country) &\n",
" (actual_df.Date >= start_date)]\n",
"fig.add_trace(go.Scatter(x=country_actual_df.Date,\n",
" y=country_actual_df.ActualDailyNewCases7DMA,\n",
" name=\"Ground Truth\",\n",
" line=dict(color='orange', width=4, dash='dash'))\n",
" )\n",
"# Format x axis\n",
"fig.update_xaxes(\n",
"dtick=\"D1\", # Means 1 day\n",
"tickformat=\"%d\\n%b\")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Filter by country"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import plotly.graph_objects as go\n",
"\n",
"fig = go.Figure(layout=dict(title=dict(text=f'Predicted 7-day Moving Average of New Cases in {default_country}',\n",
" y=0.9,\n",
" x=0.5,\n",
" xanchor='center',\n",
" yanchor='top'\n",
" ),\n",
" plot_bgcolor='#f2f2f2',\n",
" xaxis_title=\"Date\",\n",
" yaxis_title=\"New Cases\"\n",
" ))\n",
"\n",
"# Keep track of trace visibility by country name\n",
"country_plot_names = []\n",
"\n",
"# Add 1 trace per predictor, per country\n",
"for predictor_name in predictor_names:\n",
" for country_name in country_names:\n",
" country_df = ranking_df[ranking_df.CountryName == country_name]\n",
" pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n",
" fig.add_trace(go.Scatter(x=pred_country_df.Date,\n",
" y=pred_country_df.PredictedDailyNewCases7DMA,\n",
" name=predictor_name,\n",
" visible= (country_name == default_country))\n",
" )\n",
" country_plot_names.append(country_name)\n",
"\n",
"# For each country\n",
"# Add 1 trace for the true number of cases\n",
"for country_name in country_names:\n",
" country_actual_df = actual_df[(actual_df.CountryName == country_name) &\n",
" (actual_df.Date >= start_date)]\n",
" fig.add_trace(go.Scatter(x=country_actual_df.Date,\n",
" y=country_actual_df.ActualDailyNewCases7DMA,\n",
" name=\"Ground Truth\",\n",
" visible= (country_name == default_country),\n",
" line=dict(color='orange', width=4, dash='dash'))\n",
" )\n",
" country_plot_names.append(country_name)\n",
"\n",
"# Format x axis\n",
"fig.update_xaxes(\n",
"dtick=\"D1\", # Means 1 day\n",
"tickformat=\"%d\\n%b\")\n",
"\n",
"# Filter\n",
"buttons=[]\n",
"for country_name in country_names:\n",
" buttons.append(dict(method='update',\n",
" label=country_name,\n",
" args = [{'visible': [country_name==r for r in country_plot_names]},\n",
" {'title': \"Predicted 7-day Moving Average of New Cases in \" + country_name}]))\n",
"fig.update_layout(showlegend=True,\n",
" updatemenus=[{\"buttons\": buttons,\n",
" \"direction\": \"down\",\n",
" \"active\": country_names.index(default_country),\n",
" \"showactive\": True,\n",
" \"x\": 0.1,\n",
" \"y\": 1.15}])\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Daily diff in 7 days moving average"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = px.line(ranking_df[ranking_df.CountryName == \"Italy\"], x=\"Date\", y=\"Diff7DMA\", color='PredictorName')\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 650cbd6

Please sign in to comment.