From ee0877dbf7ed0d08331ea98c89905bc6299b0001 Mon Sep 17 00:00:00 2001 From: Jens Egholm Pedersen Date: Thu, 7 Sep 2023 19:07:45 +0200 Subject: [PATCH] Improved argument parsing for shapes (#51) --- nir/ir.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/nir/ir.py b/nir/ir.py index 09322fa..d73222c 100644 --- a/nir/ir.py +++ b/nir/ir.py @@ -1,5 +1,6 @@ import typing from collections import Counter +from collections.abc import Sequence from dataclasses import dataclass import numpy as np @@ -11,6 +12,15 @@ # Shape is a dict mapping strings to shapes Shape = typing.Dict[str, np.ndarray] +def _parse_shape_argument(x: Shape, key: str): + if isinstance(x, np.ndarray): + return {key: x} + elif isinstance(x, Sequence): + return {key: np.array(x)} + elif isinstance(x, dict): + return x + else: + raise ValueError("Unknown shape argument", x) @dataclass class NIRNode: @@ -243,11 +253,8 @@ class Flatten(NIRNode): end_dim: int = -1 # Last dimension to flatten def __post_init__(self): - assert list(self.input_shape.keys()) == [ - "input" - ], "Flatten must have one input: `input`" - if isinstance(self.input_shape, np.ndarray): - self.input_shape = {"input": self.input_shape} + self.input_shape = _parse_shape_argument(self.input_shape, "input") + print(self.input_shape) concat = self.input_shape["input"][self.start_dim : self.end_dim].prod() self.output_shape = { "output": np.array( @@ -325,8 +332,7 @@ class Input(NIRNode): input_shape: Shape def __post_init__(self): - if isinstance(self.input_shape, np.ndarray): - self.input_shape = {"input": self.input_shape} + self.input_shape = _parse_shape_argument(self.input_shape, "input") self.output_shape = {"output": self.input_shape["input"]} @@ -431,8 +437,7 @@ class Output(NIRNode): output_shape: Shape def __post_init__(self): - if isinstance(self.output_shape, np.ndarray): - self.output_shape = {"output": self.output_shape} + self.output_shape = _parse_shape_argument(self.output_shape, "output") self.input_shape = {"input": self.output_shape["output"]}