Skip to content

Commit

Permalink
mmca
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 11, 2023
1 parent 1fe6412 commit c82a810
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 16 deletions.
99 changes: 99 additions & 0 deletions zeta/nn/attention/multi_modal_causal_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn


class MultiModalCausalAttention(nn.Module):
def __init__(
self,
dim,
heads=8,
dropout=0.1,
):
super().__init__()
self.heads = heads
self.scale = dim**-0.5

self.to_qkv = nn.Linear(dim, dim * 3, bias=False)

self.to_out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))

def forward(self, visual_features, textual_features, mask=None):
b, n, _, h = *visual_features.shape, self.heads

qkv_visual = self.to_qkv(visual_features).chunk(3, dim=-1)
qkv_textual = self.to_qkv(textual_features).chunk(3, dim=-1)

q_visual, k_visual, v_visual = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv_visual
)

q_textual, k_textual, v_textual = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv_textual
)

dots_visual = torch.einsum("bhid,bhjd->bhij", q_visual, k_visual) * self.scale

dots_textual = (
torch.einsum(
"bhid,bhjd->bhij",
q_textual,
k_textual,
)
* self.scale
)

if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value=True)
assert (
mask.shape[-1] == dots_textual.shape[-1]
), "mask has incorrect dimensions"

mask = mask[:, None, :] * mask[:, :, None]
dots_textual.masked_fill(~mask, float("-inf"))

del mask

attn_visual = dots_visual.softmax(dim=-1)
attn_textual = dots_textual.softmax(dim=-1)

out_visual = torch.einsum(
"bhij,bhjd->bhid",
attn_visual,
v_visual,
)

out_textual = torch.einsum(
"bhij,bhjd->bhid",
attn_textual,
v_textual,
)

out_visual = rearrange(out_visual, "b h n d -> b n (h d)")

out_textual = rearrange(out_textual, "b h n d -> b n (h d)")

return self.to_out(out_visual), self.to_out(out_textual)


class SimpleMMCA(nn.Module):
def __init__(
self,
dim,
heads,
):
super().__init__()

self.self_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads)

self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads)

def forward(self, v, t):
# self attention for visual tokens
v = self.self_attn(v, v, v)[0]

# cross attention for textual tokens
t = self.cross_attn(t, t, t)[0] + self.cross_attn(t, v, v)[0]

return t
2 changes: 1 addition & 1 deletion zeta/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from zeta.rl.reward_model import *
from zeta.rl.reward_model import RewardModel
from zeta.rl.actor_critic import ActorCritic, ppo
from zeta.rl.actor_critic import ActorCritic, ppo
23 changes: 8 additions & 15 deletions zeta/rl/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,27 @@
from torch import nn
import torch.nn as optim


class ActorCritic(nn.Module):
def __init__(
self,
num_inputs,
num_outputs,
hidden_size
):
def __init__(self, num_inputs, num_outputs, hidden_size):
super(ActorCritic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(num_inputs, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1)
nn.Linear(num_inputs, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1)
)
self.actor = nn.Sequential(
nn.Linear(num_inputs, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, num_outputs),
nn.Softmax(dim=1),
)

def forward(self, x):
value = self.critic(x)
probs = self.actor(x)
dist = torch.distributions.Categorial(probs)
return dist, value



def ppo(
policy_net,
value_net,
Expand All @@ -37,7 +32,7 @@ def ppo(
actions,
returns,
advantages,
clip_param=0.2
clip_param=0.2,
):
dist, _ = policy_net(states)
old_probs = dist.log_prob(actions).detach()
Expand All @@ -61,8 +56,6 @@ def ppo(
optimizer_policy.step()




# import torch
# import numpy as np

Expand Down Expand Up @@ -97,4 +90,4 @@ def ppo(
# # This ratio is used to compute the policy loss, which is then used to update the policy network.

# # The policy loss is computed in a way that encourages the new action probabilities to stay close to the old ones,
# # which is the key idea behind PPO's objective of taking conservative policy updates.
# # which is the key idea behind PPO's objective of taking conservative policy updates.

0 comments on commit c82a810

Please sign in to comment.