From bac77ff8d27a0eaabf4149f9c84b4a7421a9bdee Mon Sep 17 00:00:00 2001 From: Gagan Jain Date: Mon, 11 Nov 2024 10:16:10 -0800 Subject: [PATCH] Logging KT's key order warning only once (#2548) Summary: this warning is very noisy, changing to print it only once instead of every time Reviewed By: TroyGarden Differential Revision: D65700079 --- torchrec/sparse/jagged_tensor.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 4b5359f0d..243a58ee0 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -3380,14 +3380,21 @@ def _kt_unflatten( return KeyedTensor(context[0], context[1], values[0]) +print_flatten_spec_warn = True + + def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]: _keys, _length_per_key = spec.context # please read https://fburl.com/workplace/8bei5iju for more context, # you can also consider use short_circuit_pytree_ebc_regroup with KTRegroupAsDict - logger.warning( - "KT's key order might change from spec from the torch.export, this could have perf impact. " - f"{kt.keys()} vs {_keys}" - ) + global print_flatten_spec_warn + if print_flatten_spec_warn: + logger.warning( + "KT's key order might change from spec from the torch.export, this could have perf impact. " + f"{kt.keys()} vs {_keys}" + ) + print_flatten_spec_warn = False + res = permute_multi_embedding([kt], [_keys]) return [res[0]]