Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate Codebase to Latest PyTorch Lightning and NumPy #29

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pyquaternion==0.9.9
nusceneis-devkit==1.1.11
opencv-python==4.8.0.74
pytorch-lightning==2.4.0
fvcore==0.1.5.post20221221
efficientnet_pytorch==0.7.1
timm==1.0.8
scikit_image==0.24.0
157 changes: 93 additions & 64 deletions stp3/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
import torch.nn as nn
import numpy as np
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
from pytorch_lightning.metrics.functional.reduction import reduce
from torchmetrics import Metric
from torchmetrics.functional import stat_scores
from torchmetrics.utilities import reduce
from skimage.draw import polygon

from stp3.utils.tools import gen_dx_bx
Expand All @@ -19,30 +19,38 @@ def __init__(
n_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
reduction: str = 'none',
compute_on_step: bool = False,
reduction: str = 'none'
):
super().__init__(compute_on_step=compute_on_step)
super().__init__()

self.n_classes = n_classes
self.ignore_index = ignore_index
self.absent_score = absent_score
self.reduction = reduction

# Initialize states for the metric computation
self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum')
self.add_state('support', default=torch.zeros(n_classes), dist_reduce_fx='sum')

def update(self, prediction: torch.Tensor, target: torch.Tensor):
tps, fps, _, fns, sups = stat_scores_multiple_classes(prediction, target, self.n_classes)
# Calculate statistics for each class
tps, fps, _, fns, sups = stat_scores(
preds=prediction,
target=target,
average=None,
num_classes=self.n_classes
)

# Update state variables
self.true_positive += tps
self.false_positive += fps
self.false_negative += fns
self.support += sups

def compute(self):
# Initialize scores tensor
scores = torch.zeros(self.n_classes, device=self.true_positive.device, dtype=torch.float32)

for class_idx in range(self.n_classes):
Expand All @@ -54,20 +62,21 @@ def compute(self):
fn = self.false_negative[class_idx]
sup = self.support[class_idx]

# If this class is absent in the target (no support) AND absent in the pred (no true or false
# positives), then use the absent_score for this class.
# Assign absent_score if the class is absent in both target and prediction
if sup + tp + fp == 0:
scores[class_idx] = self.absent_score
continue

# Calculate IoU score
denominator = tp + fp + fn
score = tp.to(torch.float) / denominator
scores[class_idx] = score

# Remove the ignored class index from the scores.
# Exclude the ignored class index from scores
if (self.ignore_index is not None) and (0 <= self.ignore_index < self.n_classes):
scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index+1:]])
scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index + 1:]])

# Reduce scores according to the specified reduction method
return reduce(scores, reduction=self.reduction)


Expand All @@ -76,22 +85,22 @@ def __init__(
self,
n_classes: int,
temporally_consistent: bool = True,
vehicles_id: int = 1,
compute_on_step: bool = False,
vehicles_id: int = 1
):
super().__init__(compute_on_step=compute_on_step)
super().__init__()

self.n_classes = n_classes
self.temporally_consistent = temporally_consistent
self.vehicles_id = vehicles_id
self.keys = ['iou', 'true_positive', 'false_positive', 'false_negative']

# Initialize states for the metric computation
self.add_state('iou', default=torch.zeros(n_classes), dist_reduce_fx='sum')
self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum')

def update(self, pred_instance, gt_instance):
def update(self, pred_instance: torch.Tensor, gt_instance: torch.Tensor):
"""
Update state with predictions and targets.

Expand Down Expand Up @@ -133,12 +142,7 @@ def compute(self):
sq = self.iou / torch.maximum(self.true_positive, torch.ones_like(self.true_positive))
rq = self.true_positive / denominator

return {'pq': pq,
'sq': sq,
'rq': rq,
# If 0, it means there wasn't any detection.
# 'denominator': (self.true_positive + self.false_positive / 2 + self.false_negative / 2),
}
return {'pq': pq, 'sq': sq, 'rq': rq}

def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt_instance, unique_id_mapping):
"""
Expand All @@ -163,44 +167,31 @@ def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt
n_all_things = n_instances + n_classes # Classes + instances.
n_things_and_void = n_all_things + 1

# Now 1 is background; 0 is void (not used). 2 is vehicle semantic class but since it overlaps with
# instances, it is not present.
# and the rest are instance ids starting from 3
prediction, pred_to_cls = self.combine_mask(pred_segmentation, pred_instance, n_classes, n_all_things)
target, target_to_cls = self.combine_mask(gt_segmentation, gt_instance, n_classes, n_all_things)

# Compute ious between all stuff and things
# hack for bincounting 2 arrays together
x = prediction + n_things_and_void * target
bincount_2d = torch.bincount(x.long(), minlength=n_things_and_void ** 2)
if bincount_2d.shape[0] != n_things_and_void ** 2:
raise ValueError('Incorrect bincount size.')
conf = bincount_2d.reshape((n_things_and_void, n_things_and_void))
# Drop void class
conf = conf[1:, 1:]

# Confusion matrix contains intersections between all combinations of classes
union = conf.sum(0).unsqueeze(0) + conf.sum(1).unsqueeze(1) - conf
iou = torch.where(union > 0, (conf.float() + 1e-9) / (union.float() + 1e-9), torch.zeros_like(union).float())

# In the iou matrix, first dimension is target idx, second dimension is pred idx.
# Mapping will contain a tuple that maps prediction idx to target idx for segments matched by iou.
mapping = (iou > 0.5).nonzero(as_tuple=False)

# Check that classes match.
is_matching = pred_to_cls[mapping[:, 1]] == target_to_cls[mapping[:, 0]]
mapping = mapping[is_matching]
tp_mask = torch.zeros_like(conf, dtype=torch.bool)
tp_mask[mapping[:, 0], mapping[:, 1]] = True

# First ids correspond to "stuff" i.e. semantic seg.
# Instance ids are offset accordingly
for target_id, pred_id in mapping:
cls_id = pred_to_cls[pred_id]

if self.temporally_consistent and cls_id == self.vehicles_id:
if target_id.item() in unique_id_mapping and unique_id_mapping[target_id.item()] != pred_id.item():
# Not temporally consistent
result['false_negative'][target_to_cls[target_id]] += 1
result['false_positive'][pred_to_cls[pred_id]] += 1
unique_id_mapping[target_id.item()] = pred_id.item()
Expand All @@ -211,18 +202,14 @@ def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt
unique_id_mapping[target_id.item()] = pred_id.item()

for target_id in range(n_classes, n_all_things):
# If this is a true positive do nothing.
if tp_mask[target_id, n_classes:].any():
continue
# If this target instance didn't match with any predictions and was present set it as false negative.
if target_to_cls[target_id] != -1:
result['false_negative'][target_to_cls[target_id]] += 1

for pred_id in range(n_classes, n_all_things):
# If this is a true positive do nothing.
if tp_mask[n_classes:, pred_id].any():
continue
# If this predicted instance didn't match with any prediction, set that predictions as false positive.
if pred_to_cls[pred_id] != -1 and (conf[:, pred_id] > 0).any():
result['false_positive'][pred_to_cls[pred_id]] += 1

Expand All @@ -238,9 +225,8 @@ def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_cla
instance = instance - 1 + n_classes

segmentation = segmentation.clone().view(-1)
segmentation_mask = segmentation < n_classes # Remove void pixels.
segmentation_mask = segmentation < n_classes

# Build an index from instance id to class id.
instance_id_to_class_tuples = torch.cat(
(
instance[instance_mask & segmentation_mask].unsqueeze(1),
Expand All @@ -255,69 +241,89 @@ def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_cla
)

segmentation[instance_mask] = instance[instance_mask]
segmentation += 1 # Shift all legit classes by 1.
segmentation[~segmentation_mask] = 0 # Shift void class to zero.
segmentation += 1
segmentation[~segmentation_mask] = 0

return segmentation, instance_id_to_class


class PlanningMetric(Metric):
def __init__(
self,
cfg,
n_future=4,
compute_on_step: bool = False,
n_future=4
):
super().__init__(compute_on_step=compute_on_step)
super().__init__()

# Generate grid dx, bx parameters
dx, bx, _ = gen_dx_bx(cfg.LIFT.X_BOUND, cfg.LIFT.Y_BOUND, cfg.LIFT.Z_BOUND)
dx, bx = dx[:2], bx[:2]

# Set parameters as nn.Parameter to keep them immutable during training
self.dx = nn.Parameter(dx, requires_grad=False)
self.bx = nn.Parameter(bx, requires_grad=False)

# Calculate bird's eye view dimensions
_, _, self.bev_dimension = calculate_birds_eye_view_parameters(
cfg.LIFT.X_BOUND, cfg.LIFT.Y_BOUND, cfg.LIFT.Z_BOUND
)
self.bev_dimension = self.bev_dimension.numpy()

# Ego vehicle dimensions
self.W = cfg.EGO.WIDTH
self.H = cfg.EGO.HEIGHT

# Number of future time steps to evaluate
self.n_future = n_future

# Initialize metric states
self.add_state("obj_col", default=torch.zeros(self.n_future), dist_reduce_fx="sum")
self.add_state("obj_box_col", default=torch.zeros(self.n_future), dist_reduce_fx="sum")
self.add_state("L2", default=torch.zeros(self.n_future),dist_reduce_fx="sum")
self.add_state("L2", default=torch.zeros(self.n_future), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")


def evaluate_single_coll(self, traj, segmentation):
'''
gt_segmentation
traj: torch.Tensor (n_future, 2)
segmentation: torch.Tensor (n_future, 200, 200)
Evaluate collision for a single trajectory against segmentation.

Parameters:
traj: torch.Tensor (n_future, 2)
segmentation: torch.Tensor (n_future, 200, 200)

Returns:
collision: torch.Tensor (n_future,) indicating collision at each time step
'''
# Define polygon representing the vehicle's bounding box
pts = np.array([
[-self.H / 2. + 0.5, self.W / 2.],
[self.H / 2. + 0.5, self.W / 2.],
[self.H / 2. + 0.5, -self.W / 2.],
[-self.H / 2. + 0.5, -self.W / 2.],
])

# Transform vehicle coordinates into BEV grid coordinates
pts = (pts - self.bx.cpu().numpy()) / (self.dx.cpu().numpy())
pts[:, [0, 1]] = pts[:, [1, 0]]

# Generate polygon in BEV grid
rr, cc = polygon(pts[:,1], pts[:,0])
rc = np.concatenate([rr[:,None], cc[:,None]], axis=-1)

# Adjust trajectory to grid
n_future, _ = traj.shape
trajs = traj.view(n_future, 1, 2)
trajs[:,:,[0,1]] = trajs[:,:,[1,0]] # can also change original tensor
trajs[:,:,[0,1]] = trajs[:,:,[1,0]] # Swap x, y axes
trajs = trajs / self.dx
trajs = trajs.cpu().numpy() + rc # (n_future, 32, 2)
trajs = trajs.cpu().numpy() + rc

# Clip coordinates to valid range
r = trajs[:,:,0].astype(np.int32)
r = np.clip(r, 0, self.bev_dimension[0] - 1)

c = trajs[:,:,1].astype(np.int32)
c = np.clip(c, 0, self.bev_dimension[1] - 1)

# Check collision at each future time step
collision = np.full(n_future, False)
for t in range(n_future):
rr = r[t]
Expand All @@ -332,11 +338,20 @@ def evaluate_single_coll(self, traj, segmentation):

def evaluate_coll(self, trajs, gt_trajs, segmentation):
'''
trajs: torch.Tensor (B, n_future, 2)
gt_trajs: torch.Tensor (B, n_future, 2)
segmentation: torch.Tensor (B, n_future, 200, 200)
Evaluate collision for batch of trajectories against segmentation.

Parameters:
trajs: torch.Tensor (B, n_future, 2)
gt_trajs: torch.Tensor (B, n_future, 2)
segmentation: torch.Tensor (B, n_future, 200, 200)

Returns:
obj_coll_sum: torch.Tensor (n_future,) total collisions with objects
obj_box_coll_sum: torch.Tensor (n_future,) total box collisions
'''
B, n_future, _ = trajs.shape

# Adjust trajectories to account for coordinate system differences
trajs = trajs * torch.tensor([-1, 1], device=trajs.device)
gt_trajs = gt_trajs * torch.tensor([-1, 1], device=gt_trajs.device)

Expand Down Expand Up @@ -367,17 +382,25 @@ def evaluate_coll(self, trajs, gt_trajs, segmentation):

def compute_L2(self, trajs, gt_trajs):
'''
trajs: torch.Tensor (B, n_future, 3)
gt_trajs: torch.Tensor (B, n_future, 3)
Compute L2 distance between predicted and ground truth trajectories.

Parameters:
trajs: torch.Tensor (B, n_future, 3)
gt_trajs: torch.Tensor (B, n_future, 3)

Returns:
L2: torch.Tensor (B, n_future) L2 distances at each time step
'''

return torch.sqrt(((trajs[:, :, :2] - gt_trajs[:, :, :2]) ** 2).sum(dim=-1))

def update(self, trajs, gt_trajs, segmentation):
'''
trajs: torch.Tensor (B, n_future, 3)
gt_trajs: torch.Tensor (B, n_future, 3)
segmentation: torch.Tensor (B, n_future, 200, 200)
Update metric states with batch of predictions and ground truths.

Parameters:
trajs: torch.Tensor (B, n_future, 3)
gt_trajs: torch.Tensor (B, n_future, 3)
segmentation: torch.Tensor (B, n_future, 200, 200)
'''
assert trajs.shape == gt_trajs.shape
L2 = self.compute_L2(trajs, gt_trajs)
Expand All @@ -386,11 +409,17 @@ def update(self, trajs, gt_trajs, segmentation):
self.obj_col += obj_coll_sum
self.obj_box_col += obj_box_coll_sum
self.L2 += L2.sum(dim=0)
self.total +=len(trajs)
self.total += len(trajs)

def compute(self):
'''
Compute final metric results after aggregation.

Returns:
dict with keys 'obj_col', 'obj_box_col', and 'L2'
'''
return {
'obj_col': self.obj_col / self.total,
'obj_box_col': self.obj_box_col / self.total,
'L2' : self.L2 / self.total
}
'L2': self.L2 / self.total
}
Loading