Skip to content

Commit

Permalink
[Bugfix] Fix incorrect output on OLMo models in Tensor Parallelism (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored Apr 5, 2024
1 parent 18de883 commit 54951ac
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions vllm/model_executor/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@
from typing import List, Optional, Tuple

import torch
import torch.nn.functional as F
# this model must need this dependency
from hf_olmo import OLMoConfig
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand All @@ -62,17 +63,6 @@
from vllm.sequence import SamplerOutput


class SwiGLU(nn.Module):

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x

@property
def output_multiplier(self) -> float:
return 0.5


class OlmoAttention(nn.Module):
"""
This is the attention block where the output is computed as
Expand Down Expand Up @@ -174,17 +164,16 @@ def __init__(
bias=False)

# Feed-forward input projection.
self.ff_proj = ColumnParallelLinear(
self.ff_proj = MergedColumnParallelLinear(
config.d_model,
self.hidden_size,
[self.hidden_size // 2] * 2,
bias=config.include_bias,
linear_method=linear_method,
)

# Activation function.
# self.act = SiluAndMul()
# self.act.output_multiplier = 0.5
self.act = SwiGLU()
self.act = SiluAndMul()
self.act.output_multiplier = 0.5
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0

# Feed-forward output projection.
Expand Down Expand Up @@ -374,8 +363,12 @@ def load_weights(
if ".att" in name:
name = name.replace(".att", ".attn.att")
# mlp
if ".ff" in name and "transformer.ff_out" not in name:
name = name.replace(".ff", ".mlp.ff")
if ".ff_proj" in name:
name = name.replace(".ff_proj", ".mlp.ff_proj")
# Reverse the weight for the MergeColumnParallelLinear
loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1])
if ".ff_out" in name and "transformer.ff_out" not in name:
name = name.replace(".ff_out", ".mlp.ff_out")
# there is no bias in olmo
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
Expand Down

0 comments on commit 54951ac

Please sign in to comment.