Skip to content

Commit

Permalink
BUG FIX
Browse files Browse the repository at this point in the history
  • Loading branch information
FantasticGNU committed Sep 1, 2023
1 parent 59fab40 commit ac4164b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
8 changes: 6 additions & 2 deletions code/datasets/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __getitem__(self, index):

conversation_abnormal = conversation_normal

return origin, conversation_normal, x, conversation_abnormal, class_name, mask
return origin, conversation_normal, x, conversation_abnormal, class_name, mask, img_path



Expand All @@ -201,20 +201,24 @@ def collate(self, instances):
texts = []
class_names = []
masks = []
img_paths = []
for instance in instances:
images.append(instance[0])
texts.append(instance[1])
class_names.append(instance[4])
masks.append(torch.zeros_like(instance[5]))
img_paths.append(instance[6])

images.append(instance[2])
texts.append(instance[3])
class_names.append(instance[4])
masks.append(instance[5])
img_paths.append(instance[6])

return dict(
images=images,
texts=texts,
class_names=class_names,
masks=masks
masks=masks,
img_paths=img_paths
)
8 changes: 6 additions & 2 deletions code/datasets/visa.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __getitem__(self, index):
conversation_abnormal = conversation_normal


return origin, conversation_normal, x, conversation_abnormal, class_name, mask
return origin, conversation_normal, x, conversation_abnormal, class_name, mask, img_path



Expand All @@ -182,21 +182,25 @@ def collate(self, instances):
texts = []
class_names = []
masks = []
img_paths = []
for instance in instances:
images.append(instance[0])
texts.append(instance[1])
class_names.append(instance[4])
masks.append(torch.zeros_like(instance[5]))
img_paths.append(instance[6])

images.append(instance[2])
texts.append(instance[3])
class_names.append(instance[4])
masks.append(instance[5])
img_paths.append(instance[6])


return dict(
images=images,
texts=texts,
class_names=class_names,
masks=masks
masks=masks,
img_paths=img_paths
)
16 changes: 8 additions & 8 deletions code/model/openllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,22 +443,18 @@ def forward(self, inputs):
anomaly_maps = []
for layer in range(len(patch_tokens)):
patch_tokens[layer] = patch_tokens[layer] / patch_tokens[layer].norm(dim=-1, keepdim=True)
# print(patch_tokens[layer].shape)
# anomaly_map = torch.bmm(patch_tokens[layer], feats_text_tensor.transpose(-2,-1))
anomaly_map = (100.0 * patch_tokens[layer] @ feats_text_tensor.transpose(-2,-1))
B, L, C = anomaly_map.shape
H = int(np.sqrt(L))
anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
size=224, mode='bilinear', align_corners=True)
# anomaly_map_no_softmax = anomaly_map

anomaly_map = torch.softmax(anomaly_map, dim=1)
anomaly_maps.append(anomaly_map)
# anomaly_maps_ns.append(anomaly_map_no_softmax)
anomaly_maps.append(anomaly_map)

gt = inputs['masks']
gt = torch.stack(gt, dim=0).to(self.device)
gt = gt.squeeze()
# print(gt.max(), gt.min())
gt[gt > 0.3], gt[gt <= 0.3] = 1, 0


Expand All @@ -476,8 +472,12 @@ def forward(self, inputs):

normal_paths = []
for path in inputs['img_paths']:
normal_path = path.replace('test', 'train')
normal_path = find_first_file_in_directory("/".join(normal_path.split('/')[:-2])+'/good')
if 'visa' in image_paths.lower():
normal_path = path.replace('Anomaly', 'Normal')
normal_path = find_first_file_in_directory("/".join(normal_path.split('/')[:-1]))
else:
normal_path = path.replace('test', 'train')
normal_path = find_first_file_in_directory("/".join(normal_path.split('/')[:-2])+'/good')
normal_paths.append(normal_path)

print(normal_paths)
Expand Down

0 comments on commit ac4164b

Please sign in to comment.