Skip to content

Commit

Permalink
remove diff in base.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 18, 2023
1 parent 3d45e05 commit b93c33a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 60 deletions.
84 changes: 26 additions & 58 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ def __init__(
else:
self.use_past = False

self.use_cache = self.parent_model.use_cache

self.use_fp16 = False
for inp in session.get_inputs():
if "past_key_values" in inp.name and inp.type == "tensor(float16)":
Expand Down Expand Up @@ -198,12 +196,6 @@ def __init__(
):
self.value_sequence_length_idx = -1

# Some architectures fuse the key cache and value cache.
if self.parent_model.config.model_type == "gpt_bigcode":
self.fuse_kv = True
else:
self.fuse_kv = False

def prepare_inputs_for_merged(
self,
input_ids: Union[None, torch.LongTensor, np.ndarray],
Expand All @@ -222,7 +214,6 @@ def prepare_inputs_for_merged(
use_cache_branch = use_cache_branch.to(self.device)

# Generate dummy past for the first forward if uses a merged decoder
# TODO: the controlflow here is ugly and should be removed in favor of class inheritance.
if self.parent_model.use_merged and past_key_values is None:
batch_size = input_ids.shape[0]

Expand All @@ -249,28 +240,13 @@ def prepare_inputs_for_merged(
key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value]
)
elif self.parent_model.config.model_type in MULTI_QUERY_ATTN_MODELS:
if self.fuse_kv:
shape_key_and_value = (batch_size, 1, embed_size_per_head * 2)
key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype)
shape_key_and_value = (batch_size, 1, embed_size_per_head * 2)
key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype)

if use_torch is True:
key_and_value = key_and_value.to(self.device)
if use_torch is True:
key_and_value = key_and_value.to(self.device)

past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names)))
else:
shape_key_and_value = (batch_size, 1, embed_size_per_head)
key = constructor.zeros(shape_key_and_value, dtype=dtype)
value = constructor.zeros(shape_key_and_value, dtype=dtype)

if use_torch is True:
key = key.to(self.device)
value = value.to(self.device)

past_key_values = tuple(
key_or_value
for _ in range(len(self.key_value_input_names) // 2)
for key_or_value in [key, value]
)
past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names)))
else:
shape = (batch_size, num_attention_heads, 1, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)
Expand Down Expand Up @@ -349,9 +325,7 @@ def compute_past_key_values_output_shapes_mqa(
if past_key_values is not None and use_cache_branch is not False:
sequence_length += past_key_values[0].size(-2)

# TODO: this controlflow should be eliminated by inheritance.
kv_fuse_factor = 2 if self.fuse_kv else 1
key_and_value_shape = (batch_size, sequence_length, embed_size_per_head * kv_fuse_factor)
key_and_value_shape = (batch_size, sequence_length, embed_size_per_head * 2)

return {name: key_and_value_shape for name in self.key_value_output_names}

Expand Down Expand Up @@ -430,19 +404,16 @@ def forward(
io_binding.synchronize_outputs()

# Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2)
if self.use_cache:
past_key_values = ()
for name in self.key_value_output_names:
past_key_values += (output_buffers[name].view(output_shapes[name]),)

# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (self-attention key and value per decoder layer)
if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS:
num_pkv = 2
past_key_values = tuple(
past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv)
)
else:
past_key_values = None
past_key_values = ()
for name in self.key_value_output_names:
past_key_values += (output_buffers[name].view(output_shapes[name]),)

# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (self-attention key and value per decoder layer)
if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS:
num_pkv = 2
past_key_values = tuple(
past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv)
)

logits = output_buffers["logits"].view(output_shapes["logits"])

Expand Down Expand Up @@ -497,22 +468,19 @@ def forward(
outputs = self.session.run(None, onnx_inputs)

# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention)
if self.use_cache:
past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]).to(self.device)
for key in self.key_value_output_names
)

# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and
# per decoder layer
if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS:
num_pkv = 2
past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]).to(self.device)
for key in self.key_value_output_names
past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv)
)

# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and
# per decoder layer
if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS:
num_pkv = 2
past_key_values = tuple(
past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv)
)
else:
past_key_values = None

logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device)

loss = None
Expand Down
2 changes: 0 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,6 @@ def _from_pretrained(
raise ValueError(
f"{ONNX_DECODER_WITH_PAST_NAME} not supported for the following architecture : {', '.join(MODEL_TO_PATCH_FOR_PAST)}. Please re-export your model or set use_cache=False."
)
else:
pass

regular_file_names = []
for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]:
Expand Down

0 comments on commit b93c33a

Please sign in to comment.