From 0bd9565cc283d10d085e101ed7a20dcd77b2cdcc Mon Sep 17 00:00:00 2001 From: myhloli Date: Thu, 26 Dec 2024 11:18:45 +0800 Subject: [PATCH] refactor(datasets): remove unused video processing code and update dependencies - Remove load_video function and related imports - Update unimernet version to 0.2.3 in pyproject.toml - Remove eva-decord dependency --- pyproject.toml | 3 +-- unimernet/datasets/data_utils.py | 26 -------------------------- 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aaae535..8b5e322 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "unimernet" -version = "0.2.2" +version = "0.2.3" description = 'UniMERNet: A Universal Network for Real-World Mathematical Expression Recognition' authors = ["Bin Wang "] readme = "README.md" @@ -30,7 +30,6 @@ fairscale = "^0.4.13" ftfy = {version = "^6.2.0", python = ">=3.10,<4.0"} albumentations = "^1.4.4" wand = "^0.6.13" -eva-decord = "^0.6.1" webdataset = "^0.2.86" rapidfuzz = "^3.8.1" termcolor = "^2.4.0" diff --git a/unimernet/datasets/data_utils.py b/unimernet/datasets/data_utils.py index 4e82c90..7a22b46 100644 --- a/unimernet/datasets/data_utils.py +++ b/unimernet/datasets/data_utils.py @@ -12,43 +12,17 @@ import tarfile import zipfile -import decord import webdataset as wds import numpy as np import torch from torch.utils.data.dataset import IterableDataset, ChainDataset -from decord import VideoReader from unimernet.common.registry import registry from unimernet.datasets.datasets.base_dataset import ConcatDataset from tqdm import tqdm -decord.bridge.set_bridge("torch") MAX_INT = registry.get("MAX_INT") -def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform"): - vr = VideoReader(uri=video_path, height=height, width=width) - - vlen = len(vr) - start, end = 0, vlen - - n_frms = min(n_frms, vlen) - - if sampling == "uniform": - indices = np.arange(start, end, vlen / n_frms).astype(int) - elif sampling == "headtail": - indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2)) - indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2)) - indices = indices_h + indices_t - else: - raise NotImplementedError - - # get_batch -> T, H, W, C - frms = vr.get_batch(indices).permute(3, 0, 1, 2).float() # (C, T, H, W) - - return frms - - def apply_to_sample(f, sample): if len(sample) == 0: return {}