Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to generate model parameters based on input "x"? #6

Open
leitro opened this issue Sep 4, 2023 · 3 comments
Open

How to generate model parameters based on input "x"? #6

leitro opened this issue Sep 4, 2023 · 3 comments

Comments

@leitro
Copy link

leitro commented Sep 4, 2023

Hi Shyam! Thanks for providing the easy-to-use wrapper for hypernetworks, it's amazing!

One question: The parameters are generated by function generate_params

def generate_params(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
embedding = self.embedding_module(
torch.arange(self.num_embeddings, device=self.device)
)
generated_params = self.weight_generator(embedding).view(-1)
return generated_params, {"embedding": embedding}

Thus the starting points are always random embeddings, right? If I want to make use of the input data x also as the input of hypernetwork, what should I do? Could you please kindly outline a bit?

@shyamsn97
Copy link
Owner

Hey! Thanks for the interest in the library! I think what you’re looking for is in the DynamicHypernetwork functionality, I’ll post an example here soon

@leitro
Copy link
Author

leitro commented Sep 16, 2023

Thanks for the reference, one question: Is there anything specially designed here to use rnn_cell to deal with the input?

hidden_state = self.rnn_cell(x, hidden_state)
indices = torch.arange(self.num_embeddings, device=self.device)
embedding = self.embedding(indices) * hidden_state.view(self.num_embeddings, 1)
return embedding, hidden_state

Is it possible to replace the rnn cell with a simple linear layer here? Any suggestion? Thanks!

@shyamsn97
Copy link
Owner

Hey! Yeah definitely you can replace the rnn with a linear layer. I added the rnn cell mainly to replicate the original work: https://blog.otoro.net/2016/09/28/hyper-networks/, but you can basically add whatever you want in it. Here's an example right here:

from typing import Optional, Iterable, Any, Tuple, Dict
import torch
import torch.nn as nn
# static hypernetwork
from hypernn.torch import TorchHyperNetwork
from hypernn.torch.utils import get_weight_chunk_dims

class DynamicLinearHypernetwork(TorchHyperNetwork):
    def __init__(
        self,
        inp_dims: int,
        target_network: nn.Module,
        num_target_parameters: Optional[int] = None,
        embedding_dim: int = 100,
        num_embeddings: int = 3,
        weight_chunk_dim: Optional[int] = None,
    ):
        super().__init__(
                    target_network = target_network,
                    num_target_parameters = num_target_parameters,
                )
        self.inp_dims = inp_dims
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.weight_chunk_dim = weight_chunk_dim

        if weight_chunk_dim is None:
            self.weight_chunk_dim = get_weight_chunk_dims(
                self.num_target_parameters, num_embeddings
            )
    
        self.embedding_module = self.make_embedding_module()
        self.weight_generator = self.make_weight_generator()    
        self.inp_embedder = nn.Linear(self.inp_dims, self.num_embeddings)

    def make_embedding_module(self) -> nn.Module:
        return nn.Embedding(self.num_embeddings, self.embedding_dim)

    def make_weight_generator(self) -> nn.Module:
        return nn.Linear(self.embedding_dim, self.weight_chunk_dim)

    def generate_params(
        self, inp: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        embedded_inp = self.inp_embedder(inp).view(self.num_embeddings, -1)
        embedding = self.embedding_module(
            torch.arange(self.num_embeddings, device=self.device)
        ) * embedded_inp
        generated_params = self.weight_generator(embedding).view(-1)
        return generated_params, {"embedding": embedding}

    # usage
target_network = nn.Sequential(
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Linear(64, 32)
)

INP_DIM = 32
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

hypernetwork = DynamicLinearHypernetwork(
    inp_dims = INP_DIM,
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)
inp = torch.zeros((1, 32))

out = hypernetwork(inp, generate_params_kwargs=dict(inp=inp))
print(out.shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants