diff --git a/docker-compose.yml b/docker-compose.yml index 713af7c..a025076 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -78,7 +78,7 @@ services: volumes: - .:/src/propose - ./tests:/tests - entrypoint: [ "pytest" ] + entrypoint: [ "pytest", "/tests/"] notebook_server: <<: *common diff --git a/experiments/human36m/mpii-prod-xlarge_lr_decr.yaml b/experiments/human36m/mpii-prod-xlarge_lr_decr.yaml new file mode 100644 index 0000000..a59a296 --- /dev/null +++ b/experiments/human36m/mpii-prod-xlarge_lr_decr.yaml @@ -0,0 +1,46 @@ +seed: 0 +checkpoint_every: 10 +use_pretrained: mpii-prod-xlarge:latest + +tags: + - mpii + - human36m +group: prod + +dataset: + dirname: "/data/human36m/processed" + mpii: true + +train: + optimizer: + lr: 1.0e-5 + weight_decay: 0 + lr_scheduler: + patience: 10 + cooldown: 5 + mode: "min" + factor: 0.1 + threshold: 1.0e-2 + min_lr: 1.0e-6 + batch_size: 200 + epochs: 200 + +model: + num_layers: 14 + context_features: 68 + hidden_features: 262 + relations: + - x + - c + - r + - x->x + - x<-x + - c->x + - r->x + +embedding: + name: "sage" + config: + input_dim: 2 + hidden_dim: 177 + output_dim: 68 \ No newline at end of file diff --git a/propose/datasets/human36m/Human36mDataset.py b/propose/datasets/human36m/Human36mDataset.py index 06eba53..b167444 100644 --- a/propose/datasets/human36m/Human36mDataset.py +++ b/propose/datasets/human36m/Human36mDataset.py @@ -18,14 +18,15 @@ def tensor_to_graph(inputs, context, root, edges, context_edges, root_edges): """ - Convert a tensor to a graph. - :param inputs: tensor of shape (batch_size, num_nodes, num_features) - :param context: tensor of shape (batch_size, num_nodes, num_context_features) - :param root: tensor of shape (batch_size, num_nodes) - :param edges: tensor of shape (batch_size, num_edges, 2) - :param context_edges: tensor of shape (batch_size, num_context_edges, 2) - :param root_edges: tensor of shape (batch_size, num_root_edges, 2) - :return: HeteroData + It takes in the inputs, context, root, and edges, and returns a HeteroData object + + :param inputs: the input tensor + :param context: the context nodes + :param root: the root node + :param edges: the edges between the nodes in the graph + :param context_edges: the edges from the context to the inputs + :param root_edges: the edges from the root node to the other nodes + :return: A hetero data object. """ data = HeteroData() @@ -44,6 +45,13 @@ def tensor_to_graph(inputs, context, root, edges, context_edges, root_edges): def tensor_to_human36m_graph(inputs, context, context_edges): + """ + It takes the input tensors, and converts them to a graph + + :param inputs: the input tensor, which is a tensor of shape (num_frames, num_joints, 3) + :param context: the context of the graph, which is the same as the input to the model + :param context_edges: the edges that are used to compute the context + """ pose = Human36mPose(np.zeros((1, 17, 3))) edges = torch.LongTensor(pose.edges).T @@ -270,9 +278,18 @@ def __init__( self.base_data.append(base_data) def __len__(self): + """ + The function returns the length of the data attribute of the object + :return: The length of the data. + """ return len(self.data) def __getitem__(self, item): + """ + The function returns the data, base data, and a dictionary of the action, camera, subject, occlusion, and center3d + + :param item: the index of the item we want to get + """ if self.return_matrix: return ( self.data[item]["x"]["x"], @@ -299,6 +316,16 @@ def __getitem__(self, item): @classmethod def remove_root_edges(cls, edges, context_edges, num_context_samples): + """ + We remove the root edges from the full edges, and then we subtract 1 from the full edges and context edges to + make them zero-indexed + + :param cls: the class of the object + :param edges: the edges of the full graph + :param context_edges: the edges that are in the context graph + :param num_context_samples: The number of samples in the context + :return: The edges are being returned with the root edges removed. + """ full_edges = edges[:, torch.where(edges[0] != 0)[0]] context_edges = context_edges[:, torch.where(context_edges[1] != 0)[0]] root_edges = edges[:, torch.where(edges[0] == 0)[0]] @@ -311,6 +338,14 @@ def remove_root_edges(cls, edges, context_edges, num_context_samples): return full_edges, root_edges, context_edges def _sample_context(self, gaussfit, num_context_samples): + """ + Given a gaussian fit, sample from the gaussian distribution and return the samples + + :param gaussfit: the output of the neural network, which is a 16x6 tensor. The first column is the probability of + the gaussian, the next two are the mean, and the last three are the covariance matrix + :param num_context_samples: number of samples to draw from the context distribution + :return: The samples are being returned. + """ mean = torch.stack([gaussfit[:, 1], gaussfit[:, 2]], dim=1) cov = torch.stack([gaussfit[:, 3], gaussfit[:, 5]], dim=1).unsqueeze( 2 @@ -321,6 +356,14 @@ def _sample_context(self, gaussfit, num_context_samples): return samples.view(samples.shape[0] * samples.shape[1], samples.shape[2]) def _add_variance(self, pose2d, gaussfit): + """ + It takes in a pose2d and a gaussfit, and if use_variance is true, it returns a concatenation of pose2d and the + square of the third and sixth columns of gaussfit. Otherwise, it just returns pose2d + + :param pose2d: the 2D pose + :param gaussfit: the output of the gaussian fitting function + :return: The pose2d is being returned. + """ if self.use_variance: res = torch.cat( [ @@ -504,6 +547,10 @@ def __init__( self.base_data.append(base_data) def __len__(self): + """ + The function returns the length of the data attribute of the object + :return: The length of the data. + """ return len(self.data) def __getitem__(self, item): @@ -527,6 +574,14 @@ def __getitem__(self, item): ) # returns: full data, base data def remove_root_edges(self, edges, context_edges): + """ + It takes in the edges and context edges, and returns the full edges, root edges, and context edges + + :param edges: the edges of the graph, in the form of a 2xN tensor, where N is the number of edges. The first + row is the source node, the second row is the destination node + :param context_edges: the edges that are in the context of the current node + :return: The full_edges, root_edges, and context_edges are being returned. + """ full_edges = edges[:, torch.where(edges[0] != 0)[0]] context_edges = context_edges[:, 1:] root_edges = edges[:, torch.where(edges[0] == 0)[0]] diff --git a/propose/evaluation/__init__.py b/propose/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/propose/evaluation/mpjpe.py b/propose/evaluation/mpjpe.py new file mode 100644 index 0000000..6f52000 --- /dev/null +++ b/propose/evaluation/mpjpe.py @@ -0,0 +1,117 @@ +import torch +import numpy as np + + +def mpjpe(pred, gt, dim=None, mean=True): + """ + `mpjpe` is the mean per joint position error, which is the mean of the Euclidean distance between the predicted 3D + joint positions and the ground truth 3D joint positions + + Used in Protocol-I for Human3.6M dataset evaluation. + + :param pred: the predicted 3D pose + :param gt: ground truth + :param dim: the dimension to average over. If None, the average is taken over all dimensions + :param mean: If True, returns the mean of the MPJPE across all frames. If False, returns the MPJPE for each frame, + defaults to True (optional) + :return: The mean of the pjpe + """ + pjpe = ((pred - gt) ** 2).sum(-1) ** 0.5 + + if not mean: + return pjpe + + # if pjpe is torch.Tensor use dim if numpy.array use axis + if isinstance(pjpe, torch.Tensor): + if dim is None: + return pjpe.mean() + return pjpe.mean(dim=dim) + + if dim is None: + return np.mean(pjpe) + + return np.mean(pjpe, axis=dim) + + +def pa_mpjpe( + p_gt: torch.TensorType, p_pred: torch.TensorType, dim: int = None, mean: bool = True +): + """ + PA-MPJPE is the Procrustes mean per joint position error, which is the mean of the Euclidean distance between the + predicted 3D joint positions and the ground truth 3D joint positions, after projecting the ground truth onto the + predicted 3D skeleton. + + Used in Protocol-II for Human3.6M dataset evaluation. + + Code adapted from: + https://github.com/twehrbein/Probabilistic-Monocular-3D-Human-Pose-Estimation-with-Normalizing-Flows/ + + :param p_gt: the ground truth 3D pose + :type p_gt: torch.TensorType + :param p_pred: predicted 3D pose + :type p_pred: torch.TensorType + :param dim: the dimension to average over. If None, the average is taken over all dimensions + :type dim: int + :param mean: If True, returns the mean of the MPJPE across all frames. If False, returns the MPJPE for each frame, + defaults to True (optional) + :return: The transformed coordinates. + """ + if not isinstance(p_pred, torch.Tensor): + p_pred = torch.Tensor(p_pred) + + if not isinstance(p_gt, torch.Tensor): + p_gt = torch.Tensor(p_gt) + + og_gt = p_gt.clone() + + p_gt = p_gt.repeat(1, p_pred.shape[1], 1) + + p_gt = p_gt.permute(1, 2, 0).contiguous() + p_pred = p_pred.permute(1, 2, 0).contiguous() + + # Moving the tensors to the CPU as the following code is more efficient on the CPU + p_pred = p_pred.cpu() + p_gt = p_gt.cpu() + + mu_gt = p_gt.mean(dim=2) + mu_pred = p_pred.mean(dim=2) + + p_gt = p_gt - mu_gt[:, :, None] + p_pred = p_pred - mu_pred[:, :, None] + + ss_gt = (p_gt**2.0).sum(dim=(1, 2)) + ss_pred = (p_pred**2.0).sum(dim=(1, 2)) + + # centred Frobenius norm + norm_gt = torch.sqrt(ss_gt) + norm_pred = torch.sqrt(ss_pred) + + # scale to equal (unit) norm + p_gt /= norm_gt[:, None, None] + p_pred /= norm_pred[:, None, None] + + # optimum rotation matrix of Y + A = torch.bmm(p_gt, p_pred.transpose(1, 2)) + + U, s, V = torch.svd(A, some=True) + + # Computing the rotation matrix. + T = torch.bmm(V, U.transpose(1, 2)) + + detT = torch.det(T) + sign = torch.sign(detT) + V[:, :, -1] *= sign[:, None] + s[:, -1] *= sign + T = torch.bmm(V, U.transpose(1, 2)) + + # Computing the trace of the matrix A. + traceTA = s.sum(dim=1) + + # transformed coords + scale = norm_gt * traceTA + + p_pred_projected = ( + scale[:, None, None] * torch.bmm(p_pred.transpose(1, 2), T) + mu_gt[:, None, :] + ) + + return mpjpe(og_gt, p_pred_projected.permute(1, 0, 2), dim=0) diff --git a/propose/evaluation/pck.py b/propose/evaluation/pck.py new file mode 100644 index 0000000..1e6ea60 --- /dev/null +++ b/propose/evaluation/pck.py @@ -0,0 +1,39 @@ +import torch + +human36m_joints_to_use = [1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16] + + +def pck( + poses_gt: torch.Tensor, + poses_pred: torch.Tensor, + threshold: float = 150, + return_distances: bool = False, +) -> torch.BoolTensor: + """ + It computes the percentage of frames in which the predicted pose is within a threshold distance of the ground truth + pose + + :param poses_gt: the ground truth poses with only the joints of interest (frames x joints x 3) + :type poses_gt: torch.Tensor + :param poses_pred: the predicted poses with only the joints of interest (frames x joints x 3) + :type poses_pred: torch.Tensor + :param threshold: The threshold for the distance between the predicted and ground truth pose, defaults to 180 + :type threshold: float (optional) + :param return_distances: If True, returns the distances between the predicted and ground truth pose, defaults to False + :type return_distances: bool (optional) + """ + if not isinstance(poses_pred, torch.Tensor): + poses_pred = torch.Tensor(poses_pred) + + if not isinstance(poses_gt, torch.Tensor): + poses_gt = torch.Tensor(poses_gt) + + distances = torch.sqrt(torch.sum((poses_gt - poses_pred) ** 2, dim=-1)) + + if return_distances: + return distances + + n_correct_joints = torch.count_nonzero(distances < threshold, dim=1) + correct_poses = n_correct_joints / poses_gt.shape[1] + + return correct_poses diff --git a/propose/models/flows/CondGraphFlow.py b/propose/models/flows/CondGraphFlow.py index e3bd3ec..ffb40e9 100644 --- a/propose/models/flows/CondGraphFlow.py +++ b/propose/models/flows/CondGraphFlow.py @@ -97,6 +97,17 @@ def from_pretrained(cls, artifact_name): flow = cls.build_model(artifact.metadata) artifact_dir = artifact.download() - flow.load_state_dict(torch.load(artifact_dir + "/model.pt")) + + device = "cuda" if torch.cuda.is_available() else "cpu" + flow.load_state_dict( + torch.load(artifact_dir + "/model.pt", map_location=torch.device(device)) + ) return flow + + def set_device(self): + if torch.cuda.is_available(): + self.to("cuda:0") + return True + + return False diff --git a/propose/poses/metadata/__init__.py b/propose/poses/metadata/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/propose/training/supervised.py b/propose/training/supervised.py index d90aa14..9b6aed6 100644 --- a/propose/training/supervised.py +++ b/propose/training/supervised.py @@ -5,7 +5,7 @@ from torch_geometric.loader.dataloader import Collater -from propose.utils.mpjpe import mpjpe +from propose.evaluation.mpjpe import mpjpe def supervised_trainer( diff --git a/propose/utils/imports.py b/propose/utils/imports.py new file mode 100644 index 0000000..cdcc503 --- /dev/null +++ b/propose/utils/imports.py @@ -0,0 +1,37 @@ +from importlib import import_module + + +def split_module_name(abs_class_name): + """ + It takes a fully qualified class name (e.g. `"foo.bar.Baz"`) and returns a tuple of the module path (e.g. `"foo.bar"`) + and the class name (e.g. `"Baz"`) + + :param abs_class_name: The absolute name of the class + :return: The absolute module path and the class name. + """ + abs_module_path = ".".join(abs_class_name.split(".")[:-1]) + class_name = abs_class_name.split(".")[-1] + return abs_module_path, class_name + + +def dynamic_import(abs_module_path, class_name): + """ + It dynamically imports a class from a module + + :param abs_module_path: The absolute path to the module you want to import + :param class_name: The name of the class you want to instantiate + :return: The class object + """ + module_object = import_module(abs_module_path) + target_class = getattr(module_object, class_name) + return target_class + + +def module_import(path): + """ + It takes a string like `"foo.bar.baz"` and returns the module object `baz` from the package `foo.bar` + + :param path: The path to the module you want to import + :return: The module object. + """ + return dynamic_import(*split_module_name(path)) diff --git a/propose/utils/mpjpe.py b/propose/utils/mpjpe.py deleted file mode 100644 index aa9114e..0000000 --- a/propose/utils/mpjpe.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -import numpy as np - - -def mpjpe(pred, gt, dim=None): - pjpe = ((pred - gt) ** 2).sum(-1) ** 0.5 - - # if pjpe is torch.Tensor use dim if numpy.array use axis - if isinstance(pjpe, torch.Tensor): - if dim is None: - return pjpe.mean() - return pjpe.mean(dim=dim) - - if dim is None: - return np.mean(pjpe) - return np.mean(pjpe, axis=dim) diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/eval.py b/scripts/eval.py index 138aabe..6c3ec0f 100644 --- a/scripts/eval.py +++ b/scripts/eval.py @@ -1,15 +1,13 @@ -from pathlib import Path -from propose.datasets.human36m.preprocess import pickle_poses, pickle_cameras +from propose.utils.imports import dynamic_import import argparse -from eval.human36m import human36m - import os import yaml from pathlib import Path + parser = argparse.ArgumentParser(description="Arguments for running the scripts") parser.add_argument( @@ -33,6 +31,13 @@ help="Experiment config file", ) +parser.add_argument( + "--script", + default="eval.human36m.human36m", + type=str, + help="Experiment script", +) + if __name__ == "__main__": args = parser.parse_args() @@ -59,11 +64,8 @@ if "experiment_name" not in config: config["experiment_name"] = args.experiment - if not args.wandb: - raise Exception("Wandb is required for evaluation experiments") - if args.human36m: - human36m(use_wandb=args.wandb, config=config) + dynamic_import(args.script, "run")(use_wandb=args.wandb, config=config) else: print( "Not running any scripts as no arguments were passed. Run with --help for more information." diff --git a/scripts/eval/human36m/__init__.py b/scripts/eval/human36m/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/eval/human36m/calibration.py b/scripts/eval/human36m/calibration.py new file mode 100644 index 0000000..9d45c30 --- /dev/null +++ b/scripts/eval/human36m/calibration.py @@ -0,0 +1,154 @@ +from propose.datasets.human36m.Human36mDataset import Human36mDataset +from torch_geometric.loader import DataLoader + +from propose.utils.reproducibility import set_random_seed + +from propose.models.flows import CondGraphFlow + +import torch + +import os + +import time +from tqdm import tqdm +import numpy as np + +import wandb + +import seaborn as sns +import matplotlib.pyplot as plt + + +def calibration(flow, test_dataloader): + total = 0 + iter_dataloader = iter(test_dataloader) + pbar = tqdm(range(len(test_dataloader))) + + quantiles = np.arange(0, 1.05, 0.05) + quantile_counts = np.zeros((len(quantiles), 1)) + q_val = [] + + for _ in pbar: + batch, _, action = next(iter_dataloader) + batch.cuda() + samples = flow.sample(200, batch) + + true_pose = ( + batch["x"] + .x.cpu() + .numpy() + .reshape(-1, 16, 1, 3)[ + :, np.insert(action["occlusion"].bool().numpy(), 9, False) + ] + ) + sample_poses = ( + samples["x"] + .x.detach() + .cpu() + .numpy() + .reshape(-1, 16, 200, 3)[ + :, np.insert(action["occlusion"].bool().numpy(), 9, False) + ] + ) + + sample_mean = ( + torch.Tensor(sample_poses).median(-2).values.numpy()[..., np.newaxis, :] + ) + errors = ((sample_mean / 0.0036 - sample_poses / 0.0036) ** 2).sum(-1) ** 0.5 + true_error = ((sample_mean / 0.0036 - true_pose / 0.0036) ** 2).sum(-1) ** 0.5 + + q_vals = np.quantile(errors, quantiles, 2).squeeze(1) + q_val.append(q_vals) + + v = np.nanmean((q_vals > true_error.squeeze()).astype(int), axis=1)[ + :, np.newaxis + ] + if not np.isnan(v).any(): + total += 1 + quantile_counts += v + + quantile_freqs = quantile_counts / total + + return quantiles, quantile_freqs, q_val + + +def calibration_experiment(flow, config, **kwargs): + test_dataset = Human36mDataset( + **config["dataset"], + **kwargs, + ) + test_dataloader = DataLoader( + test_dataset, batch_size=1, shuffle=True, pin_memory=False, num_workers=0 + ) + + return calibration(flow, test_dataloader) + + +def run(use_wandb, config): + set_random_seed(config["seed"]) + + config["dataset"]["dirname"] = config["dataset"]["dirname"] + "/test" + + if use_wandb: + wandb.init( + project="propose_human36m", + entity=os.environ["WANDB_USER"], + config=config, + job_type="evaluation", + name=f"{config['experiment_name']}_calibration_{time.strftime('%d/%m/%Y::%H:%M:%S')}", + tags=config["tags"] if "tags" in config else None, + group=config["group"] if "group" in config else None, + ) + + flow = CondGraphFlow.from_pretrained( + f'ppierzc/propose_human36m/{config["experiment_name"]}:latest' + ) + + config["cuda_accelerated"] = flow.set_device() + flow.eval() + + # Test + quantiles, quantile_freqs, q_val = calibration_experiment( + flow, + config, + occlusion_fractions=[], + test=True, + ) + + sns.set_context("talk") + with sns.axes_style("whitegrid"): + plt.figure(figsize=(5, 5), dpi=150) + plt.fill_between( + quantiles, + np.mean(quantile_freqs, axis=1) + np.std(quantile_freqs, axis=1), + np.mean(quantile_freqs, axis=1) - np.std(quantile_freqs, axis=1), + color="#1E88E5", + alpha=0.5, + zorder=-5, + rasterized=True, + ) + plt.plot([0, 1], [0, 1], ls="--", c="tab:gray") + plt.plot( + quantiles, + np.median(quantile_freqs, axis=1), + c="#1E88E5", + alpha=1, + label="cGNF all", + ) + plt.xticks(np.arange(0, 1.2, 0.2)) + plt.yticks(np.arange(0, 1.2, 0.2)) + plt.xlabel("Quantile") + plt.ylabel("Frequency") + plt.text(0.03, 0.07, "reference line", rotation=45, c="k", fontsize=15) + plt.xlim(0, 1) + plt.ylim(0, 1) + plt.title("Calibration") + plt.legend(frameon=False) + + plt.gca().set_rasterization_zorder(-1) + + if use_wandb: + img = wandb.Image(plt) + wandb.log({"calibration": img}) + + plt.close() diff --git a/scripts/eval/human36m.py b/scripts/eval/human36m/human36m.py similarity index 53% rename from scripts/eval/human36m.py rename to scripts/eval/human36m/human36m.py index bff607e..9aefe94 100644 --- a/scripts/eval/human36m.py +++ b/scripts/eval/human36m/human36m.py @@ -2,12 +2,11 @@ from torch_geometric.loader import DataLoader from propose.utils.reproducibility import set_random_seed -from propose.utils.mpjpe import mpjpe +from propose.evaluation.mpjpe import mpjpe, pa_mpjpe +from propose.evaluation.pck import pck, human36m_joints_to_use from propose.models.flows import CondGraphFlow -import torch - import os import time @@ -19,11 +18,20 @@ def evaluate(flow, test_dataloader, temperature=1.0): mpjpes = [] + pa_mpjpes = [] + single_mpjpes = [] + single_pa_mpjpes = [] + pck_scores = [] + mean_pck_scores = [] iter_dataloader = iter(test_dataloader) - for _ in tqdm(range(len(test_dataloader))): + + pbar = tqdm(range(len(test_dataloader))) + + for _ in pbar: batch, _, action = next(iter_dataloader) - batch.cuda() + batch.to(flow.device) + samples = flow.sample(200, batch, temperature=temperature) true_pose = batch["x"].x.cpu().numpy().reshape(-1, 16, 1, 3) @@ -32,17 +40,59 @@ def evaluate(flow, test_dataloader, temperature=1.0): true_pose = np.insert(true_pose, 0, 0, axis=1) sample_poses = np.insert(sample_poses, 0, 0, axis=1) + pck_score = pck( + true_pose[:, human36m_joints_to_use] / 0.0036, + sample_poses[:, human36m_joints_to_use] / 0.0036, + ) + + has_correct_pose = pck_score.max().unsqueeze(0).numpy() + mean_correct_pose = pck_score.mean().unsqueeze(0).numpy() + m = mpjpe(true_pose / 0.0036, sample_poses / 0.0036, dim=1) + m_single = m[..., 0] m = np.min(m, axis=-1) + pa_m = ( + pa_mpjpe(true_pose[0] / 0.0036, sample_poses[0] / 0.0036, dim=0) + .unsqueeze(0) + .numpy() + ) + + pa_m_single = pa_m[..., 0] + pa_m = np.min(pa_m, axis=-1) + m = m.tolist() + pa_m = pa_m.tolist() + m_single = m_single.tolist() mpjpes += [m] + pa_mpjpes += [pa_m] + single_mpjpes += [m_single] + single_pa_mpjpes += [pa_m_single] + + pck_scores += [has_correct_pose] + mean_pck_scores += [mean_correct_pose] + + pbar.set_description( + f"MPJPE: {np.concatenate(mpjpes).mean():.4f}, " + f"PA MPJPE: {np.concatenate(pa_mpjpes).mean():.4f}, " + f"Single MPJPE: {np.concatenate(single_mpjpes).mean():.4f} " + f"Single PA MPJPE: {np.concatenate(single_pa_mpjpes).mean():.4f} " + f"PCK: {np.concatenate(pck_scores).mean():.4f} " + f"Mean PCK: {np.concatenate(mean_pck_scores).mean():.4f} " + ) - return mpjpes + return ( + mpjpes, + pa_mpjpes, + single_mpjpes, + single_pa_mpjpes, + pck_scores, + mean_pck_scores, + ) -def mpjpe_experiment(flow, config, **kwargs): +def mpjpe_experiment(flow, config, name="test", **kwargs): test_dataset = Human36mDataset( **config["dataset"], **kwargs, @@ -50,12 +100,28 @@ def mpjpe_experiment(flow, config, **kwargs): test_dataloader = DataLoader( test_dataset, batch_size=1, shuffle=True, pin_memory=False, num_workers=0 ) - test_res = evaluate(flow, test_dataloader) - - return np.concatenate(test_res).mean(), test_dataset, test_dataloader - - -def human36m(use_wandb: bool = False, config: dict = None): + ( + test_res, + test_res_pa, + test_res_single, + test_res_pa_single, + test_res_pck, + test_res_mean_pck, + ) = evaluate(flow, test_dataloader) + + res = { + f"{name}/test_res": np.concatenate(test_res).mean(), + f"{name}/test_res_pa": np.concatenate(test_res_pa).mean(), + f"{name}/test_res_single": np.concatenate(test_res_single).mean(), + f"{name}/test_res_pa_single": np.concatenate(test_res_pa_single).mean(), + f"{name}/test_res_pck": np.concatenate(test_res_pck).mean(), + f"{name}/test_res_mean_pck": np.concatenate(test_res_mean_pck).mean(), + } + + return res, test_dataset, test_dataloader + + +def run(use_wandb: bool = False, config: dict = None): """ Train a CondGraphFlow on the Human36m dataset. :param use_wandb: Whether to use wandb for logging. @@ -71,20 +137,16 @@ def human36m(use_wandb: bool = False, config: dict = None): entity=os.environ["WANDB_USER"], config=config, job_type="evaluation", - name=f"{config['experiment_name']}_{time.strftime('%d/%m/%Y::%H:%M:%S')}", + name=f"{config['experiment_name']}_human36m_{time.strftime('%d/%m/%Y::%H:%M:%S')}", tags=config["tags"] if "tags" in config else None, group=config["group"] if "group" in config else None, ) flow = CondGraphFlow.from_pretrained( - f'ppierzc/propose_human36m/{config["experiment_name"]}:latest' + f'ppierzc/propose_human36m/{config["experiment_name"]}:v20' ) - config["cuda_accelerated"] = False - if torch.cuda.is_available(): - flow.to("cuda:0") - config["cuda_accelerated"] = True - + config["cuda_accelerated"] = flow.set_device() flow.eval() # Test @@ -93,10 +155,11 @@ def human36m(use_wandb: bool = False, config: dict = None): config, occlusion_fractions=[], test=True, + name="test", ) if use_wandb: - wandb.log({"test/best_mpjpe": test_res}) + wandb.log(test_res) # Hard hard_res, hard_dataset, hard_dataloader = mpjpe_experiment( @@ -104,10 +167,11 @@ def human36m(use_wandb: bool = False, config: dict = None): config, occlusion_fractions=[], hardsubset=True, + name="hard", ) if use_wandb: - wandb.log({"hard/best_mpjpe": hard_res}) + wandb.log(hard_res) # Occlusion Only mpjpes = [] diff --git a/scripts/eval/human36m/per_joint_error.py b/scripts/eval/human36m/per_joint_error.py new file mode 100644 index 0000000..35ffa3d --- /dev/null +++ b/scripts/eval/human36m/per_joint_error.py @@ -0,0 +1,151 @@ +from propose.datasets.human36m.Human36mDataset import Human36mDataset +from torch_geometric.loader import DataLoader +from propose.poses.human36m import Human36mPose + +from propose.utils.reproducibility import set_random_seed +from propose.evaluation.mpjpe import mpjpe + +from propose.models.flows import CondGraphFlow + +import os + +import time +from tqdm import tqdm +import numpy as np + +import wandb + +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt + + +def evaluate(flow, test_dataloader, temperature=1.0): + mpjpes_not_occuled = [] + mpjpes_occuled = [] + + iter_dataloader = iter(test_dataloader) + for _ in tqdm(range(len(test_dataloader))): + batch, _, action = next(iter_dataloader) + occluded_joints = action["occlusion"].bool().numpy() + + batch = batch.to(flow.device) + samples = flow.sample(200, batch, temperature=temperature) + + true_pose = batch["x"].x.cpu().numpy().reshape(-1, 16, 1, 3) + sample_poses = samples["x"].x.detach().cpu().numpy().reshape(-1, 16, 200, 3) + + true_pose = np.insert(true_pose, 0, 0, axis=1) + sample_poses = np.insert(sample_poses, 0, 0, axis=1) + + m = mpjpe(true_pose / 0.0036, sample_poses / 0.0036, mean=False) + m = np.min(m, axis=-1) + + m = np.delete(m, 0, axis=1) + m = np.delete(m, 8, axis=1) + + # if occluded add values to mpjpes_occuled with the unoclluded as nan + m_occlued = m.copy() + m_occlued[~occluded_joints] = np.nan + mpjpes_occuled.append(m_occlued) + + # if not occluded add values to mpjpes_not_occuled with the occluded as nan + m_not_occlued = m.copy() + m_not_occlued[occluded_joints] = np.nan + mpjpes_not_occuled.append(m_not_occlued) + + return mpjpes_not_occuled, mpjpes_occuled + + +def mpjpe_experiment(flow, config, **kwargs): + test_dataset = Human36mDataset(**config["dataset"], **kwargs) + test_dataloader = DataLoader( + test_dataset, batch_size=1, shuffle=True, pin_memory=False, num_workers=0 + ) + mpjpes_not_occuled, mpjpes_occuled = evaluate(flow, test_dataloader) + + return np.concatenate(mpjpes_not_occuled).T, np.concatenate(mpjpes_occuled).T + + +def run(use_wandb: bool = False, config: dict = None): + """ + Train a CondGraphFlow on the Human36m dataset. + :param use_wandb: Whether to use wandb for logging. + :param config: A dictionary of configuration parameters. + """ + set_random_seed(config["seed"]) + + config["dataset"]["dirname"] = config["dataset"]["dirname"] + "/test" + + if use_wandb: + wandb.init( + project="propose_human36m", + entity=os.environ["WANDB_USER"], + config=config, + job_type="evaluation", + name=f"{config['experiment_name']}_pje_{time.strftime('%d/%m/%Y::%H:%M:%S')}", + tags=config["tags"] if "tags" in config else None, + group=config["group"] if "group" in config else None, + ) + + flow = CondGraphFlow.from_pretrained( + f'ppierzc/propose_human36m/{config["experiment_name"]}:latest' + ) + + config["cuda_accelerated"] = flow.set_device() + flow.eval() + + pose = Human36mPose(np.zeros((16, 2))) + marker_names = pose.marker_names[1:] + del marker_names[8] + + # Test + mpjpes_not_occuled, mpjpes_occuled = mpjpe_experiment( + flow, + config, + occlusion_fractions=[], + test=True, + ) + + df_occluded = pd.DataFrame( + {key: value for key, value in zip(marker_names, mpjpes_occuled)} + ) + + df_not_occluded = pd.DataFrame( + {key: value for key, value in zip(marker_names, mpjpes_not_occuled)} + ) + + df = ( + pd.concat( + [df_not_occluded, df_occluded], keys=["not_occluded", "occluded"], axis=1 + ) + .stack() + .stack() + .to_frame() + .reset_index() + ) + + plt.figure(figsize=(15, 5)) + sns.barplot(data=df, x="level_1", y=0, hue="level_2") + plt.xticks(rotation=90) + plt.ylabel("MPJPE") + plt.xlabel("Joint") + plt.legend(title="Occluded?") + plt.tight_layout() + + output = { + "img": wandb.Image(plt.gcf(), caption="MPJPE"), + "occluded": { + key: list(filter(lambda x: x, value)) + for key, value in zip(marker_names, mpjpes_occuled) + }, + "not_occluded": { + key: list(filter(lambda x: x, value)) + for key, value in zip(marker_names, mpjpes_not_occuled) + }, + } + + if use_wandb: + wandb.log(output) + + plt.close() diff --git a/scripts/eval/human36m/single.py b/scripts/eval/human36m/single.py new file mode 100644 index 0000000..85f0f5f --- /dev/null +++ b/scripts/eval/human36m/single.py @@ -0,0 +1,203 @@ +from propose.datasets.human36m.Human36mDataset import Human36mDataset +from torch_geometric.loader import DataLoader + +from propose.utils.reproducibility import set_random_seed +from propose.evaluation.mpjpe import mpjpe, pa_mpjpe +from propose.evaluation.pck import pck, human36m_joints_to_use + +from propose.models.flows import CondGraphFlow + +import os + +import time +from tqdm import tqdm +import numpy as np + +import wandb + + +def evaluate(flow, test_dataloader, temperature=1.0): + single_mpjpes = [] + single_pa_mpjpes = [] + pck_scores = [] + mean_pck_scores = [] + + iter_dataloader = iter(test_dataloader) + + pbar = tqdm(range(len(test_dataloader))) + + for _ in pbar: + batch, _, action = next(iter_dataloader) + batch.to(flow.device) + + samples = flow.mode_sample(batch) + + true_pose = batch["x"].x.cpu().numpy().reshape(-1, 16, 1, 3) + sample_poses = samples["x"].x.detach().cpu().numpy().reshape(-1, 16, 1, 3) + + true_pose = np.insert(true_pose, 0, 0, axis=1) + sample_poses = np.insert(sample_poses, 0, 0, axis=1) + + pck_score = pck( + true_pose[:, human36m_joints_to_use] / 0.0036, + sample_poses[:, human36m_joints_to_use] / 0.0036, + ) + + has_correct_pose = pck_score.max().unsqueeze(0).numpy() + mean_correct_pose = pck_score.mean().unsqueeze(0).numpy() + + m = mpjpe(true_pose / 0.0036, sample_poses / 0.0036, dim=1) + m_single = m[..., 0] + + pa_m = ( + pa_mpjpe(true_pose[0] / 0.0036, sample_poses[0] / 0.0036, dim=0) + .unsqueeze(0) + .numpy() + ) + + pa_m_single = pa_m[..., 0] + + m_single = m_single.tolist() + + single_mpjpes += [m_single] + single_pa_mpjpes += [pa_m_single] + + pck_scores += [has_correct_pose] + mean_pck_scores += [mean_correct_pose] + + pbar.set_description( + f"Single MPJPE: {np.concatenate(single_mpjpes).mean():.4f} " + f"Single PA MPJPE: {np.concatenate(single_pa_mpjpes).mean():.4f} " + f"PCK: {np.concatenate(pck_scores).mean():.4f} " + f"Mean PCK: {np.concatenate(mean_pck_scores).mean():.4f} " + ) + + return ( + single_mpjpes, + single_pa_mpjpes, + pck_scores, + mean_pck_scores, + ) + + +def mpjpe_experiment(flow, config, name="test", **kwargs): + test_dataset = Human36mDataset( + **config["dataset"], + **kwargs, + ) + test_dataloader = DataLoader( + test_dataset, batch_size=1, shuffle=True, pin_memory=False, num_workers=0 + ) + ( + test_res_single, + test_res_pa_single, + test_res_pck, + test_res_mean_pck, + ) = evaluate(flow, test_dataloader) + + res = { + f"{name}/test_res_single": np.concatenate(test_res_single).mean(), + f"{name}/test_res_pa_single": np.concatenate(test_res_pa_single).mean(), + f"{name}/test_res_pck": np.concatenate(test_res_pck).mean(), + f"{name}/test_res_mean_pck": np.concatenate(test_res_mean_pck).mean(), + } + + return res, test_dataset, test_dataloader + + +def run(use_wandb: bool = False, config: dict = None): + """ + Train a CondGraphFlow on the Human36m dataset. + :param use_wandb: Whether to use wandb for logging. + :param config: A dictionary of configuration parameters. + """ + set_random_seed(config["seed"]) + + config["dataset"]["dirname"] = config["dataset"]["dirname"] + "/test" + + if use_wandb: + wandb.init( + project="propose_human36m", + entity=os.environ["WANDB_USER"], + config=config, + job_type="evaluation", + name=f"{config['experiment_name']}_single_{time.strftime('%d/%m/%Y::%H:%M:%S')}", + tags=config["tags"] if "tags" in config else None, + group=config["group"] if "group" in config else None, + ) + + flow = CondGraphFlow.from_pretrained( + f'ppierzc/propose_human36m/{config["experiment_name"]}:v20' + ) + + config["cuda_accelerated"] = flow.set_device() + flow.eval() + + # Test + test_res, test_dataset, test_dataloader = mpjpe_experiment( + flow, + config, + occlusion_fractions=[], + test=True, + name="test", + ) + + if use_wandb: + wandb.log(test_res) + + # Hard + hard_res, hard_dataset, hard_dataloader = mpjpe_experiment( + flow, + config, + occlusion_fractions=[], + hardsubset=True, + name="hard", + ) + + if use_wandb: + wandb.log(hard_res) + + hard_dataset = Human36mDataset( + **config["dataset"], + occlusion_fractions=[], + hardsubset=True, + ) + + # Occlusion Only + mpjpes = [] + for i in tqdm(range(len(hard_dataset))): + batch = hard_dataset[i][0] + batch.cuda() + samples = flow.mode_sample(batch.cuda()) + + true_pose = ( + batch["x"] + .x.cpu() + .numpy() + .reshape(-1, 16, 1, 3)[:, np.insert(hard_dataset.occlusions[i], 9, False)] + ) + sample_poses = ( + samples["x"] + .x.detach() + .cpu() + .numpy() + .reshape(-1, 16, 1, 3)[:, np.insert(hard_dataset.occlusions[i], 9, False)] + ) + + m = mpjpe(true_pose / 0.0036, sample_poses / 0.0036, dim=1) + m = np.min(m, axis=-1) + + m = m.tolist() + + mpjpes += [m] + + occl_res = np.nanmean(mpjpes) + if use_wandb: + wandb.log({"occl/best_mpjpe": occl_res}) + + print("MPJPE for best") + print("---") + # print(f"H36M: {test_res}") + # print(f"H36MA: {hard_res}") + print(f"Occl.: {occl_res}") + print("---") diff --git a/scripts/eval/human36m/temperature.py b/scripts/eval/human36m/temperature.py new file mode 100644 index 0000000..26c6ee1 --- /dev/null +++ b/scripts/eval/human36m/temperature.py @@ -0,0 +1,171 @@ +from propose.datasets.human36m.Human36mDataset import Human36mDataset +from torch_geometric.loader import DataLoader + +from propose.utils.reproducibility import set_random_seed +from propose.evaluation.mpjpe import mpjpe, pa_mpjpe +from propose.evaluation.pck import pck, human36m_joints_to_use + +from propose.models.flows import CondGraphFlow + +import os + +import time +from tqdm import tqdm +import numpy as np + +import wandb + + +def evaluate(flow, test_dataloader, temperature=1.0, limit=1000): + mpjpes = [] + pa_mpjpes = [] + single_mpjpes = [] + single_pa_mpjpes = [] + pck_scores = [] + mean_pck_scores = [] + + iter_dataloader = iter(test_dataloader) + + if limit is None: + pbar = tqdm(range(len(test_dataloader))) + else: + pbar = tqdm(range(limit)) + + for _ in pbar: + batch, _, action = next(iter_dataloader) + batch.to(flow.device) + + samples = flow.sample(200, batch, temperature=temperature) + + true_pose = batch["x"].x.cpu().numpy().reshape(-1, 16, 1, 3) + sample_poses = samples["x"].x.detach().cpu().numpy().reshape(-1, 16, 200, 3) + + true_pose = np.insert(true_pose, 0, 0, axis=1) + sample_poses = np.insert(sample_poses, 0, 0, axis=1) + + pck_score = pck( + true_pose[:, human36m_joints_to_use] / 0.0036, + sample_poses[:, human36m_joints_to_use] / 0.0036, + ) + + has_correct_pose = pck_score.max().unsqueeze(0).numpy() + mean_correct_pose = pck_score.mean().unsqueeze(0).numpy() + + m = mpjpe(true_pose / 0.0036, sample_poses / 0.0036, dim=1) + m_single = m[..., 0] + m = np.min(m, axis=-1) + + pa_m = ( + pa_mpjpe(true_pose[0] / 0.0036, sample_poses[0] / 0.0036, dim=0) + .unsqueeze(0) + .numpy() + ) + + pa_m_single = pa_m[..., 0] + pa_m = np.min(pa_m, axis=-1) + + m = m.tolist() + pa_m = pa_m.tolist() + m_single = m_single.tolist() + + mpjpes += [m] + pa_mpjpes += [pa_m] + single_mpjpes += [m_single] + single_pa_mpjpes += [pa_m_single] + + pck_scores += [has_correct_pose] + mean_pck_scores += [mean_correct_pose] + + pbar.set_description( + f"MPJPE: {np.concatenate(mpjpes).mean():.4f}, " + f"PA MPJPE: {np.concatenate(pa_mpjpes).mean():.4f}, " + f"Single MPJPE: {np.concatenate(single_mpjpes).mean():.4f} " + f"Single PA MPJPE: {np.concatenate(single_pa_mpjpes).mean():.4f} " + f"PCK: {np.concatenate(pck_scores).mean():.4f} " + f"Mean PCK: {np.concatenate(mean_pck_scores).mean():.4f} " + ) + + return ( + mpjpes, + pa_mpjpes, + single_mpjpes, + single_pa_mpjpes, + pck_scores, + mean_pck_scores, + ) + + +def mpjpe_experiment(flow, config, name="test", temperature=1.0, **kwargs): + test_dataset = Human36mDataset( + **config["dataset"], + **kwargs, + ) + test_dataloader = DataLoader( + test_dataset, batch_size=1, shuffle=True, pin_memory=False, num_workers=0 + ) + ( + test_res, + test_res_pa, + test_res_single, + test_res_pa_single, + test_res_pck, + test_res_mean_pck, + ) = evaluate(flow, test_dataloader, temperature=temperature) + + res = { + f"{name}/test_res": np.concatenate(test_res).mean(), + f"{name}/test_res_pa": np.concatenate(test_res_pa).mean(), + f"{name}/test_res_single": np.concatenate(test_res_single).mean(), + f"{name}/test_res_pa_single": np.concatenate(test_res_pa_single).mean(), + f"{name}/test_res_pck": np.concatenate(test_res_pck).mean(), + f"{name}/test_res_mean_pck": np.concatenate(test_res_mean_pck).mean(), + } + + return res, test_dataset, test_dataloader + + +def run(use_wandb: bool = False, config: dict = None): + """ + Train a CondGraphFlow on the Human36m dataset. + :param use_wandb: Whether to use wandb for logging. + :param config: A dictionary of configuration parameters. + """ + set_random_seed(config["seed"]) + + config["dataset"]["dirname"] = config["dataset"]["dirname"] + "/test" + + if use_wandb: + wandb.init( + project="propose_human36m", + entity=os.environ["WANDB_USER"], + config=config, + job_type="evaluation", + name=f"{config['experiment_name']}_temperature_{time.strftime('%d/%m/%Y::%H:%M:%S')}", + tags=config["tags"] if "tags" in config else None, + group=config["group"] if "group" in config else None, + ) + + flow = CondGraphFlow.from_pretrained( + f'ppierzc/propose_human36m/{config["experiment_name"]}:latest' + ) + + config["cuda_accelerated"] = flow.set_device() + flow.eval() + + temperatures = np.arange(0.1, 1.1, 0.1) + + for temperature in temperatures: + # Test + test_res, test_dataset, test_dataloader = mpjpe_experiment( + flow, + config, + occlusion_fractions=[], + test=True, + name="test", + temperature=temperature, + ) + + test_res["temperature"] = temperature + + if use_wandb: + wandb.log(test_res) diff --git a/scripts/train.py b/scripts/train.py index cdeebfe..c10a99b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -37,6 +37,13 @@ help="Which run to resume", ) +parser.add_argument( + "--resume_id", + default="", + type=str, + help="Id of run which to resume", +) + parser.add_argument( "--experiment", default="mpii-prod.yaml", @@ -76,6 +83,7 @@ if args.wandb: wandb.init( + id=args.resume_id if args.resume_id else None, project="propose_human36m", entity=os.environ["WANDB_USER"], config=config, diff --git a/scripts/train/human36m.py b/scripts/train/human36m.py index 9ed934c..0617433 100644 --- a/scripts/train/human36m.py +++ b/scripts/train/human36m.py @@ -62,6 +62,7 @@ def human36m(use_wandb: bool = False, config: dict = None): ) if use_wandb and wandb.run.resumed: + wandb.restore("checkpoint.pt", root="/tmp") checkpoint = torch.load("/tmp/checkpoint.pt") flow.load_state_dict(checkpoint["model"]) diff --git a/setup.py b/setup.py index 07421ca..7868aca 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,12 @@ description="Probabilistic Pose Estimation", author="Paweł A. Pierzchlewicz", author_email="ppierzc@gmail.com", - packages=find_packages(exclude=[]), + packages=find_packages(), install_requires=[], + include_package_data=True, + package_data={ + "": [ + "*.yaml", + ] + }, ) diff --git a/tests/datasets/rat7m/rat7m_loaders_test.py b/tests/datasets/rat7m/rat7m_loaders_test.py index 9af59db..11d6889 100644 --- a/tests/datasets/rat7m/rat7m_loaders_test.py +++ b/tests/datasets/rat7m/rat7m_loaders_test.py @@ -16,7 +16,7 @@ import numpy as np -path = "/tests/mock/data/mocap-mock.mat" +path = "./tests/mock/data/mocap-mock.mat" def test_rat7m_mocap_loaded(): diff --git a/tests/evaluation/__init__.py b/tests/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/evaluation/mpjpe_test.py b/tests/evaluation/mpjpe_test.py new file mode 100644 index 0000000..d824a4d --- /dev/null +++ b/tests/evaluation/mpjpe_test.py @@ -0,0 +1,182 @@ +from propose.evaluation.mpjpe import mpjpe, pa_mpjpe +from propose.utils.reproducibility import set_random_seed + +import torch +import numpy.testing as npt + +from unittest import TestCase + + +class MPJPETests(TestCase): + def test_mpjpe(self): + error = mpjpe(torch.Tensor([1, 2, 3]), torch.Tensor([4, 5, 6])).item() + + self.assertAlmostEqual(error, 5.196152210235596) + + def test_pa_mpjpe(self): + set_random_seed(1) + p_pred = torch.randn((1, 17, 3)).repeat(200, 1, 1).permute(1, 0, 2) + p_true = torch.randn((1, 17, 3)).permute(1, 0, 2) + + error = pa_mpjpe(p_true, p_pred, dim=0).mean().item() + + self.assertAlmostEqual(error, 1.525496244430542) + + def test_against_wehrbein(self): + """ + Test against Wehrbein et al. implementation. + Their implementation has different input dimensions to our pipeline, so we test whether our adaptation works. + """ + set_random_seed(1) + p_pred = torch.randn((200, 17, 3)) / 0.0036 + p_true = torch.randn((1, 17, 3)) / 0.0036 + + r1 = wehrbein_pampjpe( + p_true.repeat(200, 1, 1), p_pred, return_sum=False, joints=17 + ) + + p_pred = p_pred.permute(1, 0, 2) + p_true = p_true.permute(1, 0, 2) + + r2 = pa_mpjpe(p_true, p_pred, dim=0) + + # Wehrbein et al. implementation has a bug with shape. + # See self.test_wehrbein_is_wrong_proof() + npt.assert_array_equal(r1.shape, r2.shape) + npt.assert_array_equal(r2.shape, torch.Tensor([200]).numpy()) + npt.assert_raises(AssertionError, npt.assert_allclose, r1, r2) + + def test_wehrbein_is_wrong_proof(self): + """ + Test whether the implementation in Wehrbein et al. is wrong. + At some point the procrustes implementation performs: + > torch.view(-1, 3, n_joints) + this is dangerous as the shape might be the same but the order is not. + This tests shows that this is the case. + """ + p_gt = torch.Tensor( + [ + [ + [1, 1, 1], + [2, 2, 2], + [3, 3, 3], + [4, 4, 4], + ], + [ + [5, 5, 5], + [6, 6, 6], + [7, 7, 7], + [8, 8, 8], + ], + ] + ) + + p_gt1 = p_gt.view(-1, 3, p_gt.shape[1]) + p_gt2 = p_gt.permute(0, 2, 1) + + x_range = torch.arange(p_gt.shape[0]) + y_range = torch.arange(p_gt.shape[1]) + z_range = torch.arange(p_gt.shape[2]) + + x_grid, y_grid, z_grid = torch.meshgrid(x_range, y_range, z_range) + + index = torch.stack([x_grid, y_grid, z_grid], -1).view(-1, 3) + + claim_a = [] + claim_b = [] + for i in index: + a = p_gt[i[0], i[1], i[2]] + b = p_gt2[i[0], i[2], i[1]] + c = p_gt1[i[0], i[2], i[1]] + + claim_a.append(a == b) + claim_b.append(a == c) + + self.assertTrue(all(claim_a)) + self.assertFalse(all(claim_b)) + + +# Code for testing the above functions +# Original code from: +# https://github.com/twehrbein/Probabilistic-Monocular-3D-Human-Pose-Estimation-with-Normalizing-Flows/ +# + + +def procrustes_torch_parallel(p_gt, p_pred): + # p_gt and p_pred need to be of shape (-1, 3, #joints) + # care: run on cpu! way faster than on gpu + + mu_gt = p_gt.mean(dim=2) + mu_pred = p_pred.mean(dim=2) + + X0 = p_gt - mu_gt[:, :, None] + Y0 = p_pred - mu_pred[:, :, None] + + ssX = (X0**2.0).sum(dim=(1, 2)) + ssY = (Y0**2.0).sum(dim=(1, 2)) + + # centred Frobenius norm + normX = torch.sqrt(ssX) + normY = torch.sqrt(ssY) + + # scale to equal (unit) norm + X0 /= normX[:, None, None] + Y0 /= normY[:, None, None] + + # optimum rotation matrix of Y + A = torch.bmm(X0, Y0.transpose(1, 2)) + + try: + U, s, V = torch.svd(A, some=True) + except: + print("ERROR IN SVD, could not converge") + print("SVD INPUT IS:") + print(A) + print(A.shape) + exit() + + T = torch.bmm(V, U.transpose(1, 2)) + + # Make sure we have a rotation + detT = torch.det(T) + sign = torch.sign(detT) + V[:, :, -1] *= sign[:, None] + s[:, -1] *= sign + T = torch.bmm(V, U.transpose(1, 2)) + + traceTA = s.sum(dim=1) + + # optimum scaling of Y + b = traceTA * normX / normY + + # standardised distance between X and b*Y*T + c + d = 1 - traceTA**2 + + # transformed coords + scale = normX * traceTA + Z = ( + scale[:, None, None] * torch.bmm(Y0.transpose(1, 2), T) + mu_gt[:, None, :] + ).transpose(1, 2) + + # transformation matrix + c = mu_gt - b[:, None] * (torch.bmm(mu_pred[:, None, :], T)).squeeze() + + # transformation values + tform = {"rotation": T, "scale": b, "translation": c} + return d, Z, tform + + +def wehrbein_pampjpe(p_ref, p, return_sum=True, return_poses=False, joints=17): + p_ref, p = p_ref.view((-1, 3, joints)), p.view((-1, 3, joints)) + d, Z, tform = procrustes_torch_parallel(p_ref.clone(), p) + + if return_sum: + err = torch.sum( + torch.mean(torch.sqrt(torch.sum((p_ref - Z) ** 2, dim=1)), dim=1) + ).item() + else: + err = torch.mean(torch.sqrt(torch.sum((p_ref - Z) ** 2, dim=1)), dim=1) + if not return_poses: + return err + else: + return err, Z diff --git a/tests/evaluation/pck_test.py b/tests/evaluation/pck_test.py new file mode 100644 index 0000000..12d4c15 --- /dev/null +++ b/tests/evaluation/pck_test.py @@ -0,0 +1,105 @@ +import unittest + +import torch +import numpy.testing as npt + +from propose.evaluation.pck import pck +from propose.utils.reproducibility import set_random_seed + + +class PCKTests(unittest.TestCase): + def test_pck_computes_correctly(self): + set_random_seed(1) + p_pred = torch.randn((1, 17, 10, 3)) + p_true = torch.randn((1, 17, 10, 3)) + + error = pck(p_true, p_pred, threshold=4) + + npt.assert_almost_equal( + error.numpy(), + torch.Tensor( + [ + [ + 0.9412, + 0.8824, + 0.9412, + 0.8824, + 0.9412, + 1.0000, + 1.0000, + 1.0000, + 1.0000, + 0.9412, + ] + ] + ).numpy(), + decimal=4, + ) + + def test_one_joint_wrong(self): + a = torch.Tensor( + [ + [ + [ + [1, 1, 1], + ], + [ + [1, 1, 1], + ], + [ + [1, 1, 1], + ], + [ + [1, 1, 1], + ], + ] + ] + ) + + b = torch.Tensor( + [ + [ + [ + [1, 1, 1], + [1, 1, 1], + ], + [ + [1, 10, 1], + [1, 1, 1], + ], + [ + [1, 1, 1], + [1, 1, 1], + ], + [ + [1, 1, 1], + [1, 1, 1], + ], + ] + ] + ) + + n_joints = 4 + n_batches = 1 + n_samples = 2 + n_dim = 3 + + self.assertEqual(a.dim(), 4) + self.assertEqual(b.dim(), 4) + + self.assertEqual(a.shape, (n_batches, n_joints, 1, n_dim)) + self.assertEqual(b.shape, (n_batches, n_joints, n_samples, n_dim)) + + error = pck(a, b, threshold=4) + + self.assertEqual(error.shape, (n_batches, n_samples)) + + npt.assert_almost_equal( + error.numpy(), + torch.Tensor( + [ + [0.7500, 1.0000], + ] + ).numpy(), + decimal=4, + ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29