diff --git a/neo/rawio/spikeglxrawio.py b/neo/rawio/spikeglxrawio.py index 2e8f896d8..c6d5fda45 100644 --- a/neo/rawio/spikeglxrawio.py +++ b/neo/rawio/spikeglxrawio.py @@ -229,14 +229,25 @@ def _parse_header(self): spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype) # deal with nb_segment and t_start/t_stop per segment - self._t_starts = {seg_index: 0.0 for seg_index in range(nb_segment)} + + self._t_starts = {stream_name: {} for stream_name in stream_names} self._t_stops = {seg_index: 0.0 for seg_index in range(nb_segment)} - for seg_index in range(nb_segment): - for stream_name in stream_names: + + for stream_name in stream_names: + for seg_index in range(nb_segment): info = self.signals_info_dict[seg_index, stream_name] + + frame_start = float(info["meta"]["firstSample"]) + sampling_frequency = info["sampling_rate"] + t_start = frame_start / sampling_frequency + + self._t_starts[stream_name][seg_index] = t_start t_stop = info["sample_length"] / info["sampling_rate"] self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop) + + + # fille into header dict self.header = {} self.header["nb_block"] = 1 @@ -282,7 +293,8 @@ def _segment_t_stop(self, block_index, seg_index): return self._t_stops[seg_index] def _get_signal_t_start(self, block_index, seg_index, stream_index): - return 0.0 + stream_name = self.header["signal_streams"][stream_index]["name"] + return self._t_starts[stream_name][seg_index] def _event_count(self, event_channel_idx, block_index=None, seg_index=None): timestamps, _, _ = self._get_event_timestamps(block_index, seg_index, event_channel_idx, None, None) diff --git a/neo/test/rawiotest/test_spikeglxrawio.py b/neo/test/rawiotest/test_spikeglxrawio.py index ba65cf83d..84f6d31ac 100644 --- a/neo/test/rawiotest/test_spikeglxrawio.py +++ b/neo/test/rawiotest/test_spikeglxrawio.py @@ -110,6 +110,65 @@ def test_nidq_digital_channel(self): atol = 0.001 assert np.allclose(on_diff, 1, atol=atol) + def test_t_start_reading(self): + """Test that t_start values are correctly read for all streams and segments.""" + + # Expected t_start values for each stream and segment + expected_t_starts = { + 'imec0.ap': { + 0: 15.319535472007237, + 1: 15.339535431281986, + 2: 21.284723325294053, + 3: 21.3047232845688 + }, + 'imec1.ap': { + 0: 15.319554693264516, + 1: 15.339521518106308, + 2: 21.284735282142822, + 3: 21.304702106984614 + }, + 'imec0.lf': { + 0: 15.3191688060872, + 1: 15.339168765361949, + 2: 21.284356659374016, + 3: 21.304356618648765 + }, + 'imec1.lf': { + 0: 15.319321358082725, + 1: 15.339321516521915, + 2: 21.284568614155827, + 3: 21.30456877259502 + } + } + + # Initialize the RawIO + rawio = SpikeGLXRawIO(self.get_local_path("spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4")) + rawio.parse_header() + + # Get list of stream names + stream_names = rawio.header["signal_streams"]["name"] + + # Test t_start for each stream and segment + for stream_name, expected_values in expected_t_starts.items(): + # Get stream index + stream_index = list(stream_names).index(stream_name) + + # Check each segment + for seg_index, expected_t_start in expected_values.items(): + actual_t_start = rawio.get_signal_t_start( + block_index=0, + seg_index=seg_index, + stream_index=stream_index + ) + + # Use numpy.testing for proper float comparison + np.testing.assert_allclose( + actual_t_start, + expected_t_start, + rtol=1e-9, + atol=1e-9, + err_msg=f"Mismatch in t_start for stream '{stream_name}', segment {seg_index}" + ) if __name__ == "__main__": unittest.main()