Skip to content

Commit

Permalink
Improved argument parsing for shapes (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jegp authored Sep 7, 2023
1 parent b6cb9d0 commit ee0877d
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions nir/ir.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
from collections import Counter
from collections.abc import Sequence
from dataclasses import dataclass

import numpy as np
Expand All @@ -11,6 +12,15 @@
# Shape is a dict mapping strings to shapes
Shape = typing.Dict[str, np.ndarray]

def _parse_shape_argument(x: Shape, key: str):
if isinstance(x, np.ndarray):
return {key: x}
elif isinstance(x, Sequence):
return {key: np.array(x)}
elif isinstance(x, dict):
return x
else:
raise ValueError("Unknown shape argument", x)

@dataclass
class NIRNode:
Expand Down Expand Up @@ -243,11 +253,8 @@ class Flatten(NIRNode):
end_dim: int = -1 # Last dimension to flatten

def __post_init__(self):
assert list(self.input_shape.keys()) == [
"input"
], "Flatten must have one input: `input`"
if isinstance(self.input_shape, np.ndarray):
self.input_shape = {"input": self.input_shape}
self.input_shape = _parse_shape_argument(self.input_shape, "input")
print(self.input_shape)
concat = self.input_shape["input"][self.start_dim : self.end_dim].prod()
self.output_shape = {
"output": np.array(
Expand Down Expand Up @@ -325,8 +332,7 @@ class Input(NIRNode):
input_shape: Shape

def __post_init__(self):
if isinstance(self.input_shape, np.ndarray):
self.input_shape = {"input": self.input_shape}
self.input_shape = _parse_shape_argument(self.input_shape, "input")
self.output_shape = {"output": self.input_shape["input"]}


Expand Down Expand Up @@ -431,8 +437,7 @@ class Output(NIRNode):
output_shape: Shape

def __post_init__(self):
if isinstance(self.output_shape, np.ndarray):
self.output_shape = {"output": self.output_shape}
self.output_shape = _parse_shape_argument(self.output_shape, "output")
self.input_shape = {"input": self.output_shape["output"]}


Expand Down

0 comments on commit ee0877d

Please sign in to comment.