-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to generate model parameters based on input "x"? #6
Comments
Hey! Thanks for the interest in the library! I think what you’re looking for is in the DynamicHypernetwork functionality, I’ll post an example here soon |
Thanks for the reference, one question: Is there anything specially designed here to use hyper-nn/hypernn/torch/dynamic_hypernet.py Lines 40 to 43 in 2765728
Is it possible to replace the rnn cell with a simple linear layer here? Any suggestion? Thanks! |
Hey! Yeah definitely you can replace the rnn with a linear layer. I added the rnn cell mainly to replicate the original work: https://blog.otoro.net/2016/09/28/hyper-networks/, but you can basically add whatever you want in it. Here's an example right here: from typing import Optional, Iterable, Any, Tuple, Dict
import torch
import torch.nn as nn
# static hypernetwork
from hypernn.torch import TorchHyperNetwork
from hypernn.torch.utils import get_weight_chunk_dims
class DynamicLinearHypernetwork(TorchHyperNetwork):
def __init__(
self,
inp_dims: int,
target_network: nn.Module,
num_target_parameters: Optional[int] = None,
embedding_dim: int = 100,
num_embeddings: int = 3,
weight_chunk_dim: Optional[int] = None,
):
super().__init__(
target_network = target_network,
num_target_parameters = num_target_parameters,
)
self.inp_dims = inp_dims
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.weight_chunk_dim = weight_chunk_dim
if weight_chunk_dim is None:
self.weight_chunk_dim = get_weight_chunk_dims(
self.num_target_parameters, num_embeddings
)
self.embedding_module = self.make_embedding_module()
self.weight_generator = self.make_weight_generator()
self.inp_embedder = nn.Linear(self.inp_dims, self.num_embeddings)
def make_embedding_module(self) -> nn.Module:
return nn.Embedding(self.num_embeddings, self.embedding_dim)
def make_weight_generator(self) -> nn.Module:
return nn.Linear(self.embedding_dim, self.weight_chunk_dim)
def generate_params(
self, inp: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, Any]]:
embedded_inp = self.inp_embedder(inp).view(self.num_embeddings, -1)
embedding = self.embedding_module(
torch.arange(self.num_embeddings, device=self.device)
) * embedded_inp
generated_params = self.weight_generator(embedding).view(-1)
return generated_params, {"embedding": embedding}
# usage
target_network = nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, 32)
)
INP_DIM = 32
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32
hypernetwork = DynamicLinearHypernetwork(
inp_dims = INP_DIM,
target_network = target_network,
embedding_dim = EMBEDDING_DIM,
num_embeddings = NUM_EMBEDDINGS
)
inp = torch.zeros((1, 32))
out = hypernetwork(inp, generate_params_kwargs=dict(inp=inp))
print(out.shape) |
Hi Shyam! Thanks for providing the easy-to-use wrapper for hypernetworks, it's amazing!
One question: The parameters are generated by function
generate_params
hyper-nn/hypernn/torch/linear_hypernet.py
Lines 56 to 61 in 2765728
Thus the starting points are always random embeddings, right? If I want to make use of the input data
x
also as the input of hypernetwork, what should I do? Could you please kindly outline a bit?The text was updated successfully, but these errors were encountered: