diff --git a/opensora/models/causalvideovae/dataset/video_dataset.py b/opensora/models/causalvideovae/dataset/video_dataset.py index a375d3a0..bdc406bf 100644 --- a/opensora/models/causalvideovae/dataset/video_dataset.py +++ b/opensora/models/causalvideovae/dataset/video_dataset.py @@ -246,5 +246,5 @@ def _load_video(self, video_path, sample_rate=None): frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) video_data = decord_vr.get_batch(frame_id_list).asnumpy() video_data = torch.from_numpy(video_data) - video_data = video_data.permute(3, 0, 1, 2) + video_data = video_data.permute(0, 3, 1, 2) return video_data diff --git a/opensora/models/causalvideovae/eval/cal_fvd.py b/opensora/models/causalvideovae/eval/cal_fvd.py index 1f1a9806..f7c55239 100755 --- a/opensora/models/causalvideovae/eval/cal_fvd.py +++ b/opensora/models/causalvideovae/eval/cal_fvd.py @@ -13,7 +13,8 @@ def trans(x): return x def calculate_fvd(videos1, videos2, device, method='styleganv'): - + videos1 = videos1.to(device) + videos2 = videos2.to(device) if method == 'styleganv': from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained elif method == 'videogpt': diff --git a/opensora/models/causalvideovae/eval/eval.py b/opensora/models/causalvideovae/eval/eval.py index ec773200..46aa194f 100755 --- a/opensora/models/causalvideovae/eval/eval.py +++ b/opensora/models/causalvideovae/eval/eval.py @@ -84,25 +84,18 @@ def __getitem__(self, index): def calculate_common_metric(args, dataloader, device): score_list = [] - for batch_data in tqdm(dataloader): - real_videos = batch_data["real"].to(device) - generated_videos = batch_data["generated"].to(device) - + for batch_data in tqdm(dataloader): # {'real': real_video_tensor, 'generated':generated_video_tensor } + real_videos = batch_data['real'] + generated_videos = batch_data['generated'] assert real_videos.shape[2] == generated_videos.shape[2] - if args.metric == "fvd": - tmp_list = list( - calculate_fvd( - real_videos, generated_videos, args.device, method=args.fvd_method - )["value"].values() - ) - elif args.metric == "ssim": - tmp_list = list( - calculate_ssim(real_videos, generated_videos)["value"].values() - ) - elif args.metric == "psnr": - tmp_list = [calculate_psnr(real_videos, generated_videos)] + if args.metric == 'fvd': + tmp_list = list(calculate_fvd(real_videos, generated_videos, args.device, method=args.fvd_method)['value'].values()) + elif args.metric == 'ssim': + tmp_list = list(calculate_ssim(real_videos, generated_videos)['value'].values()) + elif args.metric == 'psnr': + tmp_list = list(calculate_psnr(real_videos, generated_videos)['value'].values()) else: - tmp_list = [calculate_lpips(real_videos, generated_videos, args.device)] + tmp_list = list(calculate_lpips(real_videos, generated_videos, args.device)['value'].values()) score_list += tmp_list return np.mean(score_list) diff --git a/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py b/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py index 3043a2a4..68ef36fa 100755 --- a/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py +++ b/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py @@ -14,7 +14,7 @@ def load_i3d_pretrained(device=torch.device('cpu')): print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") i3d = torch.jit.load(filepath).eval().to(device) - i3d = torch.nn.DataParallel(i3d) + #i3d = torch.nn.DataParallel(i3d) return i3d @@ -87,4 +87,4 @@ def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) else: fid = np.real(m) - return float(fid) \ No newline at end of file + return float(fid)