Skip to content

Commit

Permalink
Update main_train_avatarposer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaxi-jiang authored Feb 17, 2023
1 parent 5d1d606 commit db45b29
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions main_train_avatarposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def main(json_path='options/train_avatarposer.json'):
if current_step % opt['train']['checkpoint_test'] == 0:


rot_error = []
pos_error = []
vel_error = []
pos_error_hands = []
Expand Down Expand Up @@ -229,30 +228,27 @@ def main(json_path='options/train_avatarposer.json'):
gt_angle = gt_angle.reshape(body_parms_gt['pose_body'].shape[0],-1,3)


rot_error_ = torch.mean(torch.absolute(gt_angle-predicted_angle))
pos_error_ = torch.mean(torch.sqrt(torch.sum(torch.square(gt_position-predicted_position),axis=-1)))
pos_error_hands_ = torch.mean(torch.sqrt(torch.sum(torch.square(gt_position-predicted_position),axis=-1))[...,[20,21]])

gt_velocity = (gt_position[1:,...] - gt_position[:-1,...])*60
predicted_velocity = (predicted_position[1:,...] - predicted_position[:-1,...])*60
vel_error_ = torch.mean(torch.sqrt(torch.sum(torch.square(gt_velocity-predicted_velocity),axis=-1)))

rot_error.append(rot_error_)
pos_error.append(pos_error_)
vel_error.append(vel_error_)

pos_error_hands.append(pos_error_hands_)



rot_error = sum(rot_error)/len(rot_error)
pos_error = sum(pos_error)/len(pos_error)
vel_error = sum(vel_error)/len(vel_error)
pos_error_hands = sum(pos_error_hands)/len(pos_error_hands)


# testing log
logger.info('<epoch:{:3d}, iter:{:8,d}, Average rotational error [degree]: {:<.5f}, Average positional error [cm]: {:<.5f}, Average velocity error [cm/s]: {:<.5f}, Average positional error at hand [cm]: {:<.5f}\n'.format(epoch, current_step,rot_error*57.2958, pos_error*100, vel_error*100, pos_error_hands*100))
logger.info('<epoch:{:3d}, iter:{:8,d}, Average positional error [cm]: {:<.5f}, Average velocity error [cm/s]: {:<.5f}, Average positional error at hand [cm]: {:<.5f}\n'.format(epoch, current_step,pos_error*100, vel_error*100, pos_error_hands*100))


logger.info('Saving the final model.')
Expand Down

0 comments on commit db45b29

Please sign in to comment.