diff --git a/doc/tutorials/trials.ipynb b/doc/tutorials/trials.ipynb index 8bfa645c1..d18b109c8 100644 --- a/doc/tutorials/trials.ipynb +++ b/doc/tutorials/trials.ipynb @@ -39,6 +39,8 @@ "from quantities import ms, mV, Hz\n", "from elephant.trials import TrialsFromBlock, TrialsFromLists\n", "from elephant.datasets import download_datasets\n", + "import matplotlib.pyplot as plt\n", + "\n", "\n", "# Helper function to create example spike trains and analog signals\n", "def create_example_data():\n", @@ -68,7 +70,7 @@ "Elephant version: 1.2.0b1. Data URL:https://gin.g-node.org/NeuralEnsemble/elephant-data/raw/v1.2.0b1/README.md, error: HTTP Error 404: Not Found.\n", "Using elephant-data latest instead (This is expected for elephant development versions).\n", " warnings.warn(f\"No corresponding version of elephant-data found.\\n\"\n", - "Downloading https://datasets.python-elephant.org/raw/master/tutorials/tutorial_unitary_event_analysis/data/dataset-1.nix to '/tmp/elephant/dataset-1.nix': 1.69MB [00:02, 804kB/s] \n" + "Downloading https://datasets.python-elephant.org/raw/master/tutorials/tutorial_unitary_event_analysis/data/dataset-1.nix to '/tmp/elephant/dataset-1.nix': 1.69MB [00:01, 1.29MB/s]\n" ] } ], @@ -104,7 +106,7 @@ "Number of trials: 36\n", "Number of spike trains in each trial: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n", "Number of analog signals in each trial: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n", - "Trial 1 Segment: \n", + "Trial 1 Segment: \n", "All spike trains from trial 2: [" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot all spiketrains belonging to specific trials\n", + "trials = (3,4,5)\n", + "n_trials = len(trials)\n", + "\n", + "# Create a figure with subplots for each trial\n", + "fig, axes = plt.subplots(n_trials, 1, figsize=(12, 1 * n_trials), sharex=True)\n", + "\n", + "for trial_no, _ in enumerate(trials):\n", + " spiketrains = trials_from_block.get_spiketrains_from_trial_as_list(trial_no)\n", + "\n", + " for i, spiketrain in enumerate(spiketrains):\n", + " axes[trial_no].plot(spiketrain.times, [i] * len(spiketrain), '|', markersize=5, label=f'Neuron {i+1}')\n", + " \n", + " # Set labels and title for each subplot\n", + " axes[trial_no].set_ylabel('Neuron')\n", + " axes[trial_no].set_title(f'Spike Trains from trial {trial_no}')\n", + " \n", + " # Set y-axis ticks to match neuron numbers\n", + " axes[trial_no].set_yticks(range(len(spiketrains)))\n", + " axes[trial_no].set_yticklabels([f'Neuron {i+1}' for i in range(len(spiketrains))])\n", + " \n", + " # Add legend to each subplot\n", + " axes[trial_no].legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n", + "\n", + "# Set common x-label\n", + "fig.text(0.5, 0.04, 'Time (s)', ha='center')\n", + "\n", + "# Adjust layout and display the plot\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -146,7 +197,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ {