-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from PPierzc/main
Additional Evaluation Logic
- Loading branch information
Showing
27 changed files
with
1,395 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
seed: 0 | ||
checkpoint_every: 10 | ||
use_pretrained: mpii-prod-xlarge:latest | ||
|
||
tags: | ||
- mpii | ||
- human36m | ||
group: prod | ||
|
||
dataset: | ||
dirname: "/data/human36m/processed" | ||
mpii: true | ||
|
||
train: | ||
optimizer: | ||
lr: 1.0e-5 | ||
weight_decay: 0 | ||
lr_scheduler: | ||
patience: 10 | ||
cooldown: 5 | ||
mode: "min" | ||
factor: 0.1 | ||
threshold: 1.0e-2 | ||
min_lr: 1.0e-6 | ||
batch_size: 200 | ||
epochs: 200 | ||
|
||
model: | ||
num_layers: 14 | ||
context_features: 68 | ||
hidden_features: 262 | ||
relations: | ||
- x | ||
- c | ||
- r | ||
- x->x | ||
- x<-x | ||
- c->x | ||
- r->x | ||
|
||
embedding: | ||
name: "sage" | ||
config: | ||
input_dim: 2 | ||
hidden_dim: 177 | ||
output_dim: 68 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import torch | ||
import numpy as np | ||
|
||
|
||
def mpjpe(pred, gt, dim=None, mean=True): | ||
""" | ||
`mpjpe` is the mean per joint position error, which is the mean of the Euclidean distance between the predicted 3D | ||
joint positions and the ground truth 3D joint positions | ||
Used in Protocol-I for Human3.6M dataset evaluation. | ||
:param pred: the predicted 3D pose | ||
:param gt: ground truth | ||
:param dim: the dimension to average over. If None, the average is taken over all dimensions | ||
:param mean: If True, returns the mean of the MPJPE across all frames. If False, returns the MPJPE for each frame, | ||
defaults to True (optional) | ||
:return: The mean of the pjpe | ||
""" | ||
pjpe = ((pred - gt) ** 2).sum(-1) ** 0.5 | ||
|
||
if not mean: | ||
return pjpe | ||
|
||
# if pjpe is torch.Tensor use dim if numpy.array use axis | ||
if isinstance(pjpe, torch.Tensor): | ||
if dim is None: | ||
return pjpe.mean() | ||
return pjpe.mean(dim=dim) | ||
|
||
if dim is None: | ||
return np.mean(pjpe) | ||
|
||
return np.mean(pjpe, axis=dim) | ||
|
||
|
||
def pa_mpjpe( | ||
p_gt: torch.TensorType, p_pred: torch.TensorType, dim: int = None, mean: bool = True | ||
): | ||
""" | ||
PA-MPJPE is the Procrustes mean per joint position error, which is the mean of the Euclidean distance between the | ||
predicted 3D joint positions and the ground truth 3D joint positions, after projecting the ground truth onto the | ||
predicted 3D skeleton. | ||
Used in Protocol-II for Human3.6M dataset evaluation. | ||
Code adapted from: | ||
https://github.com/twehrbein/Probabilistic-Monocular-3D-Human-Pose-Estimation-with-Normalizing-Flows/ | ||
:param p_gt: the ground truth 3D pose | ||
:type p_gt: torch.TensorType | ||
:param p_pred: predicted 3D pose | ||
:type p_pred: torch.TensorType | ||
:param dim: the dimension to average over. If None, the average is taken over all dimensions | ||
:type dim: int | ||
:param mean: If True, returns the mean of the MPJPE across all frames. If False, returns the MPJPE for each frame, | ||
defaults to True (optional) | ||
:return: The transformed coordinates. | ||
""" | ||
if not isinstance(p_pred, torch.Tensor): | ||
p_pred = torch.Tensor(p_pred) | ||
|
||
if not isinstance(p_gt, torch.Tensor): | ||
p_gt = torch.Tensor(p_gt) | ||
|
||
og_gt = p_gt.clone() | ||
|
||
p_gt = p_gt.repeat(1, p_pred.shape[1], 1) | ||
|
||
p_gt = p_gt.permute(1, 2, 0).contiguous() | ||
p_pred = p_pred.permute(1, 2, 0).contiguous() | ||
|
||
# Moving the tensors to the CPU as the following code is more efficient on the CPU | ||
p_pred = p_pred.cpu() | ||
p_gt = p_gt.cpu() | ||
|
||
mu_gt = p_gt.mean(dim=2) | ||
mu_pred = p_pred.mean(dim=2) | ||
|
||
p_gt = p_gt - mu_gt[:, :, None] | ||
p_pred = p_pred - mu_pred[:, :, None] | ||
|
||
ss_gt = (p_gt**2.0).sum(dim=(1, 2)) | ||
ss_pred = (p_pred**2.0).sum(dim=(1, 2)) | ||
|
||
# centred Frobenius norm | ||
norm_gt = torch.sqrt(ss_gt) | ||
norm_pred = torch.sqrt(ss_pred) | ||
|
||
# scale to equal (unit) norm | ||
p_gt /= norm_gt[:, None, None] | ||
p_pred /= norm_pred[:, None, None] | ||
|
||
# optimum rotation matrix of Y | ||
A = torch.bmm(p_gt, p_pred.transpose(1, 2)) | ||
|
||
U, s, V = torch.svd(A, some=True) | ||
|
||
# Computing the rotation matrix. | ||
T = torch.bmm(V, U.transpose(1, 2)) | ||
|
||
detT = torch.det(T) | ||
sign = torch.sign(detT) | ||
V[:, :, -1] *= sign[:, None] | ||
s[:, -1] *= sign | ||
T = torch.bmm(V, U.transpose(1, 2)) | ||
|
||
# Computing the trace of the matrix A. | ||
traceTA = s.sum(dim=1) | ||
|
||
# transformed coords | ||
scale = norm_gt * traceTA | ||
|
||
p_pred_projected = ( | ||
scale[:, None, None] * torch.bmm(p_pred.transpose(1, 2), T) + mu_gt[:, None, :] | ||
) | ||
|
||
return mpjpe(og_gt, p_pred_projected.permute(1, 0, 2), dim=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
|
||
human36m_joints_to_use = [1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16] | ||
|
||
|
||
def pck( | ||
poses_gt: torch.Tensor, | ||
poses_pred: torch.Tensor, | ||
threshold: float = 150, | ||
return_distances: bool = False, | ||
) -> torch.BoolTensor: | ||
""" | ||
It computes the percentage of frames in which the predicted pose is within a threshold distance of the ground truth | ||
pose | ||
:param poses_gt: the ground truth poses with only the joints of interest (frames x joints x 3) | ||
:type poses_gt: torch.Tensor | ||
:param poses_pred: the predicted poses with only the joints of interest (frames x joints x 3) | ||
:type poses_pred: torch.Tensor | ||
:param threshold: The threshold for the distance between the predicted and ground truth pose, defaults to 180 | ||
:type threshold: float (optional) | ||
:param return_distances: If True, returns the distances between the predicted and ground truth pose, defaults to False | ||
:type return_distances: bool (optional) | ||
""" | ||
if not isinstance(poses_pred, torch.Tensor): | ||
poses_pred = torch.Tensor(poses_pred) | ||
|
||
if not isinstance(poses_gt, torch.Tensor): | ||
poses_gt = torch.Tensor(poses_gt) | ||
|
||
distances = torch.sqrt(torch.sum((poses_gt - poses_pred) ** 2, dim=-1)) | ||
|
||
if return_distances: | ||
return distances | ||
|
||
n_correct_joints = torch.count_nonzero(distances < threshold, dim=1) | ||
correct_poses = n_correct_joints / poses_gt.shape[1] | ||
|
||
return correct_poses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.