Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove batch from shape spec #18

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions nirtorch/to_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def extract_nir_graph(
model_map: Callable[[nn.Module], nir.NIRNode],
sample_data: Any,
model_name: Optional[str] = "model",
ignore_submodules_of=None
ignore_submodules_of=None,
) -> nir.NIRNode:
"""Given a `model`, generate an NIR representation using the specified `model_map`.

Expand Down Expand Up @@ -49,7 +49,9 @@ def extract_nir_graph(

# Convert the nodes and get indices
nir_edges = []
nir_nodes = {"input": nir.Input(np.array(sample_data.shape))}
nir_nodes = {
"input": nir.Input(np.array(sample_data.shape[1:]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to add a flag is_batched and only remove the first dimension if this flag is true? I think we would always have batched input, so leaving it like this would also be fine with me

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torch, the usual convention is to always have the batch dimension. So I would think it is safer to do this than to expect all other modules to add this flag of having a batch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to think about this for a bit. I don't think I understand the premise. Why do you have to modify the sample data? Can't the user just not include the batch dimension?

I'm asking because none of the PyTorch modules (linear, conv, ...) requires a batch dimension to evaluate them. Can't we just specify that whatever the user puts in, the user gets (with or without a batch dim)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't even aware this was possible! Alright I have an alternative solution, we can only look at the last necessary dimensions and ignore the other dims perhaps?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see, we can't do that because we don't know what the dimensionality of the output or input is going to be in the first place.

@stevenabreu7 where do you suggest is_batched flag to go?

Copy link
Collaborator

@Jegp Jegp Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what problem we're addressing at the moment. Can I ask why the shape of the input isn't sufficient? Wouldn't this be solved by something like extract_nir_graph(..., data.squeeze())?

} # Remove the first dimension

# Get all the NIR nodes
for indx, node in enumerate(torch_graph.node_list):
Expand Down Expand Up @@ -85,7 +87,9 @@ def extract_nir_graph(
if len(node.outgoing_nodes) == 0:
out_name = "output"
# Try to find shape of input to the Output node
output_node = nir.Output(torch_graph.module_output_types[node.elem])
output_node = nir.Output(
torch_graph.module_output_types[node.elem][1:]
) # Ignore batch dimension
nir_nodes[out_name] = output_node
nir_edges.append((node.name, out_name))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def _torch_convert(module: nn.Module) -> nir.NIRNode:
if isinstance(module, nn.Conv1d):
return nir.Conv1d(module.weight, 1, 1, 1, 1, module.bias)
return nir.Conv1d(None, module.weight, 1, 1, 1, 1, module.bias)
elif isinstance(module, nn.Linear):
return nir.Affine(module.weight, module.bias)
else:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_from_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _torch_model_map(m: nir.NIRNode, device: str = "cpu") -> torch.nn.Module:
lin.weight.data = torch.nn.Parameter(torch.tensor(m.weight).to(device))
return lin
elif isinstance(m, nir.Input) or isinstance(m, nir.Output):
return None
return torch.nn.Identity()
Jegp marked this conversation as resolved.
Show resolved Hide resolved
else:
raise NotImplementedError(f"Unsupported module {m}")

Expand Down Expand Up @@ -49,6 +49,7 @@ def test_extract_lin():
assert torch.allclose(m(x), y)


@pytest.mark.skip("Not yet supported")
def test_extrac_recurrent():
w = np.random.randn(1, 1)
g = nir.NIRGraph(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def test_output_type_when_single_node():
def test_sequential_flatten():
d = torch.empty(2, 3, 4)
g = extract_nir_graph(torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d)
g.nodes["input"].input_type["input"] == (2, 3, 4)
g.nodes["output"].output_type["output"] == (2, 3 * 4)
assert tuple(g.nodes["input"].input_type["input"]) == (3, 4)
assert tuple(g.nodes["output"].output_type["output"]) == (3 * 4, )


@pytest.mark.skip(reason="Not supported yet")
Expand Down
Loading