-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
78 lines (64 loc) · 2.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import Dict, List, Union
import torch
import aloscene
def get_mask_queries(
frames: aloscene.frame, m_outputs: Dict, model: torch.nn, matcher: torch.nn = None, filters: List = None, **kwargs
):
"""Mask process filter throught matcher or our_filter function
Parameters
----------
frames : aloscene.frame
Input frames
m_outputs : Dict
Forward output
model : torch.nn
model with inference function
matcher : torch.nn, optional
Matcher between GT and pred elements, by default None
filters : List, optional
Boolean mask for each batch, by default None
Returns
-------
torch.Tensor, List
Mask reduced from (M,H,W) to (N,H,W) with boolean mask per batch (M >= N)
"""
dec_outputs = m_outputs["dec_outputs"][-1]
device = dec_outputs.device
if filters is None:
if matcher is None:
filters = model.get_outs_filter(m_outputs=m_outputs, **kwargs)
else:
nq = dec_outputs.size(1)
filters = [torch.tensor([False] * nq, dtype=torch.bool, device=device) for _ in range(len(dec_outputs))]
for b, (src, _) in enumerate(matcher(m_outputs=m_outputs, frames=frames, **kwargs)):
filters[b][src] = True
# Filter masks and cat adding pad
fsizes = [sum(f) for f in filters]
max_size = max(fsizes)
feat_size = dec_outputs.shape[2:]
dec_outputs = [
torch.cat([dec_outputs[b : b + 1, idx], torch.zeros(1, max_size - fs, *feat_size, device=device)], dim=1)
for b, (idx, fs) in enumerate(zip(filters, fsizes))
]
return torch.cat(dec_outputs, dim=0), filters
def get_base_model_frame(frames: Union[list, aloscene.Frame], cat: str = "category") -> aloscene.Frame:
"""Get frames with correct labels for criterion process
Parameters
----------
frames : aloscene.Frame
frames to set labels
Returns
-------
aloscene.Frame
frames with correct set of labels
"""
if isinstance(frames, list):
frames = aloscene.Frame.batch_list(frames)
frames = frames.clone()
def criterion(b):
b.labels = b.labels[cat]
if isinstance(frames.boxes2d[0].labels, dict):
frames.apply_on_child(frames.boxes2d, criterion)
if isinstance(frames.segmentation[0].labels, dict):
frames.apply_on_child(frames.segmentation, criterion)
return frames