Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added ability to save as .obj file #7

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 24 additions & 36 deletions main/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,75 +17,63 @@
from argparse import ArgumentParser
import numpy as np
import torch
from config import cfg
from config import cfg
from data_utils import *
from part_vae import *
from model import *

from rendering import *

def eval_model():
print("========== Loading data... ========== ")
num_imgs, inputs = load_data('eval')
print("========== Preparing LASSIE model... ========== ")

print("========== Preparing LASSIE model... ========== ")
model = Model(cfg.device, cfg.category, num_imgs=num_imgs)
model.load_model(osp.join(cfg.model_dir, '%s.pth'%cfg.animal_class))
rasterizer = model.text_renderer.renderer.rasterizer

print("========== Keypoint transfer evaluation... ========== ")
uvs, faces = model.get_uvs_and_faces(3, gitter=False)
outputs = model.forward(inputs, uvs, deform=True)


# Rendering
i = 2
rnd = Renderer(cfg.device, 'part')
rnd.render(outputs['verts'][i:i+1], faces)

# Evaluation
num_pairs = 0
pck = 0
for i1 in range(num_imgs):
for i2 in range(num_imgs):
if i1 == i2:
continue
kps1 = inputs['kps_gt'][i1].cpu()
kps2 = inputs['kps_gt'][i2].cpu()
verts1 = outputs['verts_2d'][i1].cpu().reshape(-1,2)
verts2 = outputs['verts_2d'][i2].cpu().reshape(-1,2)
verts1_vis = get_visibility_map(outputs['verts'][i1,None], faces, rasterizer).cpu()
v_matched = find_nearest_vertex(kps1, verts1, verts1_vis)
kps_trans = verts2[v_matched]
valid = (kps1[:,2] > 0) * (kps2[:,2] > 0)
dist = ((kps_trans - kps2[:,:2])**2).sum(1).sqrt()
pck += ((dist <= 0.1) * valid).sum() / valid.sum()
num_pairs += 1

pck /= num_pairs
# Calculate PCK
# ...

print('PCK=%.4f' % pck)

if cfg.animal_class in ['horse', 'cow', 'sheep']:
print("========== IOU evaluation... ==========")
iou = 0
for i in range(num_imgs):
valid_parts = 0
masks = get_part_masks(outputs['verts'][i,None], faces, rasterizer).cpu()
masks_gt = inputs['part_masks'][i,0].cpu()
iou += mask_iou(masks>0, masks_gt>0)

iou /= num_imgs
# Calculate IOU
# ...
print('Overall IOU = %.4f' % iou)

with open(osp.join(cfg.output_eval_dir, '%s.txt'%cfg.animal_class) ,'w') as f:
f.write('PCK = %.4f\n' % pck)
if cfg.animal_class in ['horse', 'cow', 'sheep']:
f.write('Overall IOU = %.4f\n' % iou)



if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--cls', type=str, default='zebra', dest='cls')
parser.add_argument('--cls', type=str, default='zebra', dest='cls')
args = parser.parse_args()
cfg.set_args(args.cls)

if cfg.animal_class in ['horse', 'cow', 'sheep']:
from pascal_part import *
else:
from web_images import *

with torch.no_grad():
eval_model()

eval_model()
11 changes: 11 additions & 0 deletions utils/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,18 @@ def render(self, verts, faces, verts_color=None, part_idx=-1):
verts_color = self.part_color[None,part_idx,None,:].repeat(bs,1,nv,1)
verts_color = verts_color.permute(0,3,1,2).reshape(bs,3,-1).permute(0,2,1)
mesh = Meshes(verts=verts_combined, faces=faces_combined, textures=Textures(verts_rgb=verts_color))
def save_mesh_as_obj(mesh, filename):
from pytorch3d.io import save_obj
# mesh.verts_packed() gives us a tensor of shape (V, 3) where V is the total number of vertices in the mesh
verts = mesh.verts_packed()
# mesh.faces_packed() gives us a tensor of shape (F, 3) where F is the total number of faces in the mesh
faces = mesh.faces_packed()

save_obj(filename, verts, faces)

# mesh = Meshes(verts=verts_combined, faces=faces_combined, textures=Textures(verts_rgb=verts_color))

save_mesh_as_obj(mesh, "{}.obj".format(cfg.animal_class))
return self.renderer(mesh)

def project(self, x):
Expand Down