diff --git a/qiskit_ibm_runtime/execution_span/__init__.py b/qiskit_ibm_runtime/execution_span/__init__.py index 2e747ea45..7637ae82d 100644 --- a/qiskit_ibm_runtime/execution_span/__init__.py +++ b/qiskit_ibm_runtime/execution_span/__init__.py @@ -30,12 +30,14 @@ .. autosummary:: :toctree: ../stubs/ + DoubleSliceSpan ExecutionSpan ExecutionSpans ShapeType SliceSpan """ +from .double_slice_span import DoubleSliceSpan from .execution_span import ExecutionSpan, ShapeType from .execution_spans import ExecutionSpans from .slice_span import SliceSpan diff --git a/qiskit_ibm_runtime/execution_span/double_slice_span.py b/qiskit_ibm_runtime/execution_span/double_slice_span.py new file mode 100644 index 000000000..2e9bc0b0c --- /dev/null +++ b/qiskit_ibm_runtime/execution_span/double_slice_span.py @@ -0,0 +1,79 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""DoubleSliceSpan""" + +from __future__ import annotations + +from datetime import datetime +from typing import Iterable + +import math +import numpy as np +import numpy.typing as npt + +from .execution_span import ExecutionSpan, ShapeType + + +class DoubleSliceSpan(ExecutionSpan): + """An :class:`~.ExecutionSpan` for data stored in a sliceable format. + + This type of execution span references pub result data by assuming that it is a sliceable + portion of the data where the shots are the outermost slice and the rest of the data is flattened. + Therefore, for each pub dependent on this span, the constructor accepts two :class:`slice` objects, + along with the corresponding shape of the data to be sliced; in contrast to + :class:`~.SliceSpan`, this class does not assume that *all* shots for a particular set of parameter + values are contiguous in the array of data. + + Args: + start: The start time of the span, in UTC. + stop: The stop time of the span, in UTC. + data_slices: A map from pub indices to ``(shape_tuple, slice, slice)``. + """ + + def __init__( + self, + start: datetime, + stop: datetime, + data_slices: dict[int, tuple[ShapeType, slice, slice]], + ): + super().__init__(start, stop) + self._data_slices = data_slices + + def __eq__(self, other: object) -> bool: + return isinstance(other, DoubleSliceSpan) and ( + self.start == other.start + and self.stop == other.stop + and self._data_slices == other._data_slices + ) + + @property + def pub_idxs(self) -> list[int]: + return sorted(self._data_slices) + + @property + def size(self) -> int: + size = 0 + for shape, args_sl, shots_sl in self._data_slices.values(): + size += len(range(math.prod(shape[:-1]))[args_sl]) * len(range(shape[-1])[shots_sl]) + return size + + def mask(self, pub_idx: int) -> npt.NDArray[np.bool_]: + shape, args_sl, shots_sl = self._data_slices[pub_idx] + mask = np.zeros(shape, dtype=np.bool_) + mask.reshape(np.prod(shape[:-1]), shape[-1])[(args_sl, shots_sl)] = True + return mask + + def filter_by_pub(self, pub_idx: int | Iterable[int]) -> "DoubleSliceSpan": + pub_idx = {pub_idx} if isinstance(pub_idx, int) else set(pub_idx) + slices = {idx: val for idx, val in self._data_slices.items() if idx in pub_idx} + return DoubleSliceSpan(self.start, self.stop, slices) diff --git a/release-notes/unreleased/1982.feat.rst b/release-notes/unreleased/1982.feat.rst new file mode 100644 index 000000000..31f634561 --- /dev/null +++ b/release-notes/unreleased/1982.feat.rst @@ -0,0 +1 @@ +Added :class:`.DoubleSliceSpan`, an :class:`ExecutionSpan` for batching with two slices. diff --git a/test/unit/test_execution_span.py b/test/unit/test_execution_span.py index 52e773254..960ccddc5 100644 --- a/test/unit/test_execution_span.py +++ b/test/unit/test_execution_span.py @@ -18,7 +18,7 @@ import numpy as np import numpy.testing as npt -from qiskit_ibm_runtime.execution_span import SliceSpan, ExecutionSpans +from qiskit_ibm_runtime.execution_span import SliceSpan, DoubleSliceSpan, ExecutionSpans from ..ibm_test_case import IBMTestCase @@ -126,6 +126,102 @@ def test_filter_by_pub(self): self.assertEqual(self.span2.filter_by_pub(1), SliceSpan(self.start2, self.stop2, {})) +@ddt.ddt +class TestDoubleSliceSpan(IBMTestCase): + """Class for testing DoubleSliceSpan.""" + + def setUp(self) -> None: + super().setUp() + self.start1 = datetime(2024, 10, 11, 4, 31, 30) + self.stop1 = datetime(2024, 10, 11, 4, 31, 34) + self.slices1 = { + 2: ((1, 100), slice(1), slice(4, 9)), + 0: ((3, 5, 10), slice(10, 13), slice(2, 5)), + } + self.span1 = DoubleSliceSpan(self.start1, self.stop1, self.slices1) + + self.start2 = datetime(2024, 10, 16, 11, 9, 20) + self.stop2 = datetime(2024, 10, 16, 11, 9, 30) + self.slices2 = { + 0: ((5, 100), slice(3, 5), slice(20, 40)), + 1: ((1, 5, 3), slice(2, 5), slice(3)), + } + self.span2 = DoubleSliceSpan(self.start2, self.stop2, self.slices2) + + def test_limits(self): + """Test the start and stop properties""" + self.assertEqual(self.span1.start, self.start1) + self.assertEqual(self.span1.stop, self.stop1) + self.assertEqual(self.span2.start, self.start2) + self.assertEqual(self.span2.stop, self.stop2) + + def test_equality(self): + """Test the equality method.""" + self.assertEqual(self.span1, self.span1) + self.assertEqual(self.span1, DoubleSliceSpan(self.start1, self.stop1, self.slices1)) + self.assertNotEqual(self.span1, "aoeu") + self.assertNotEqual(self.span1, self.span2) + + def test_duration(self): + """Test the duration property""" + self.assertEqual(self.span1.duration, 4) + self.assertEqual(self.span2.duration, 10) + + def test_repr(self): + """Test the repr method""" + expect = "start='2024-10-11 04:31:30', stop='2024-10-11 04:31:34', size=14" + self.assertEqual(repr(self.span1), f"DoubleSliceSpan(<{expect}>)") + + def test_size(self): + """Test the size property""" + self.assertEqual(self.span1.size, 1 * 5 + 3 * 3) + self.assertEqual(self.span2.size, 2 * 20 + 3 * 3) + + def test_pub_idxs(self): + """Test the pub_idxs property""" + self.assertEqual(self.span1.pub_idxs, [0, 2]) + self.assertEqual(self.span2.pub_idxs, [0, 1]) + + def test_mask(self): + """Test the mask() method""" + mask1 = np.zeros((1, 100), dtype=bool) + mask1[0][4:9] = True + npt.assert_array_equal(self.span1.mask(2), mask1) + + mask2 = [[[0, 0, 0], [0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]]] + npt.assert_array_equal(self.span2.mask(1), mask2) + + @ddt.data( + (0, True, True), + ([0, 1], True, True), + ([0, 1, 2], True, True), + ([1, 2], True, True), + ([1], False, True), + (2, True, False), + ([0, 2], True, True), + ) + @ddt.unpack + def test_contains_pub(self, idx, span1_expected_res, span2_expected_res): + """The the contains_pub method""" + self.assertEqual(self.span1.contains_pub(idx), span1_expected_res) + self.assertEqual(self.span2.contains_pub(idx), span2_expected_res) + + def test_filter_by_pub(self): + """The the filter_by_pub method""" + self.assertEqual(self.span1.filter_by_pub([]), DoubleSliceSpan(self.start1, self.stop1, {})) + self.assertEqual(self.span2.filter_by_pub([]), DoubleSliceSpan(self.start2, self.stop2, {})) + + self.assertEqual( + self.span1.filter_by_pub([1, 0]), + DoubleSliceSpan(self.start1, self.stop1, {0: self.slices1[0]}), + ) + + self.assertEqual( + self.span1.filter_by_pub(2), + DoubleSliceSpan(self.start1, self.stop1, {2: self.slices1[2]}), + ) + + @ddt.ddt class TestExecutionSpans(IBMTestCase): """Class for testing ExecutionSpans."""