From 05cd49afd2aa0b889a27fa7b0c353d9bed5735a9 Mon Sep 17 00:00:00 2001 From: Olivier Francon Date: Fri, 18 Sep 2020 02:14:01 +0200 Subject: [PATCH 1/2] #11 Create a plot to compare predictors --- requirements.txt | 1 + robojudge.ipynb | 230 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 229 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index ab73b47a..a48ffad4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ notebook==6.1.4 scikit-learn==0.23.2 tensorflow==2.3.0 keras==2.4.3 +plotly==4.9.0 diff --git a/robojudge.ipynb b/robojudge.ipynb index f82b8f3c..e89d7397 100644 --- a/robojudge.ipynb +++ b/robojudge.ipynb @@ -324,7 +324,21 @@ "metadata": {}, "outputs": [], "source": [ - "actual_df.head()" + "actual_df.head(8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# actual_as_pred_df = actual_df.copy()\n", + "# actual_as_pred_df[\"PredictorName\"] = \"Ground truth\"\n", + "# actual_as_pred_df[\"Prediction\"] = False\n", + "# actual_as_pred_df = actual_as_pred_df.rename(columns={\"ActualDailyNewCases\": \"PredictedDailyNewCases\",\n", + "# \"ActualDailyNewCases7DMA\": \"PredictedDailyNewCases7DMA\"})\n", + "# actual_as_pred_df.head(8)" ] }, { @@ -575,12 +589,224 @@ "cr_df[(cr_df.CountryName.isin(NORTH_AMERICA)) & (cr_df.RegionName == \"\")]" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Plots" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "import plotly.express as px" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prediction vs actual" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "selected_country = \"Italy\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: add actual" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.line(ranking_df[ranking_df.CountryName == selected_country],\n", + " x=\"Date\",\n", + " y=\"PredictedDailyNewCases7DMA\",\n", + " color=\"PredictorName\",\n", + " title=\"Predicted daily new cases (7 days moving average)\")\n", + "fig.update_xaxes(\n", + " dtick=\"D1\", # Means 1 day\n", + " tickformat=\"%d\\n%b\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prediction vs actual (Scatter)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "country_df = ranking_df[ranking_df.CountryName == selected_country]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictors = list(country_df.PredictorName.unique())\n", + "predictors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "country_df[country_df[\"PredictorName\"] == 'Predictor #27']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.graph_objects as go\n", + "\n", + "fig = go.Figure(layout=dict(title=dict(text=f'Predicted New Cases per Day in {selected_country}',\n", + " y=0.9,\n", + " x=0.5,\n", + " xanchor='center',\n", + " yanchor='top'\n", + " ),\n", + " plot_bgcolor='#f2f2f2',\n", + " xaxis_title=\"Date\",\n", + " yaxis_title=\"New Cases\"\n", + " ))\n", + "\n", + "# Add 1 trace per predictor\n", + "for predictor_name in predictors:\n", + " pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n", + " fig.add_trace(go.Scatter(x=pred_country_df.Date,\n", + " y=pred_country_df.PredictedDailyNewCases7DMA,\n", + " name=predictor_name)\n", + " )\n", + "\n", + "# Add 1 trace for the true number of cases\n", + "country_actual_df = actual_df[(actual_df.CountryName == selected_country) &\n", + " (actual_df.Date >= start_date)]\n", + "fig.add_trace(go.Scatter(x=country_actual_df.Date,\n", + " y=country_actual_df.ActualDailyNewCases7DMA,\n", + " name=\"Ground Truth\",\n", + " line=dict(color='orange', width=4, dash='dash'))\n", + " )\n", + "# Format x axis\n", + "fig.update_xaxes(\n", + "dtick=\"D1\", # Means 1 day\n", + "tickformat=\"%d\\n%b\")\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Daily diff in 7 days moving average" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.line(ranking_df[ranking_df.CountryName == \"Italy\"], x=\"Date\", y=\"Diff7DMA\", color='PredictorName')\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Filter by country" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.graph_objects as go\n", + "\n", + "fig = go.Figure(layout=dict(title=dict(text=f'Predicted New Cases per Day in {selected_country}',\n", + " y=0.9,\n", + " x=0.5,\n", + " xanchor='center',\n", + " yanchor='top'\n", + " ),\n", + " plot_bgcolor='#f2f2f2',\n", + " xaxis_title=\"Date\",\n", + " yaxis_title=\"New Cases\"\n", + " ))\n", + "\n", + "# Add 1 trace per predictor\n", + "for predictor_name in predictors:\n", + " pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n", + " fig.add_trace(go.Scatter(x=pred_country_df.Date,\n", + " y=pred_country_df.PredictedDailyNewCases7DMA,\n", + " name=predictor_name)\n", + " )\n", + "\n", + "# Add 1 trace for the true number of cases\n", + "country_actual_df = actual_df[(actual_df.CountryName == selected_country) &\n", + " (actual_df.Date >= start_date)]\n", + "fig.add_trace(go.Scatter(x=country_actual_df.Date,\n", + " y=country_actual_df.ActualDailyNewCases7DMA,\n", + " name=\"Ground Truth\",\n", + " line=dict(color='orange', width=4, dash='dash'))\n", + " )\n", + "# Format x axis\n", + "fig.update_xaxes(\n", + "dtick=\"D1\", # Means 1 day\n", + "tickformat=\"%d\\n%b\")\n", + "\n", + "# Filter\n", + "default_country = \"Italy\"\n", + "buttons=[]\n", + "for country_name in country_df.CountryName.unique():\n", + " buttons.append(dict(method='update',\n", + " label=country_name,\n", + " args = [{'visible': [country_name==r for r in country_plot_names]},\n", + " {'title': \"Predicted New Cases Per Day in \" + country_name}]))\n", + "fig.update_layout(showlegend=True,\n", + " updatemenus=[{\"buttons\": buttons,\n", + " \"direction\": \"down\",\n", + " \"active\": default_country,\n", + " \"showactive\": True,\n", + " \"x\": 0.5,\n", + " \"y\": 1.15}])\n", + "\n", + "fig.show()" + ] } ], "metadata": { From a88a231fe5061ed5d2ed5d2b446a7c1f8ba98de3 Mon Sep 17 00:00:00 2001 From: Olivier Francon Date: Fri, 18 Sep 2020 03:08:44 +0200 Subject: [PATCH 2/2] #11 Filter by country when comparing submissions --- robojudge.ipynb | 155 +++++++++++++++++++++--------------------------- 1 file changed, 69 insertions(+), 86 deletions(-) diff --git a/robojudge.ipynb b/robojudge.ipynb index e89d7397..5dfa0391 100644 --- a/robojudge.ipynb +++ b/robojudge.ipynb @@ -602,7 +602,7 @@ "metadata": {}, "outputs": [], "source": [ - "import plotly.express as px" + "default_country = \"Italy\"" ] }, { @@ -618,7 +618,7 @@ "metadata": {}, "outputs": [], "source": [ - "selected_country = \"Italy\"" + "country_df = ranking_df[ranking_df.CountryName == default_country]" ] }, { @@ -627,50 +627,8 @@ "metadata": {}, "outputs": [], "source": [ - "# TODO: add actual" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.line(ranking_df[ranking_df.CountryName == selected_country],\n", - " x=\"Date\",\n", - " y=\"PredictedDailyNewCases7DMA\",\n", - " color=\"PredictorName\",\n", - " title=\"Predicted daily new cases (7 days moving average)\")\n", - "fig.update_xaxes(\n", - " dtick=\"D1\", # Means 1 day\n", - " tickformat=\"%d\\n%b\")\n", - "fig.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prediction vs actual (Scatter)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "country_df = ranking_df[ranking_df.CountryName == selected_country]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "predictors = list(country_df.PredictorName.unique())\n", - "predictors" + "predictor_names = list(country_df.PredictorName.unique())\n", + "country_names = list(ranking_df.CountryName.unique())" ] }, { @@ -690,7 +648,7 @@ "source": [ "import plotly.graph_objects as go\n", "\n", - "fig = go.Figure(layout=dict(title=dict(text=f'Predicted New Cases per Day in {selected_country}',\n", + "fig = go.Figure(layout=dict(title=dict(text=f'Predicted New Cases 7-day Moving Average in {default_country}',\n", " y=0.9,\n", " x=0.5,\n", " xanchor='center',\n", @@ -702,7 +660,7 @@ " ))\n", "\n", "# Add 1 trace per predictor\n", - "for predictor_name in predictors:\n", + "for predictor_name in predictor_names:\n", " pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n", " fig.add_trace(go.Scatter(x=pred_country_df.Date,\n", " y=pred_country_df.PredictedDailyNewCases7DMA,\n", @@ -710,7 +668,7 @@ " )\n", "\n", "# Add 1 trace for the true number of cases\n", - "country_actual_df = actual_df[(actual_df.CountryName == selected_country) &\n", + "country_actual_df = actual_df[(actual_df.CountryName == default_country) &\n", " (actual_df.Date >= start_date)]\n", "fig.add_trace(go.Scatter(x=country_actual_df.Date,\n", " y=country_actual_df.ActualDailyNewCases7DMA,\n", @@ -725,23 +683,6 @@ "fig.show()" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Daily diff in 7 days moving average" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.line(ranking_df[ranking_df.CountryName == \"Italy\"], x=\"Date\", y=\"Diff7DMA\", color='PredictorName')\n", - "fig.show()" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -757,7 +698,7 @@ "source": [ "import plotly.graph_objects as go\n", "\n", - "fig = go.Figure(layout=dict(title=dict(text=f'Predicted New Cases per Day in {selected_country}',\n", + "fig = go.Figure(layout=dict(title=dict(text=f'Predicted 7-day Moving Average of New Cases in {default_country}',\n", " y=0.9,\n", " x=0.5,\n", " xanchor='center',\n", @@ -768,45 +709,87 @@ " yaxis_title=\"New Cases\"\n", " ))\n", "\n", - "# Add 1 trace per predictor\n", - "for predictor_name in predictors:\n", - " pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n", - " fig.add_trace(go.Scatter(x=pred_country_df.Date,\n", - " y=pred_country_df.PredictedDailyNewCases7DMA,\n", - " name=predictor_name)\n", - " )\n", + "# Keep track of trace visibility by country name\n", + "country_plot_names = []\n", "\n", + "# Add 1 trace per predictor, per country\n", + "for predictor_name in predictor_names:\n", + " for country_name in country_names:\n", + " country_df = ranking_df[ranking_df.CountryName == country_name]\n", + " pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n", + " fig.add_trace(go.Scatter(x=pred_country_df.Date,\n", + " y=pred_country_df.PredictedDailyNewCases7DMA,\n", + " name=predictor_name,\n", + " visible= (country_name == default_country))\n", + " )\n", + " country_plot_names.append(country_name)\n", + "\n", + "# For each country\n", "# Add 1 trace for the true number of cases\n", - "country_actual_df = actual_df[(actual_df.CountryName == selected_country) &\n", - " (actual_df.Date >= start_date)]\n", - "fig.add_trace(go.Scatter(x=country_actual_df.Date,\n", - " y=country_actual_df.ActualDailyNewCases7DMA,\n", - " name=\"Ground Truth\",\n", - " line=dict(color='orange', width=4, dash='dash'))\n", - " )\n", + "for country_name in country_names:\n", + " country_actual_df = actual_df[(actual_df.CountryName == country_name) &\n", + " (actual_df.Date >= start_date)]\n", + " fig.add_trace(go.Scatter(x=country_actual_df.Date,\n", + " y=country_actual_df.ActualDailyNewCases7DMA,\n", + " name=\"Ground Truth\",\n", + " visible= (country_name == default_country),\n", + " line=dict(color='orange', width=4, dash='dash'))\n", + " )\n", + " country_plot_names.append(country_name)\n", + "\n", "# Format x axis\n", "fig.update_xaxes(\n", "dtick=\"D1\", # Means 1 day\n", "tickformat=\"%d\\n%b\")\n", "\n", "# Filter\n", - "default_country = \"Italy\"\n", "buttons=[]\n", - "for country_name in country_df.CountryName.unique():\n", + "for country_name in country_names:\n", " buttons.append(dict(method='update',\n", " label=country_name,\n", " args = [{'visible': [country_name==r for r in country_plot_names]},\n", - " {'title': \"Predicted New Cases Per Day in \" + country_name}]))\n", + " {'title': \"Predicted 7-day Moving Average of New Cases in \" + country_name}]))\n", "fig.update_layout(showlegend=True,\n", " updatemenus=[{\"buttons\": buttons,\n", " \"direction\": \"down\",\n", - " \"active\": default_country,\n", + " \"active\": country_names.index(default_country),\n", " \"showactive\": True,\n", - " \"x\": 0.5,\n", + " \"x\": 0.1,\n", " \"y\": 1.15}])\n", "\n", "fig.show()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Daily diff in 7 days moving average" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.line(ranking_df[ranking_df.CountryName == \"Italy\"], x=\"Date\", y=\"Diff7DMA\", color='PredictorName')\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {