diff --git a/zeta/nn/attention/multi_modal_causal_attention.py b/zeta/nn/attention/multi_modal_causal_attention.py new file mode 100644 index 00000000..1be2e00d --- /dev/null +++ b/zeta/nn/attention/multi_modal_causal_attention.py @@ -0,0 +1,99 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + + +class MultiModalCausalAttention(nn.Module): + def __init__( + self, + dim, + heads=8, + dropout=0.1, + ): + super().__init__() + self.heads = heads + self.scale = dim**-0.5 + + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout)) + + def forward(self, visual_features, textual_features, mask=None): + b, n, _, h = *visual_features.shape, self.heads + + qkv_visual = self.to_qkv(visual_features).chunk(3, dim=-1) + qkv_textual = self.to_qkv(textual_features).chunk(3, dim=-1) + + q_visual, k_visual, v_visual = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv_visual + ) + + q_textual, k_textual, v_textual = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv_textual + ) + + dots_visual = torch.einsum("bhid,bhjd->bhij", q_visual, k_visual) * self.scale + + dots_textual = ( + torch.einsum( + "bhid,bhjd->bhij", + q_textual, + k_textual, + ) + * self.scale + ) + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value=True) + assert ( + mask.shape[-1] == dots_textual.shape[-1] + ), "mask has incorrect dimensions" + + mask = mask[:, None, :] * mask[:, :, None] + dots_textual.masked_fill(~mask, float("-inf")) + + del mask + + attn_visual = dots_visual.softmax(dim=-1) + attn_textual = dots_textual.softmax(dim=-1) + + out_visual = torch.einsum( + "bhij,bhjd->bhid", + attn_visual, + v_visual, + ) + + out_textual = torch.einsum( + "bhij,bhjd->bhid", + attn_textual, + v_textual, + ) + + out_visual = rearrange(out_visual, "b h n d -> b n (h d)") + + out_textual = rearrange(out_textual, "b h n d -> b n (h d)") + + return self.to_out(out_visual), self.to_out(out_textual) + + +class SimpleMMCA(nn.Module): + def __init__( + self, + dim, + heads, + ): + super().__init__() + + self.self_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads) + + self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads) + + def forward(self, v, t): + # self attention for visual tokens + v = self.self_attn(v, v, v)[0] + + # cross attention for textual tokens + t = self.cross_attn(t, t, t)[0] + self.cross_attn(t, v, v)[0] + + return t diff --git a/zeta/rl/__init__.py b/zeta/rl/__init__.py index 1cc5ba16..07276d02 100644 --- a/zeta/rl/__init__.py +++ b/zeta/rl/__init__.py @@ -1,3 +1,3 @@ from zeta.rl.reward_model import * from zeta.rl.reward_model import RewardModel -from zeta.rl.actor_critic import ActorCritic, ppo \ No newline at end of file +from zeta.rl.actor_critic import ActorCritic, ppo diff --git a/zeta/rl/actor_critic.py b/zeta/rl/actor_critic.py index 1cd49f16..78a9fb98 100644 --- a/zeta/rl/actor_critic.py +++ b/zeta/rl/actor_critic.py @@ -2,18 +2,12 @@ from torch import nn import torch.nn as optim + class ActorCritic(nn.Module): - def __init__( - self, - num_inputs, - num_outputs, - hidden_size - ): + def __init__(self, num_inputs, num_outputs, hidden_size): super(ActorCritic, self).__init__() self.critic = nn.Sequential( - nn.Linear(num_inputs, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, 1) + nn.Linear(num_inputs, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1) ) self.actor = nn.Sequential( nn.Linear(num_inputs, hidden_size), @@ -21,13 +15,14 @@ def __init__( nn.Linear(hidden_size, num_outputs), nn.Softmax(dim=1), ) - + def forward(self, x): value = self.critic(x) probs = self.actor(x) dist = torch.distributions.Categorial(probs) return dist, value - + + def ppo( policy_net, value_net, @@ -37,7 +32,7 @@ def ppo( actions, returns, advantages, - clip_param=0.2 + clip_param=0.2, ): dist, _ = policy_net(states) old_probs = dist.log_prob(actions).detach() @@ -61,8 +56,6 @@ def ppo( optimizer_policy.step() - - # import torch # import numpy as np @@ -97,4 +90,4 @@ def ppo( # # This ratio is used to compute the policy loss, which is then used to update the policy network. # # The policy loss is computed in a way that encourages the new action probabilities to stay close to the old ones, -# # which is the key idea behind PPO's objective of taking conservative policy updates. \ No newline at end of file +# # which is the key idea behind PPO's objective of taking conservative policy updates.