Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running inference on mobile #20

Open
daniel-sudz opened this issue Jul 11, 2023 · 4 comments · May be fixed by #22
Open

Running inference on mobile #20

daniel-sudz opened this issue Jul 11, 2023 · 4 comments · May be fixed by #22

Comments

@daniel-sudz
Copy link

daniel-sudz commented Jul 11, 2023

I was wondering if there would be any technical challenges that would make running inference on mobile impossible. The models themselves seem small but I'm not familiar enough to know if there are any other challenges/considerations.

Update 1

Currently I am on the following error when trying to convert the model to torchscript. Seems like it would be easy enough to fix but there would probably be more things that don't work afterwards as well.

Module 'Head' has no attribute 'res_blocks' (This attribute exists on the Python module, but we failed to convert Python type: 'list' to a TorchScript type. Could not infer type of list element: Cannot infer concrete type of torch.nn.Module. Its type was inferred; try adding a type annotation for the attribute.):
  File "/home/powerhorse/Desktop/daniel_tmp/benchmark/anchor/third_party/ace/ace_network.py", line 128
        res = self.head_skip(res) + x
    
        for res_block in self.res_blocks:
                         ~~~~~~~~~~~~~~~ <--- HERE
            x = F.relu(res_block[0](res))
            x = F.relu(res_block[1](x))
@daniel-sudz
Copy link
Author

daniel-sudz commented Jul 12, 2023

it looks this issue would be a problem with making the model scriptable: pytorch/pytorch#36061
https://github.com/nianticlabs/ace/blob/main/ace_network.py#L91

one alternative would be to do something like this:

        self.res_blocks = nn.ModuleList(
            nn.ModuleList(([
                nn.Conv2d(self.head_channels, self.head_channels, 1, 1, 0),
                nn.Conv2d(self.head_channels, self.head_channels, 1, 1, 0),
                nn.Conv2d(self.head_channels, self.head_channels, 1, 1, 0),
            ])) for block in range(num_head_blocks)
        )

though this would end up breaking this import:

pattern = re.compile(r"^heads\.\d+c0\.weight$")

@daniel-sudz
Copy link
Author

daniel-sudz commented Jul 12, 2023

@tcavallari I would really love to get inference to run on mobile if it's possible, but it appears as though the pretrained weights must be retrained as per this issue (I'm assuming that there are currently no plans to make the code for the pretrained weights available):

pytorch/pytorch#36061

To make these projects scriptable, we must replace list of modules with nn.ModuleList which will make the pretrained models invalid, since the replacement will change the keys of state_dict, and may increase the number of model parameters.

@tcavallari
Copy link
Collaborator

Hello!

To make these projects scriptable, we must replace list of modules with nn.ModuleList which will make the pretrained models invalid, since the replacement will change the keys of state_dict, and may increase the number of model parameters.

I don't think it's necessary to retrain the encoder. You can create a scriptable architecture using nn.ModuleList and inserting the same convolution layers (keeping the same parameters, i.e. kernel size, stride, padding, etc...) inside it. The number of model parameters shouldn't change at all.

Then it should just be a matter of replacing the weights in the new state_dict with the weights in the pretrained one we provide with the repo. You just have to figure out which keys in the new dict map to which keys in the old dictionary, but it should be relatively easy.

@daniel-sudz
Copy link
Author

Hello!

To make these projects scriptable, we must replace list of modules with nn.ModuleList which will make the pretrained models invalid, since the replacement will change the keys of state_dict, and may increase the number of model parameters.

I don't think it's necessary to retrain the encoder. You can create a scriptable architecture using nn.ModuleList and inserting the same convolution layers (keeping the same parameters, i.e. kernel size, stride, padding, etc...) inside it. The number of model parameters shouldn't change at all.

Then it should just be a matter of replacing the weights in the new state_dict with the weights in the pretrained one we provide with the repo. You just have to figure out which keys in the new dict map to which keys in the old dictionary, but it should be relatively easy.

Ok awesome I think I managed to figure it out. I believe this change should be completely seamless in terms of back-compat so do you think we could merge this into main? I made a PR here in #22. Thank you so much for the guidance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants