-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcheck.py
65 lines (52 loc) · 1.57 KB
/
check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import timm
import inspect
import torch
# Load the Vision Transformer model
model = timm.create_model('vit_base_patch16_224', img_size=32, patch_size=4, num_classes=10, pretrained=True,)
"""
x = torch.randn(1, 3, 32, 32)
output = model(x)
m = torch.nn.Softmax(dim=1)
output = m(output)
print(output.shape)
print(output)
"""
"""
len_blocks = len(model.blocks)
print(len_blocks)
"""
def Custom_forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for idx, block in enumerate(self.blocks):
x = block(x)
x = self.norm(x)
return x
x = torch.randn(1, 3, 32, 32)
model.forward_features = Custom_forward_features(x)
# Get the forward method of the model
forward_method = model.forward_features
# Use inspect to get the source code of the forward method
forward_source_code = inspect.getsource(forward_method)
print(forward_source_code)
# print out name of the modules and parameter shapes
"""
# Iterate through all named modules
for module_name, module in model.named_modules():
print(f"Module: {module_name}")
# Iterate through all named parameters of the module
for param_name, param in module.named_parameters():
print(f"\tParameter: {param_name}, Shape: {param.shape}")
"""
"""
# Print out shapes of the parameters
for name, param in model.named_parameters():
print(f"{name}: {param.shape}")
"""
"""
all_densenet_models = timm.list_models('*vit_base*')
for model in all_densenet_models:
print(model)
"""