From 417cf7f871b7d7a2ae5870a76059c0728edd36b9 Mon Sep 17 00:00:00 2001 From: Promisery Date: Sat, 13 Jul 2024 11:11:42 +0800 Subject: [PATCH] Initialize weights of reg_token for ViT --- timm/models/vision_transformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 9a3ac5627d..fb95aebc5c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -590,6 +590,8 @@ def init_weights(self, mode: str = '') -> None: trunc_normal_(self.pos_embed, std=.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) + if self.reg_token is not None: + nn.init.normal_(self.reg_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m: nn.Module) -> None: