From c8eb8aa3a30e365324fea4525b9adce1380f4887 Mon Sep 17 00:00:00 2001 From: jiayisunx Date: Mon, 1 Jul 2024 16:06:46 +0800 Subject: [PATCH] add the meta registration for choose_tpp_linear_weight (#3020) --- intel_extension_for_pytorch/_meta_registrations.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/intel_extension_for_pytorch/_meta_registrations.py b/intel_extension_for_pytorch/_meta_registrations.py index ed11f5913..945f60407 100644 --- a/intel_extension_for_pytorch/_meta_registrations.py +++ b/intel_extension_for_pytorch/_meta_registrations.py @@ -521,6 +521,16 @@ def meta_tpp_linear_bias( return input.new_empty((*input.shape[:-1], out_features)) +@register_meta("choose_tpp_linear_weight") +def meta_choose_tpp_linear_weight(x, weight, weight_for_large_batch): + M = x.numel() // x.size(-1) + return ( + weight_for_large_batch + if weight_for_large_batch is not None and M >= 256 + else weight + ) + + @register_meta("tpp_linear_gelu") def meta_tpp_linear_gelu( input,