diff --git a/src/bloqade/compiler/codegen/python/emulator_ir.py b/src/bloqade/compiler/codegen/python/emulator_ir.py index 1917817db..f250a671e 100644 --- a/src/bloqade/compiler/codegen/python/emulator_ir.py +++ b/src/bloqade/compiler/codegen/python/emulator_ir.py @@ -338,6 +338,12 @@ def visit_field_RunTimeVector( self, node: field.RunTimeVector ) -> Dict[int, Decimal]: value = self.assignments[node.name] + for new_index, original_index in enumerate(self.original_index): + if original_index >= len(value): + raise ValueError( + f"Index {original_index} is out of bounds for the runtime vector {node.name}" + ) + return { new_index: Decimal(str(value[original_index])) for new_index, original_index in enumerate(self.original_index) @@ -347,6 +353,12 @@ def visit_field_RunTimeVector( def visit_field_AssignedRunTimeVector( self, node: field.AssignedRunTimeVector ) -> Dict[int, Decimal]: + for new_index, original_index in enumerate(self.original_index): + if original_index >= len(node.value): + raise ValueError( + f"Index {original_index} is out of bounds for the mask vector." + ) + return { new_index: Decimal(str(node.value[original_index])) for new_index, original_index in enumerate(self.original_index) @@ -357,7 +369,6 @@ def visit_field_ScaledLocations( self, node: field.ScaledLocations ) -> Dict[int, Decimal]: target_atoms = {} - for location in node.value.keys(): if location.value >= self.n_sites or location.value < 0: raise ValueError( diff --git a/src/bloqade/ir/analog_circuit.py b/src/bloqade/ir/analog_circuit.py index 6da9c3bb1..caa6fafbc 100644 --- a/src/bloqade/ir/analog_circuit.py +++ b/src/bloqade/ir/analog_circuit.py @@ -81,6 +81,9 @@ def figure(self, **assignments): # analysis the SpatialModulation information spmod_extracted_data: Dict[str, Tuple[List[int], List[float]]] = {} + def process_names(x): + return int(x.split("[")[-1].split("]")[0]) + for tab in fig_seq.tabs: pulse_name = tab.title field_plots = tab.child.children @@ -101,9 +104,7 @@ def figure(self, **assignments): for ch in channels: ch_data = Spmod_raw_data[Spmod_raw_data.d0 == ch] - sites = list( - map(lambda x: int(x.split("[")[-1].split("]")[0]), ch_data.d1) - ) + sites = list(map(process_names, ch_data.d1)) values = list(ch_data.px.astype(float)) key = f"{pulse_name}.{field_name}.{ch}" diff --git a/src/bloqade/ir/control/field.py b/src/bloqade/ir/control/field.py index 2a0b65ca5..2e362baf5 100644 --- a/src/bloqade/ir/control/field.py +++ b/src/bloqade/ir/control/field.py @@ -122,7 +122,13 @@ def figure(self, **assginment): return get_ir_figure(self, **assginment) def _get_data(self, **assignment): - return [self.name], ["vec"] + locs = [] + values = [] + for i, v in enumerate(self.value): + locs.append(f"{self.name or 'value'}[{i}]") + values.append(str(v)) + + return locs, values def show(self, **assignment): display_ir(self, **assignment) diff --git a/tests/test_batch.py b/tests/test_batch.py index 9fe95937e..d14ca80d8 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -172,7 +172,7 @@ def test_metadata_filter_scalar(): assert filtered_batch.tasks.keys() == {0, 1, 4} - with pytest.raises(ValueError): + with pytest.raises(Exception): filtered_batch = batch.filter_metadata(d=[1, 2, 16, 1j]) @@ -198,7 +198,7 @@ def test_metadata_filter_vector(): filters = dict(d=[1, 8], m=[[0, 1], [1, 0], (0, 0)]) - with pytest.raises(ValueError): + with pytest.raises(Exception): filtered_batch_all = batch.filter_metadata(**filters) diff --git a/tests/test_field.py b/tests/test_field.py index d428c4122..16c8b2f07 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -82,7 +82,7 @@ def test_assigned_runtime_vec(): ) assert x.print_node() == "AssignedRunTimeVector: sss" assert x.children() == cast([Decimal("1.0"), Decimal("2.0")]) - assert x._get_data() == (["sss"], ["vec"]) + assert x._get_data() == (["sss[0]", "sss[1]"], ["1.0", "2.0"]) mystdout = StringIO() p = PP(mystdout)