Skip to content

Commit

Permalink
Add option to colour wires in diagrams with frames (#177)
Browse files Browse the repository at this point in the history
* Add options for setting the line widths for boxes and wires
* Update the output of the Tikz drawing backend

---------

Co-authored-by: Neil John D. Ortega <[email protected]>
  • Loading branch information
Ragunath1729 and neiljdo authored Nov 22, 2024
1 parent db8e7eb commit 9ce7774
Show file tree
Hide file tree
Showing 8 changed files with 1,081 additions and 697 deletions.
157 changes: 153 additions & 4 deletions lambeq/backend/drawing/drawable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class WireEndpoint:

x: float
y: float

noun_id: int = 0 # New attribute for wire noun
parent: Optional['BoxNode'] = None

@property
Expand Down Expand Up @@ -778,6 +778,7 @@ class DrawableDiagramWithFrames(DrawableDiagram):
frame, carrying all information necessary to render it.
"""
noun_id_counter: int = 1

def _make_space(self,
scan: list[int],
Expand Down Expand Up @@ -841,6 +842,120 @@ def _make_space(self,

return x, y

def _add_box_with_nouns(
self,
scan: list[int],
box: grammar.Box,
off: int,
x_pos: float,
y_pos: float,
input_nouns: list[int]
) -> tuple[list[int], int, list[int]]:
"""Add a box to the graph, creating necessary wire endpoints.
Returns
-------
list : int
The new scan of wire endpoints after adding the box
box_ind : int
The index of the newly added `BoxNode`
input_nouns : list[int]
The new order of input_nouns after adding the box
"""
node = BoxNode(box, x_pos, y_pos)

box_ind = self._add_boxnode(node)
num_input = len(box.dom)
input_nouns = input_nouns or []
for i in range(num_input):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = self.get_noun_id()
input_nouns.append(new_color)

# Create a node representing each element in the box's domain
for i, obj in enumerate(box.dom):
idx = off + i
nbr_idx = scan[off + i]
noun_id = (
input_nouns[idx] if (input_nouns and idx < len(input_nouns))
else self.get_noun_id()
) # generate new noun_id if needed

wire_end = WireEndpoint(WireEndpointType.DOM,
obj=obj,
x=self.wire_endpoints[nbr_idx].x,
y=y_pos + HALF_BOX_HEIGHT,
noun_id=noun_id)

wire_idx = self._add_wire_end(wire_end)
node.add_dom_wire(wire_idx)
self._add_wire(nbr_idx, wire_idx)

scan_insert = []
if isinstance(box, grammar.Swap):
# If Swap, exchange the noun_ids
if input_nouns and len(box.dom) > 1:
dom_idx_1 = off
dom_idx_2 = off + 1
input_nouns[dom_idx_1], input_nouns[dom_idx_2] = (
input_nouns[dom_idx_2], input_nouns[dom_idx_1]
)
elif isinstance(node.obj, grammar.Spider):
# If Spider, expand or shrink the noun_ids based on type
if len(box.dom) == 1 and len(box.cod) > 1:
dom_noun = (input_nouns[off] if input_nouns
and off < len(input_nouns)
else self.get_noun_id())
expanded_colors = [dom_noun] * len(box.cod)
input_nouns = (input_nouns[:off] + expanded_colors
+ input_nouns[off + len(box.dom):])
elif len(box.dom) > 1 and len(box.cod) == 1:
cod_noun = (input_nouns[off] if input_nouns
and off < len(input_nouns)
else self.get_noun_id())
input_nouns = (input_nouns[:off] + [cod_noun]
+ input_nouns[off + len(box.dom):])

num_output = off + len(box.cod)
for i in range(num_output):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = self.get_noun_id()
input_nouns.append(new_color)

# Create a node representing each element in the box's codomain
for i, obj in enumerate(box.cod):
# If the box is a quantum gate, retain x coordinate of wires
if box.category == quantum and len(box.dom) == len(box.cod):
nbr_idx = scan[off + i]
x = self.wire_endpoints[nbr_idx].x
else:
x = x_pos + X_SPACING * (i - len(box.cod[1:]) / 2)
y = y_pos - HALF_BOX_HEIGHT
idx = off + i
noun_id = (input_nouns[idx] if input_nouns
and idx < len(input_nouns)
else self.get_noun_id())
wire_end = WireEndpoint(WireEndpointType.COD,
obj=obj,
x=x,
y=y,
noun_id=noun_id)

wire_idx = self._add_wire_end(wire_end)
scan_insert.append(wire_idx)
node.add_cod_wire(wire_idx)

# Replace node's dom with its cod in scan
return (scan[:off] + scan_insert + scan[off + len(box.dom):],
box_ind, input_nouns)

def _make_space_for_frame(self,
scan: list[int],
off: int,
Expand Down Expand Up @@ -949,6 +1064,20 @@ def calculate_bounds(self) -> tuple[float, float, float, float]:

return min(all_xs), min(all_ys), max(all_xs), max(all_ys)

def get_noun_id(self) -> int:
"""Get the latest available numerical ID for the noun wire.
Returns
-------
noun_id : int
The latest noun wire ID.
"""
# Increment and return the next available ID
noun_id = self.noun_id_counter
self.noun_id_counter += 1
return noun_id

@classmethod
def from_diagram(cls,
diagram: grammar.Diagram,
Expand Down Expand Up @@ -976,12 +1105,19 @@ def from_diagram(cls,
drawable = cls()

scan = []
# Generate unique noun_ids for input wires
num_input = len(diagram.dom)
input_nouns = []
for _ in range(num_input):
new_color = drawable.get_noun_id()
input_nouns.append(new_color)

for i, obj in enumerate(diagram.dom):
wire_end = WireEndpoint(WireEndpointType.INPUT,
obj=obj,
x=X_SPACING * i,
y=1)
y=1,
noun_id=input_nouns[i])
wire_end_idx = drawable._add_wire_end(wire_end)
scan.append(wire_end_idx)

Expand All @@ -993,7 +1129,8 @@ def from_diagram(cls,
# TODO: Debug issues with y coord
x, y = drawable._make_space(scan, box, off, foliated=foliated)

scan, box_ind = drawable._add_box(scan, box, off, x, y)
scan, box_ind, input_nouns = drawable._add_box_with_nouns(
scan, box, off, x, y, input_nouns)
box_height = BOX_HEIGHT
# Add drawables for the inside of the frame
if isinstance(box, grammar.Frame):
Expand All @@ -1004,12 +1141,23 @@ def from_diagram(cls,
max_box_half_height = max(max_box_half_height, (box_height / 2))
min_y = min(min_y, y)

num_output = len(diagram.cod)
# Match output nouns with input nouns as much as possible
for i in range(num_output):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = drawable.get_noun_id()
input_nouns.append(new_color)

for i, obj in enumerate(diagram.cod):
wire_end = WireEndpoint(
WireEndpointType.OUTPUT,
obj=obj,
x=drawable.wire_endpoints[scan[i]].x,
y=min_y - max_box_half_height - 1.5 * BOX_HEIGHT
y=min_y - max_box_half_height - 1.5 * BOX_HEIGHT,
noun_id=input_nouns[i]
)
wire_end_idx = drawable._add_wire_end(wire_end)
drawable._add_wire(scan[i], wire_end_idx)
Expand Down Expand Up @@ -1384,6 +1532,7 @@ def _merge_with(self, drawable: 'DrawableDiagramWithFrames') -> None:
last_wire_endpoint = len(self.wire_endpoints)

for wire_endpoint in drawable.wire_endpoints:
wire_endpoint.noun_id = 0
self.wire_endpoints.append(wire_endpoint)

for box in drawable.boxes:
Expand Down
49 changes: 41 additions & 8 deletions lambeq/backend/drawing/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,18 @@
DEFAULT_ASPECT,
DEFAULT_MARGINS,
DrawingBackend,
FRAME_COLORS)
FRAME_COLORS,
WIRE_COLORS)
from lambeq.backend.drawing.helpers import drawn_as_spider, needs_asymmetry
from lambeq.backend.drawing.mat_backend import MatBackend
from lambeq.backend.drawing.mat_backend import (
BOX_LINEWIDTH as MAT_BOX_LINEWIDTH, MatBackend,
WIRE_LINEWIDTH as MAT_WIRE_LINEWIDTH
)
from lambeq.backend.drawing.text_printer import PregroupTextPrinter
from lambeq.backend.drawing.tikz_backend import TikzBackend
from lambeq.backend.drawing.tikz_backend import (
BOX_LINEWIDTH as TIKZ_BOX_LINEWIDTH, TikzBackend,
WIRE_LINEWIDTH as TIKZ_WIRE_LINEWIDTH
)
from lambeq.backend.grammar import Box, Diagram


Expand Down Expand Up @@ -113,6 +120,9 @@ def draw(diagram: Diagram, **params) -> None:
params['coloring_mode'] = params.get(
'coloring_mode', ColoringMode.TYPE.value,
)
params['color_wires'] = params.get(
'color_wires', diagram.has_frames,
)
if drawable is None:
drawable = drawable_cls.from_diagram(diagram,
params.get('foliated', False))
Expand All @@ -125,9 +135,18 @@ def draw(diagram: Diagram, **params) -> None:
backend: DrawingBackend = params.pop('backend')
elif params.get('to_tikz', False):
backend = TikzBackend(
use_tikzstyles=params.get('use_tikzstyles', None))
use_tikzstyles=params.get('use_tikzstyles', None),
box_linewidth=params.get('box_linewidth', TIKZ_BOX_LINEWIDTH),
wire_linewidth=params.get('wire_linewidth',
TIKZ_WIRE_LINEWIDTH),
)
else:
backend = MatBackend(figsize=params.get('figsize', None))
backend = MatBackend(
figsize=params.get('figsize', None),
box_linewidth=params.get('box_linewidth', MAT_BOX_LINEWIDTH),
wire_linewidth=params.get('wire_linewidth',
MAT_WIRE_LINEWIDTH),
)

min_size = 0.01
max_v = max([v for point in ([point.coordinates for point in
Expand Down Expand Up @@ -463,6 +482,14 @@ def _get_box_color(box: grammar.Diagrammable,
return color


def _get_wire_color(wire_id):
if wire_id == 0:
return '#000000'
else:
wire_color = WIRE_COLORS[(wire_id - 1) % len(WIRE_COLORS)]
return wire_color


def _draw_pregroup_state(backend: DrawingBackend,
drawable_box: BoxNode,
**params) -> DrawingBackend:
Expand Down Expand Up @@ -527,9 +554,15 @@ def _draw_wires(backend: DrawingBackend,
for src_idx, tgt_idx in drawable_diagram.wires:
source = drawable_diagram.wire_endpoints[src_idx]
target = drawable_diagram.wire_endpoints[tgt_idx]

backend.draw_wire(
source.coordinates, target.coordinates)
wire_color_id = 0
if params.get('color_wires'):
# Determine the color based on the type of the source
if source.kind in {WireEndpointType.INPUT}:
wire_color_id = source.noun_id
else:
wire_color_id = target.noun_id
backend.draw_wire(source.coordinates, target.coordinates,
color_id=wire_color_id, **params)

if (params.get('draw_type_labels', True) and source.kind in
{WireEndpointType.INPUT, WireEndpointType.COD}):
Expand Down
42 changes: 41 additions & 1 deletion lambeq/backend/drawing/drawing_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@
DEFAULT_MARGINS = (.05, .1)
DEFAULT_ASPECT = 'equal'

WIRE_COLORS: list[str] = [
'#9c540e', '#f4a940', '#066ee2', '#d03b2d', '#7fd68b',
'#574cfa', '#49a141', '#a629b3', '#271296', '#ff6347',
'#adff2f', '#7446f2', '#007765', '#b60539', '#ff00ff',
'#c330b9', '#73b8fd', '#ff1493', '#00bfff', '#ffb6c1',
'#740127', '#e2074c', '#0252a1', '#fea431', '#205356',
'#450d06', '#d17800', '#3831a0', '#ff4500', '#d8bfd8'
]
WIRE_COLORS_NAMES: dict[str, str] = {
'#ffffff' : '#ffffff',
'#000000' : '#000000'
}
for color in WIRE_COLORS:
WIRE_COLORS_NAMES[color] = color

FRAME_COLORS: list[str] = [
'#fbe8e7', '#fee1ba', '#fff9e5', '#e8f8ea', '#dcfbf5',
Expand Down Expand Up @@ -146,7 +160,9 @@ def draw_wire(self,
bend_out: bool = False,
bend_in: bool = False,
is_leg: bool = False,
style: str | None = None) -> None:
style: str | None = None,
color_id: int = 0,
**params) -> None:
"""
Draws a wire from source to target, possibly with a curve
Expand Down Expand Up @@ -184,6 +200,30 @@ def draw_spiders(self, drawable: DrawableDiagram, **params) -> None:
"""

def _get_wire_color(self, wire_id : int, **params) -> str:
"""
Retrieves a color that uniquely represent a given wire ID.
Parameters
----------
wire_id : int
The noun identifier of the wire for which the color is
being retrieved.
**params:
Additional parameters.
Returns
-------
wire_color : str
The hex color of the wire, represented as a string.
"""
if not params.get('color_wires') or wire_id == 0:
return '#000000'
else:
wire_color = WIRE_COLORS[(wire_id - 1) % len(WIRE_COLORS)]
return wire_color

@abstractmethod
def output(self,
path: str | None = None,
Expand Down
Loading

0 comments on commit 9ce7774

Please sign in to comment.