diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..ab672ed --- /dev/null +++ b/.dockerignore @@ -0,0 +1,152 @@ +log/ +output/ +.vscode/ +workspace/ +run*.sh + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +output/ +example_videos/ +torchlogs/ +*.png +*.csv +*.txt +*junk* +*.profile + +!requirements*.txt + +.DS_Store +*.jpg +*.zip +*.sh +Dockerfile \ No newline at end of file diff --git a/.gitignore b/.gitignore index bee1e68..096f495 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,16 @@ dmypy.json # Pyre type checker .pyre/ + +output/ +example_videos/ +torchlogs/ +*.png +*.csv +*.txt +*junk* +*.profile + +.DS_Store +*.jpg +*.zip \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..ed43ff8 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "XMem_utilities"] + path = XMem_utilities + url = git@github.com:max810/Xmem_utility_scripts.git diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..f34cfa3 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,25 @@ +# Use the specified PyTorch base image with CUDA support +FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime AS xmem2-base-inference + +# Set the working directory in the container +WORKDIR /app + +RUN python -m pip install --no-cache-dir opencv-python-headless Pillow==9.2.0 + +# Install Python dependencies from requirements.txt +COPY requirements.txt /app/requirements.txt +RUN python -m pip install --no-cache-dir -r requirements.txt + +# Copy the application files into the container +COPY . /app + +# FOR GUI - only a few extra dependencies +FROM xmem2-base-inference AS xmem2-gui + +# Qt dependencies +RUN apt-get update && apt-get install -y build-essential libgl1 libglib2.0-0 libxkbcommon-x11-0 '^libxcb.*-dev' libx11-xcb-dev libglu1-mesa-dev libxrender-dev libxi-dev libxkbcommon-dev libxkbcommon-x11-dev libfontconfig libdbus-1-3 mesa-utils libgl1-mesa-glx +RUN /bin/bash -c 'gcc --version' + +RUN python -m pip install --no-cache-dir -r requirements_demo.txt +# To avoid error messages when launching PyQT +ENV LIBGL_ALWAYS_INDIRECT=1 \ No newline at end of file diff --git a/LICENSE_PUMaVOS b/LICENSE_PUMaVOS new file mode 100644 index 0000000..da6ab6c --- /dev/null +++ b/LICENSE_PUMaVOS @@ -0,0 +1,396 @@ +Attribution 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution 4.0 International Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution 4.0 International Public License ("Public License"). To the +extent this Public License may be interpreted as a contract, You are +granted the Licensed Rights in consideration of Your acceptance of +these terms and conditions, and the Licensor grants You such rights in +consideration of benefits the Licensor receives from making the +Licensed Material available under these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + j. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + k. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part; and + + b. produce, reproduce, and Share Adapted Material. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. + diff --git a/README.md b/README.md index ee6f064..2724bbf 100644 --- a/README.md +++ b/README.md @@ -1,172 +1,316 @@ -# XMem +# XMem++ -## Long-Term Video Object Segmentation with an Atkinson-Shiffrin Memory Model +## Production-level Video Segmentation From Few Annotated Frames -[Ho Kei Cheng](https://hkchengrex.github.io/), [Alexander Schwing](https://www.alexander-schwing.de/) +[Maksym Bekuzarov](https://www.linkedin.com/in/maksym-bekuzarov-947490165/)$^\dagger$, [Ariana Michelle Bermudez Venegas](https://www.linkedin.com/in/ariana-bermudez/)$^\dagger$, [Joon-Young Lee](https://joonyoung-cv.github.io/), [Hao Li](https://www.hao-li.com/Hao_Li/Hao_Li_-_about_me.html) -University of Illinois Urbana-Champaign +[Metaverse Lab TODO LINK]() @ [MBZUAI](https://mbzuai.ac.ae/) (Mohamed bin Zayed University of Artificial Intelligence) -[[arXiv]](https://arxiv.org/abs/2207.07115) [[PDF]](https://arxiv.org/pdf/2207.07115.pdf) [[Project Page]](https://hkchengrex.github.io/XMem/) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RXK5QsUo2-CnOiy5AOSjoZggPVHOPh1m?usp=sharing) +[[arXiv]](https://arxiv.org/abs/2307.15958) [[PDF]](https://arxiv.org/pdf/2307.15958.pdf) [[Project Page]](https://xmem2.surge.sh) +$^\dagger$ These authors equally contributed to the work. + +## Table of contents +* [Performance demo)](#demo) +* [Overview](#overview) +* [Getting started](#getting-started) +* [Use the GUI](#use-the-gui) +* [Use **XMem++** command-line and Python interface](#use-xmem-command-line-and-python-interface) +* [Importing existing projects](#importing-existing-projects) +* [Docker support](#docker-support) +* [Data format](#data-format) +* [Training](#training) +* [Methodology](#methodology) +* [Frame annotation candidate selector](#frame-annotation-candidate-selector) +* [PUMaVOS Dataset](#pumavos-dataset) +* [Citation](#citation) ## Demo -Handling long-term occlusion: +Inspired by movie industry use cases, **XMem++** is an Interactive Video Segmentation Tool that takes a few user-provided segmentation masks and segments very challenging use cases with minimal human supervision, such as + +- **parts** of the objects (only 6 annotated frames provided): -https://user-images.githubusercontent.com/7107196/177921527-7a1bd593-2162-4598-9adf-f2112763fccf.mp4 +https://github.com/max810/XMem2/assets/29955120/d700ccc4-194e-46d8-97b2-05b0587496f4 -Very-long video; masked layer insertion: +- **fluid** objects like hair (only 5 annotated frames provided): -https://user-images.githubusercontent.com/7107196/179089789-3d69adea-0405-4c83-ac28-45f59fe1e1c1.mp4 +https://github.com/max810/XMem2/assets/29955120/06d4d8ee-3092-4fe6-a0c2-da7e3fc4a01c -Source: https://www.youtube.com/watch?v=q5Xr0F4a0iU +- **deformable** objects like clothes (5 and 11 annotated frames used accordingly) -Out-of-domain case: +https://github.com/max810/XMem2/assets/29955120/a8e75648-b8cf-4312-8077-276597256289 -https://user-images.githubusercontent.com/7107196/177920383-161f1da1-33f9-48b3-b8b2-09e450432e2b.mp4 +https://github.com/max810/XMem2/assets/29955120/63e6704c-3292-4690-970e-818ab2950c56 -Source: かぐや様は告らせたい ~天才たちの恋愛頭脳戦~ Ep.3; A-1 Pictures +### [[LIMITATIONS]](docs/LIMITATIONS.md) -### [[Failure Cases]](docs/FAILURE_CASES.md) +## Overview -## Features +| ![Demo GUI](docs/resources/gui_demo.jpg) | +|:--:| +| _XMem++ updated GUI_ | -* Handle very long videos with limited GPU memory usage. -* Quite fast. Expect ~20 FPS even with long videos (hardware dependent). +**XMem++** is built on top of [XMem](https://github.com/hkchengrex/XMem) by [Ho Kei Cheng](https://hkchengrex.github.io/), [Alexander Schwing](https://www.alexander-schwing.de/) and improves upon it by adding the following: +1. [Permanent memory module](#methodology) that greatly improves the model's accuracy with just a few manual annotations provided (see results) +2. [Annotation candidate selection algorithm](#frame-annotation-candidate-selector) that selects $k$ next best frames for the user to provide annotations for. +3. We used **XMem++** to collect and annotate **PUMaVOS** - 23 video dataset with unusual and challenging annotation scenarios at 480p, 30FPS. See [Dataset](#pumavos-dataset) + +In addition to the following features: +* Improved GUI - references tab to see/edit what frames are in the permanent memory, candidates tab - shows candidate frames for annotation predicted by the algorithm and more. +* Negligible speed and memory usage overhead compared to XMem (if using few manually provided annotations) +* [Easy to use Python interface](docs/PYTHON_API.md) - now you can use **XMem++** as a GUI application and a Python library easily. +* 30+ FPS on 480p footage on RTX 3090 * Come with a GUI (modified from [MiVOS](https://github.com/hkchengrex/MiVOS/tree/MiVOS-STCN)). -### Table of Contents +## Getting started +### Environment setup +First, install the required Python packages: -1. [Introduction](#introduction) -2. [Results](docs/RESULTS.md) -3. [Interactive GUI demo](docs/DEMO.md) -4. [Training/inference](#traininginference) -5. [Citation](#citation) +* Python 3.8+ +* PyTorch 1.11+ (See [PyTorch](https://pytorch.org/) for installation instructions) +* `torchvision` corresponding to the PyTorch version +* OpenCV (try `pip install opencv-python`) +* Others: `pip install -r requirements.txt` +* To use the GUI: `pip install -r requirements_demo.txt` -### Introduction +### Download weights -![framework](https://imgur.com/ToE2frx.jpg) +Download the pretrained models either using `./scripts/download_models.sh`, or manually and put them in `./saves` (create the folder if it doesn't exist). You can download them from [[XMem GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[XMem Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing). For inference you only need `XMem.pth`, but for GUI also download `fbrs.pth` and `s2m.pth`. -We frame Video Object Segmentation (VOS), first and foremost, as a *memory* problem. -Prior works mostly use a single type of feature memory. This can be in the form of network weights (i.e., online learning), last frame segmentation (e.g., MaskTrack), spatial hidden representation (e.g., Conv-RNN-based methods), spatial-attentional features (e.g., STM, STCN, AOT), or some sort of long-term compact features (e.g., AFB-URR). +## Use the GUI +To run the GUI on a new video: +```Bash +python interactive_demo.py --video example_videos/chair/chair.mp4 +``` -Methods with a short memory span are not robust to changes, while those with a large memory bank are subject to a catastrophic increase in computation and GPU memory usage. Attempts at long-term attentional VOS like AFB-URR compress features eagerly as soon as they are generated, leading to a loss of feature resolution. +To run on a list of images: +```Bash +python interactive_demo.py --images example_videos/chair/JPEGImages +``` -Our method is inspired by the Atkinson-Shiffrin human memory model, which has a *sensory memory*, a *working memory*, and a *long-term memory*. These memory stores have different temporal scales and complement each other in our memory reading mechanism. It performs well in both short-term and long-term video datasets, handling videos with more than 10,000 frames with ease. +Both of these commands will create a folder for the current vide in workspace folder (default is `.workspace`) and save all the masks and predictions there. -### Training/inference -First, install the required python packages and datasets following [GETTING_STARTED.md](docs/GETTING_STARTED.md). +To keep editing an existing project in a workspace, run the following command: +```Bash +python interactive_demo.py --workspace ./workspace/ +``` -For training, see [TRAINING.md](docs/TRAINING.md). +If you have more than 1 object make sure to add `--num-objects ` to the commands above the **first time you create a project**. It will saved in the project file after that for your convenience =) -For inference, see [INFERENCE.md](docs/INFERENCE.md). +Like this: +```Bash +python interactive_demo.py --images example_videos/caps/JPEGImages --num-objects 2 +``` -### Citation +For more information visit [DEMO.md](docs/DEMO.md) -Please cite our paper if you find this repo useful! +## Use **XMem++** command-line and Python interface +We provide a simple command-line interface in [process_video.py](process_video.py) which you can use like this: -```bibtex -@inproceedings{cheng2022xmem, - title={{XMem}: Long-Term Video Object Segmentation with an Atkinson-Shiffrin Memory Model}, - author={Cheng, Ho Kei and Alexander G. Schwing}, - booktitle={ECCV}, - year={2022} -} +```Bash +python process_video.py \ + --video \ + --masks \ + --output ``` +The script will just take existing video and ground truth masks (all in the given directory will be used) and runs segmentation once. -Related projects that this paper is developed upon: -```bibtex -@inproceedings{cheng2021stcn, - title={Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation}, - author={Cheng, Ho Kei and Tai, Yu-Wing and Tang, Chi-Keung}, - booktitle={NeurIPS}, - year={2021} -} +Short-form arguments `-v -m -o` are also supported. -@inproceedings{cheng2021mivos, - title={Modular Interactive Video Object Segmentation: Interaction-to-Mask, Propagation and Difference-Aware Fusion}, - author={Cheng, Ho Kei and Tai, Yu-Wing and Tang, Chi-Keung}, - booktitle={CVPR}, - year={2021} -} -``` +See [Python API](docs/PYTHON_API.md) or [main.py](main.py) for more complex use-cases and explanations. -We use f-BRS in the interactive demo: https://github.com/saic-vul/fbrs_interactive_segmentation +## Importing existing projects -And if you want to cite the datasets: +If you already have existing frames and/or masks from other tools, you can import them into the workspace with the following command: -
- +```Bash +python import_existing.py --name [--images ] [--mask ] +``` -bibtex +_One of `--images`, `--masks` (or both) should be specified_. - +_You can also specify `--size ` to resize the frames on-the-fly (to smaller side, preserving ratio)_ -```bibtex -@inproceedings{shi2015hierarchicalECSSD, - title={Hierarchical image saliency detection on extended CSSD}, - author={Shi, Jianping and Yan, Qiong and Xu, Li and Jia, Jiaya}, - booktitle={TPAMI}, - year={2015}, -} +This will do the following: +1. Create a project directory inside your woskpace with the name from the `--name` argument. +2. Copy your given images/masks inside. +3. Convert RGB masks to necessary color palette (XMem++ uses [DAVIS color palette](util/palette.py), where each new object=new color). +4. Resize the frames if specified with the `--size` argument. -@inproceedings{wang2017DUTS, - title={Learning to Detect Salient Objects with Image-level Supervision}, - author={Wang, Lijun and Lu, Huchuan and Wang, Yifan and Feng, Mengyang - and Wang, Dong, and Yin, Baocai and Ruan, Xiang}, - booktitle={CVPR}, - year={2017} -} +## Docker support +We provide 2 images at [DockerHub](https://hub.docker.com/repository/docker/max810/xmem2/general): +- `max810/xmem2:base-inference` - smaller and lighter - for running inference from command line as in [Command line section](#use-xmem-command-line-and-python-interface). +- `max810/xmem2:gui` - for running the graphical interface interactively. -@inproceedings{FSS1000, - title = {FSS-1000: A 1000-Class Dataset for Few-Shot Segmentation}, - author = {Li, Xiang and Wei, Tianhan and Chen, Yau Pun and Tai, Yu-Wing and Tang, Chi-Keung}, - booktitle={CVPR}, - year={2020} -} +To use them just run `./run_inference_in_docker.sh` or `./run_gui_in_docker.sh` with corresponding cmd/gui arguments (see respective sections [[Inference]](#use-xmem-command-line-and-python-interface) [[GUI]](#use-the-gui)). _They supply proper arguments to `docker run` command and create the corresponding volumes for input/output directories automatically_. -@inproceedings{zeng2019towardsHRSOD, - title = {Towards High-Resolution Salient Object Detection}, - author = {Zeng, Yi and Zhang, Pingping and Zhang, Jianming and Lin, Zhe and Lu, Huchuan}, - booktitle = {ICCV}, - year = {2019} -} +Examples: +```Bash +# Inference +./run_inference_in_docker.sh -v example_videos/caps/JPEGImages -m example_videos/caps/Annotations -o directory/that/does/not/exist/yet -@inproceedings{cheng2020cascadepsp, - title={{CascadePSP}: Toward Class-Agnostic and Very High-Resolution Segmentation via Global and Local Refinement}, - author={Cheng, Ho Kei and Chung, Jihoon and Tai, Yu-Wing and Tang, Chi-Keung}, - booktitle={CVPR}, - year={2020} -} +# Interactive GUI +./run_gui_in_docker.sh --video example_videos/chair/chair.mp4 --num_objects 2 +``` +For the GUI you can change variables `$LOCAL_WORKSPACE_DIR` and `$DISPLAY_TO_USE` in [run_gui_in_docker.sh](run_gui_in_docker.sh#L54) if necessary. -@inproceedings{xu2018youtubeVOS, - title={Youtube-vos: A large-scale video object segmentation benchmark}, - author={Xu, Ning and Yang, Linjie and Fan, Yuchen and Yue, Dingcheng and Liang, Yuchen and Yang, Jianchao and Huang, Thomas}, - booktitle = {ECCV}, - year={2018} -} +_Be wary that the interactive import buttons will not work (they will open paths within the container filesystem, not the host one)._ +### Building your own images +For command-line inference: -@inproceedings{perazzi2016benchmark, - title={A benchmark dataset and evaluation methodology for video object segmentation}, - author={Perazzi, Federico and Pont-Tuset, Jordi and McWilliams, Brian and Van Gool, Luc and Gross, Markus and Sorkine-Hornung, Alexander}, - booktitle={CVPR}, - year={2016} -} +```Bash +docker build . -t --target xmem2-base-inference +``` -@inproceedings{denninger2019blenderproc, - title={BlenderProc}, - author={Denninger, Maximilian and Sundermeyer, Martin and Winkelbauer, Dominik and Zidan, Youssef and Olefir, Dmitry and Elbadrawy, Mohamad and Lodhi, Ahsan and Katam, Harinandan}, - booktitle={arXiv:1911.01911}, - year={2019} -} +For GUI: +```Bash +docker build . -t --target xmem2-gui +``` +## Data format +- Images are expected to use .jpg format. +- Masks are RGB .png files that use the [DAVIS color palette](util/palette.py), saved as a palette image (`Image.convert('P')` in [Pillow Image Module](https://pillow.readthedocs.io/en/latest/reference/Image.html#PIL.Image.Image.convert))). If your masks don't follow this color palette, just use run `python import_existing.py` to automatically convert them (see [Importing existing projects](#importing-existing-projects)). +- When using `run_on_video.py` with a video_file, masks should be named `frame_%06d.` starting at `0`: `frame_000000.jpg, frame_0000001.jpg, ...` **This is preferred filename for any use case**. + +More information and convenience commands are provided in [Data format help](docs/DATA_FORMAT_HELP.md) + +## Training +For training, refer to the [original XMem repo](https://github.com/hkchengrex/XMem/blob/main/docs/TRAINING.md). + +_We use the original weights provided by XMem, the model has not been retrained or fine-tuned in any way._ + +_Feel free to fine-tune XMem and replace the weights in this project._ + +## Methodology + +| ![XMem++ architecture overview with comments](docs/resources/architecture_explanations.jpg) | +|:--:| +| *XMem++ architecture overview with comments* | + +XMem++ is a **memory-based** interactive segmentation model - this means it uses a set of reference frames/feature maps and their corresponding masks, either predicted or given as ground truth if available, to predict masks for new frames based on **how similar they are to already processed frames** with known segmentation. + +Just like XMem, we use the two types of memory inspired by the Atkinson-Shiffrin human memory model - working memory and long-term memory. The first one stores recent convolutional feature maps with rich details, and the other - heavily compressed features for long-term dependencies across frames that are far apart in the video. + +However, existing segmentation methods ([XMem](https://arxiv.org/abs/2207.07115), [TBD](https://arxiv.org/abs/2207.06953), [AoT](https://arxiv.org/abs/2106.02638), [DeAOT](https://arxiv.org/abs/2210.09782), [STCN](https://arxiv.org/abs/2106.05210), etc.) that are using memory mechanisms to predict the segmentation mask for the current frame, typically process frames one by one, and thus suffer from a common issue - "jumps" in visual quality, when the new ground truth annotation is encountered in the video + +| ![Why permanent memory helps](docs/resources/why_permanent_memory_helps.jpg) | +|:--:| +| *Why permanent memory helps - having multiple annotations from different parts of the video in permanent memory allows the model to __smoothly interpolate__ between different scenes/appearances of the target object* | + +To solve this, we propose a new **permanent memory module** - same in implementations as XMem's working memory - we take all the annotations the user provides, process them and put in the permanent memory module. This way **every** ground truth annonation provided by the user can influence **any** frame in the video regardless where it is located. This increases overall segmentation accuracy and allows the model to smoothly interpolate between different appearences of the object (see figure above). + +For mode details refer to our arxiv page [[Arxiv]](https://arxiv.org/abs/2307.15958v1) [[PDF]](https://arxiv.org/pdf/2307.15958v1), Section 3.2. +## Frame annotation candidate selector + +We propose a simple algorithm to select which frames the user should annotate next to maximize performance and save time. It is based on an idea of **diversity** - to select the frames that capture the most variety of the target object's appearance - to **maximize the information** the network will get with them annotated. + +It has the following properties: +- **Target-specific**: Choice of frames depends on which object you are trying to segment. + +

+ +

+ +- **Model-generic**: it operates on convolutional feature maps and pixel-similarity metric (negative $\mathcal{L}_{2}$ distance), so is not specific to **XMem++**. +- **No restrictions on segmentation targets**: Some methods try to automatically estimate the visual quality of the segmentation, which puts an implicit assumption **that good-quality segmentation follows low-level image ques (edges, corners, etc.)**. However, this is not true when segmenting parts of objects, see the : + +

+ +

+ +- **Deterministic and simple**: It orders remaining frames by a **diversity measure** and the user just picks the top $k$ most diverse candidates. + + +For mode details refer to our arxiv page [[Arxiv]](https://arxiv.org/abs/2307.15958v1) [[PDF]](https://arxiv.org/pdf/2307.15958v1), Section 3.3 and Appendix D. +## PUMaVOS Dataset + +We used XMem++ to collect and annotate a dataset of challenging and practical use cases inspired by the movie production industry. + + + + + + + + + + + + + + +
+
+

Billie Shoes

+

Shoes
("billie_shoes" video)

+
+
+
+

Short Chair

+

Reflections
("chair" video)

+
+
+
+

Dog Tail

+

Body parts
("dog_tail" video)

+
+
+
+

Workout Pants

+

Deformable objects
("pants_workout" video)

+
+
+
+

SKZ

+

Similar objects, occlusion
("skz" video)

+
+
+
+

Tattoo

+

Tattos/patterns
("tattoo" video)

+
+
+
+

Ice Cream

+

Quick motion
("ice_cream" video)

+
+
+
+

Vlog

+

Multi-object parts
("vlog" video)

+
+
+ + +**Partial and Unusual Masks for Video Object Segmentation (PUMaVOS)** dataset has the following properties: +- **23** videos, **19770** densely-annotated frames; +- Covers complex practical use cases such as object parts, frequent occlusions, fast motion, deformable objects and more; +- Average length of the video is **659 frames** or **28s**, with the longer ones spanning **1min**; +- Fully densely annotated at 30FPS; +- Benchmark-oriented: no separation into training/test, designed to be as diverse as possible to test your models; +- 100% open and free to download. + +### Download +Separate sequences and masks are available here: [[TODO Google Drive]](TODO) + +PUMaVOS `.zip` download link: [[TODO Google Drive]](TODO) + +### LICENSE + +PUMaVOS is released under [CC BY 4.0 license](https://creativecommons.org/licenses/by/4.0/), - you can use it for any purpose (including commercial), you only need to credit the authors (us) whenever you do and indicate if you've made any modifications. See the full license text in [LICENSE_PUMaVOS](LICENSE_PUMaVOS) + +## Citation + +If you are using this code or PUMaVOS dataset in your work, please cite us: -@inproceedings{shapenet2015, - title = {{ShapeNet: An Information-Rich 3D Model Repository}}, - author = {Chang, Angel Xuan and Funkhouser, Thomas and Guibas, Leonidas and Hanrahan, Pat and Huang, Qixing and Li, Zimo and Savarese, Silvio and Savva, Manolis and Song, Shuran and Su, Hao and Xiao, Jianxiong and Yi, Li and Yu, Fisher}, - booktitle = {arXiv:1512.03012}, - year = {2015} +``` +@misc{bekuzarov2023xmem, + title={XMem++: Production-level Video Segmentation From Few Annotated Frames}, + author={Maksym Bekuzarov and Ariana Bermudez and Joon-Young Lee and Hao Li}, + year={2023}, + eprint={2307.15958}, + archivePrefix={arXiv}, + primaryClass={cs.CV} } ``` -
- -Contact: +Contact: , , , diff --git a/XMem_utilities b/XMem_utilities new file mode 160000 index 0000000..46df79e --- /dev/null +++ b/XMem_utilities @@ -0,0 +1 @@ +Subproject commit 46df79ecde808089cd9244ba43d8f28863c9fa71 diff --git a/docs/DATA_FORMAT_HELP.md b/docs/DATA_FORMAT_HELP.md new file mode 100644 index 0000000..fb9c156 --- /dev/null +++ b/docs/DATA_FORMAT_HELP.md @@ -0,0 +1,119 @@ +# Data format for processing +## GUI +For the GUI demo the following constraints must be satisfied: +- If using the `--video` argument, a video file can be anything that OpenCV can read. +- If using the `--images` argument, they should be named `frame_000000.jpg`, `frame_000001.jpg`, etc. General format: `frame_%06d.jpg` +- If using the `--workspace` argument, the frames are already saved inside the workspace, so you don't need to do anything. + +## Importing existing projects + +The name constraints for frames are the same as for the GUI. + +When importing masks, they should be called `frame_000000.png`, `frame_000001.png`, etc. General format: `frame_%06d.png`. + +## Command line or Python API + +The command line and Python API will work fine as long as: +- The frames and masks files have the **same name** (e.g. `001.jpg`, `002.jpg` <-> `001.png`, `002.png`) +- Default alphabetic sorting sorts in the correct increasing order (So `['1.jpg', '10.jpg', '2.jpg']` does not happen). Simplest way to achieve this is by making sure all the numbers in the filenames are **prepended with `0`** to the same length. +- Number don't have to start with 0 as long as they are ordered correctly. + +✅ Valid format + +``` +image_099.jpg, image_100.jpg, image_101.jpg +``` + +❌ Invalid format + +``` +image_99.jpg, image_100.jpg, image_101.jpg +``` + +_When in doubt, just rename everything to `frame_000000.jpg`, `frame_000001.jpg`, i.e. `frame_%06d.jpg`, this will 100% work._ + +## Convenience tips + +If you want to rename your existing frames/masks (preserves original file extension), here's a convenient script in Python: +```Python +import re +import shutil +from pathlib import Path + +p_in = Path('/path/to/your/frames') +p_out = Path('/path/where/to/save/renamed/frames') +p_out.mkdir(exist_ok=True, parents=True) + +pattern = re.compile(r'\d+') + +for p_file in sorted(p for p in p_in.iterdir() if p.is_file()): + idx = int(re.search(pattern, p_file.stem).group()) + new_name = f'frame_{idx:06d}' + p_file.suffix + shutil.copyfile(p_file, p_out / new_name) +``` + +Both cmd/Python API and GUI app can extract frames for you from a video file. However, you can still do it yourself if you need to, using the following `ffmpeg` command: +```Bash + # Optional resizing (for images, for masks use `flags=neighbor`) + # JPEG compression quality (0-51), better->worse, 0 lossless +ffmpeg -i path/to/video.mp4 -vf 'scale=480p:-1:flags=lanczos' -qscale:v 2 existing_output_dir/frame_%06d.jpg # use .png for masks +``` + +To concat masks/overlays back to a video use: + +```Bash + # FPS # or /masks/frame_%06d.png # MPEG compression quality (0-51), better->worse, 21 is good quality +ffmpeg -r 30 -i workspace//overlays/frame_%06d.jpg -crf 21 _overlay.mp4 +``` + +## Color scheme (palette) + +> Some image formats, such as GIF or PNG, can use a palette, which is a table of (usually) 256 colors to allow for better compression. Basically, instead of representing each pixel with its full color triplet, which takes 24bits (plus eventual 8 more for transparency), they use a 8 bit index that represent the position inside the palette, and thus the color. +-- https://docs.geoserver.org/2.22.x/en/user/tutorials/palettedimage/palettedimage.html + +XMem++ app uses colored masks to indicate different objects in the video. Unique color = unique object. + +However, storing a mask with 5-10 unique colors as an RGB image is expensive (you have `H x W x 3` elements), so instead we are using a color palette from DAVIS dataset, which maps object indices into colors like this: +``` +DAVIS palette +Object index -> RGB color +##################### +0 -> (0, 0, 0 ) +1 -> (128, 0, 0 ) +2 -> (0, 0, 0 ) +3 -> (128, 0, 0 ) +4 -> (0, 128, 0 ) +5 -> (128, 128, 0 ) +6 -> (0, 0, 128) +7 -> (128, 0, 128) +8 -> (0, 128, 128) +9 -> (128, 128, 128) +10 -> (64, 0, 0 ) +11 -> (192, 0, 0 ) +... +##################### +``` +This way to save space and tell the model where each object is in a mask, we store _object indices instead of RGB values_ in it, making it only `H x W` dimensions. 0 is always background and maps to black. + +So those mask files that look like color images are single-channel, `uint8` arrays under the hood. When `PIL` reads them, it (correctly) gives you a two-dimensional array (`opencv` does not work AFAIK). If what you get is instead of three-dimensional, `H*W*3` array, then your mask is not actually a paletted mask, but just a colored image. Reading and saving a paletted mask through `opencv` or MS Paint would destroy the palette. + +Our code, when asked to generate multi-object segmentation (e.g., DAVIS 2017/YouTubeVOS), always reads and writes single-channel mask. If there is a palette in the input, we will use it in the output. The code does not care whether a palette is actually used -- we can read grayscale images just fine. + +Importantly, we use `np.unique` to determine the number of objects in the mask. This would fail if: + +1. Colored images, instead of paletted masks are used. +2. The masks have "smooth" edges, produced by feathering/downsizing/compression. **For example, when you draw the mask in a painting software, make sure you set the brush hardness to maximum.** + +Generally speaking you don't need to worry about it unless you are changing internals of XMem++. + +To avoid the following, just make sure to **only use nearest neighbour interpolation** for masks whenever necessary and only use **lossless** image formats like `.png` for storing them. +| ![Image 1](resources/good_mask_256.png) | ![Image 2](resources/bad_mask_256.png) | +|:----------------------:|:----------------------:| +| ✅ Valid mask - 1 unique color/object | ❌ Invalid mask - 66 unique colors/objects! | + + +### Colors in the output + +`import_existing` and GUI app will **automatically convert** your regular RGB masks to correct palette format, so the colors you see in the app and output may be different. + +If you use command-line/Python API, it **preserves** your original colors in place. E.g. if you have 1 object you are segmenting and a mask you provide is pink, it will be pink in the output. diff --git a/docs/DEMO.md b/docs/DEMO.md index cd1d283..221b04d 100644 --- a/docs/DEMO.md +++ b/docs/DEMO.md @@ -1,58 +1,114 @@ -# Interactive GUI for Demo +# Interactive GUI -First, set up the required packages following [GETTING STARTED.md](./GETTING_STARTED.md). You can ignore the dataset part as you wouldn't be needing them for this demo. Download the pretrained models following [INFERENCE.md](./INFERENCE.md). +First, set up the required packages following [Installation steps in REDME](../README.md#getting-started). Don't forget to downlaod the pretrained models from there as well. You will need some additional packages and pretrained models for the GUI. For the packages, -```bash -pip install -r requirements_demo.txt +The interactive GUI is modified from [XMem](https://github.com/hkchengrex/Xmem). We keep the same modules for interactions ([f-BRS](https://github.com/saic-vul/fbrs_interactive_segmentation) and [S2M](https://github.com/hkchengrex/Scribble-to-Mask)). You will need their pretrained models. Use `./scripts/download_models_demo.sh` or download them manually into `./saves` from [GitHub](https://github.com/hkchengrex/XMem/releases/tag/v1.0). + +| ![Demo GUI](resources/gui_demo.jpg) | +|:--:| +| _XMem++ updated GUI with it's References tab, which show current annotations used as refernces in the permanent memory as well as_ | +## Try it for yourself + +XMem++ includes 2 short sample videos in [example_videos directory](../example_videos/), you can use them to test the demo. + +The entry point is `interactive_demo.py`. To run the GUI on a new video: +```Bash +python interactive_demo.py --video example_videos/chair/chair.mp4 ``` -The interactive GUI is modified from [MiVOS](https://github.com/hkchengrex/MiVOS). Specifically, we keep the "interaction-to-mask" module and the propagation module is replaced with XMem. The fusion module is discarded because I don't want to train it. -For interactions, we use [f-BRS](https://github.com/saic-vul/fbrs_interactive_segmentation) and [S2M](https://github.com/hkchengrex/Scribble-to-Mask). You will need their pretrained models. Use `./scripts/download_models_demo.sh` or download them manually into `./saves`. +To run on a list of images: +```Bash +python interactive_demo.py --images example_videos/chair/JPEGImages +``` -The entry point is `interactive_demo.py`. The command line arguments should be self-explanatory. +Both of these commands will create a folder for the current vide in workspace folder (default is `.workspace`) and save all the masks and predictions there. -![gui](https://imgur.com/uAImD80.jpg) -## Try it for yourself +To keep editing an existing project in a workspace, run the following command: +```Bash +python interactive_demo.py --workspace ./workspace/ +``` -https://user-images.githubusercontent.com/7107196/177661140-f690156b-1775-4cd7-acd7-1738a5c92f30.mp4 +If you have more than 1 object make sure to add `--num-objects ` to the commands above the **first time you create a project**. It will saved in the project file after that for your convenience =) -Right-click download this video (source: https://www.youtube.com/watch?v=FTcjzaqL0pE). Then run +Like this: +```Bash +python interactive_demo.py --images example_videos/caps/JPEGImages --num-objects 2 +``` -```bash -python interactive_demo.py --video [path to the video] --num_objects 4 +### Run in Docker +To run the GUI in Docker, simply run the following command: +```Bash +./run_gui_in_docker.sh --video example_videos/chair/chair.mp4 --num_objects 2 ``` +For the GUI you can change variables `$LOCAL_WORKSPACE_DIR` and `$DISPLAY_TO_USE` in [run_gui_in_docker.sh](run_gui_in_docker.sh#L54) if necessary. + +_Be wary that the interactive import buttons will not work (they will open paths within the container filesystem, not the host one)._ ## Features +* Has an **object selector** with color indication to edit masks for different objects. Either click them or press '1', '2', etc. on the keyboard. +* **Reference tab** - shows which frames you already saved into the permanent memory as references and allows you to remove/edit them +* Also shows which frames were last recommended to annotate by the frame annotation candidate selection algorithm. +* Frames previously chosen as references and saved in the permanent memory will be re-loaded if the run was interrupted/closed. +* Hover-on tooltips available for most of the controls in the app: + +![Tooltips](resources/tooltips.jpg) + +**And from the original XMem:** * Low CPU memory cost. Unlike the implementation in MiVOS, we do not load all the images as the program starts up. We load them on-the-fly with an LRU buffer. * Low GPU memory cost. This is provided by XMem. See the paper. * Faster than MiVOS-STCN, especially for long videos. ^ * You can continue from interrupted runs. We save the resultant masks on-the-fly in the workspace directory from which annotation can be resumed. The memory bank is not saved and cannot be resumed. +## Workflow +| ![Demo GUI](resources/workflow.jpg) | +|:--:| +| _XMem++ recommended workflow_ | + +1. Load the demo. +2. Provide at least one mask (if none exist yet) - draw it or import an existing one. +3. Run FULL propagation once. +4. [Optional] - Click "Compute annotation candidates" and get a list of the next $k$ frames suggested for annotation. +5. Annotated/Fix masks for one or more frames, save them to references. +6. Repeat steps 3-5. + +## Image editing +Fun fact - you can use XMem++ for just editing the masks! If there are some minor inaccuracies with a few predicted masks, you can just edit them directly and close the app - the changed will be saved. _You don't have to run the propagation for this_, and it's easier than loading each mask into an image editor, fixing and saving separately. + ## Controls * Use the slider to change the current frame. "Play Video" automatically progresses the video. * Select interaction type: "scribble", "click", or "free". Both scribble and "free" (free-hand drawing) modify an existing mask. Using "click" on an existing object mask (i.e., a mask from propagation or other interaction methods) will reset the mask. This is because f-BRS does not take an existing mask as input. -* Select the target object using the number keys. "1" corresponds to the first object, etc. You need to specify the maximum number of objects when you start the program through the command line. -* Use propagate forward/backward to let XMem do the job. Pause when correction is needed. It will only automatically stops when it hits the end of the video. +* Select the target object using the number keys or clicking on the corresponding color/object number on the left. "1" corresponds to the first object, etc. You need to specify the maximum number of objects when you first create a project through the command line. +* Use propagate forward/backward to let XMem do the job. **Use FULL propagate after you changed some masks in the permanent memory**. Pause when correction is needed. It will only automatically stops when it hits the end/start of the video. * Make sure all objects are correctly labeled before propagating. The program doesn't care which object you have interacted with -- it treats everything as user-provided inputs. Not labelling an object implicitly means that it is part of the background. -* The memory bank might be "polluted" by bad memory frames. Feel free to hit clear memory to erase that. Propagation runs faster with a small memory bank. -* All output masks are automatically saved in the workspace directory, which is printed when the program starts. +* The memory bank might be "polluted" by bad memory frames. Feel free to hit clear memory to erase that. +* All output masks and overlays are automatically saved in the workspace directory, which is printed when the program starts. * You can load an external mask for the current frame using "Import mask". -* For "layered insertion" (e.g., the breakdance demo), use the "layered" overlay mode. You can load a custom layer using "Import layer". The layer should be an RGBA png file. RGB image files are also accepted -- the alpha channel will be filled with ones. +* For "layered insertion" (e.g., the breakdance demo), use the "layered" overlay mode. You can load a custom layer using "Import layer". The layer should be an RGBA png file. RGB image files are also accepted -- the alpha channel will be filled with `1`s. * The "save overlay during propagation" checkbox does exactly that. It does not save the overlay when the user is just scrubbing the timeline. * For "popup" and "layered", the visualizations during propagation (and the saved overlays) have higher quality then when the user is scrubbing the timeline. This is because we have access to the soft probability mask during propagation. * Both "popup" and "layered" need a binary mask. By default, the first object mask is used. You can change the target (or make the target a union of objects) using the middle mouse key. ## FAQ -1. Why cannot I label object 2 after pressing the number '2'? - - Make sure you specified `--num_objects`. We ignore object IDs that exceed `num_objects`. -2. The GUI feels slow! - - The GUI needs to read/write images and masks on-the-go. Ideally this can be implemented with multiple threads with look-ahead but I didn't. The overheads will be smaller if you place the `workspace` on a SSD. You can also use a ram disk. `eval.py` will almost certainly be faster. +1. What are references? + - They are just frames with correct masks (verified by the user), which the model treats as examples for segmenting other frames. +1. How does the frame annotation candidate selection work? + - It takes previous references and select $k$ new ones that would maximize the diversity of appearance of the target object. +1. Do I have to annotate and use all the candidate suggested to me? + - Not at all! These are simply suggestions, you can do with them however you please. +1. Some of the frame candidates that I was given already have good masks! I don't need to edit them. + - That's totally fine, they were selected because the target object on them looks different to those already in use, we do not try to estimate how good or bad the current segmentation is [for a good reason](../README.md#frame-annotation-candidate-selector). +1. I only see 1 object! + - Make sure you specified `--num_objects` when you first create the project. +1. The GUI feels slow! + - The GUI needs to read/write images and masks on-the-go. Ideally this can be implemented with multiple threads with look-ahead but it hasn't been yet. The overheads will be smaller if you place the `workspace` on a SSD. You can also use a ram disk. - It takes more time to process more objects. This depends on `num_objects`, but not the actual number of objects that the user has annotated. *This does not mean that running time is directly proportional to the number of objects. There is significant shared computation.* -3. Can I run this on a remote server? - - X11 forwarding should be possible. I have not tried this and would love to know if it works for you. +1. Can I run this on a remote server? + - X11 forwarding should be possible, but untested. +1. Will this work on Windows? + - So far only tested on Linux, but should be possible to run it on other platforms as well. It's just Python and PyQt. diff --git a/docs/FAILURE_CASES.md b/docs/FAILURE_CASES.md deleted file mode 100644 index 108566b..0000000 --- a/docs/FAILURE_CASES.md +++ /dev/null @@ -1,25 +0,0 @@ -# Failure Cases - -Like all methods, XMem can fail. Here, we try to show some illustrative and frankly consistent failure modes that we noticed. We slowed down all videos for visualization. - -## Fast motion, similar objects - -The first one is fast motion with similarly-looking objects that do not provide sufficient appearance clues for XMem to track. Below is an example from the YouTubeVOS validation set (0e8a6b63bb): - -https://user-images.githubusercontent.com/7107196/179459162-80b65a6c-439d-4239-819f-68804d9412e9.mp4 - -And the source video: - -https://user-images.githubusercontent.com/7107196/181700094-356284bc-e8a4-4757-ab84-1e9009fddd4b.mp4 - -Technically it can be solved by using more positional and motion clues. XMem is not sufficiently proficient at those. - -## Shot changes; saliency shift - -Ever wondered why I did not include the final scene of Chika Dance when the roach flies off? Because it failed there. - -XMem seems to be attracted to any new salient object in the scene when the (true) target object is missing. By new I mean an object that did not appear (or had a different appearance) earlier in the video -- as XMem could not have a memory representation for that object. This happens a lot if the camera shot changes. - -https://user-images.githubusercontent.com/7107196/179459190-d736937a-6925-4472-b46e-dcf94e1cafc0.mp4 - -Note that the first shot change is not as problematic. diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md deleted file mode 100644 index aaa4295..0000000 --- a/docs/GETTING_STARTED.md +++ /dev/null @@ -1,64 +0,0 @@ -# Getting Started - -Our code is tested on Ubuntu. I have briefly tested the GUI on Windows (with a PyQt5 fix in the heading of interactive_demo.py). - -## Requirements - -* Python 3.8+ -* PyTorch 1.11+ (See [PyTorch](https://pytorch.org/) for installation instructions) -* `torchvision` corresponding to the PyTorch version -* OpenCV (try `pip install opencv-python`) -* Others: `pip install -r requirements.txt` - -## Dataset - -I recommend either softlinking (`ln -s`) existing data or use the provided `scripts/download_datasets.py` to structure the datasets as our format. - -`python -m scripts.download_dataset` - -The structure is the same as the one in STCN -- you can place XMem in the same folder as STCN and it will work. -The script uses Google Drive and sometimes fails when certain files are blocked from automatic download. You would have to do some manual work in that case. -It does not download BL30K because it is huge and we don't want to crash your harddisks. - -```bash -├── XMem -├── BL30K -├── DAVIS -│ ├── 2016 -│ │ ├── Annotations -│ │ └── ... -│ └── 2017 -│ ├── test-dev -│ │ ├── Annotations -│ │ └── ... -│ └── trainval -│ ├── Annotations -│ └── ... -├── static -│ ├── BIG_small -│ └── ... -├── long_video_set -│ ├── long_video -│ ├── long_video_x3 -│ ├── long_video_davis -│ └── ... -├── YouTube -│ ├── all_frames -│ │ └── valid_all_frames -│ ├── train -│ ├── train_480p -│ └── valid -└── YouTube2018 - ├── all_frames - │ └── valid_all_frames - └── valid -``` - -## Long-Time Video - -It comes from [AFB-URR](https://github.com/xmlyqing00/AFB-URR). Please following their license when using this data. We release our extended version (X3) and corresponding `_davis` versions such that the DAVIS evaluation can be used directly. They can be downloaded [[here]](TODO). The script above would also attempt to download it. - -### BL30K - -You can either use the automatic script `download_bl30k.py` or download it manually from [MiVOS](https://github.com/hkchengrex/MiVOS/#bl30k). Note that each segment is about 115GB in size -- 700GB in total. You are going to need ~1TB of free disk space to run the script (including extraction buffer). -The script uses Google Drive and sometimes fails when certain files are blocked from automatic download. You would have to do some manual work in that case. diff --git a/docs/INFERENCE.md b/docs/INFERENCE.md deleted file mode 100644 index 480173d..0000000 --- a/docs/INFERENCE.md +++ /dev/null @@ -1,108 +0,0 @@ -# Inference - -1. Set up the datasets following [GETTING_STARTED.md](./GETTING_STARTED.md). -2. Download the pretrained models either using `./scripts/download_models.sh`, or manually and put them in `./saves` (create the folder if it doesn't exist). You can download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing). - -All command-line inference are accessed with `eval.py`. See [RESULTS.md](./RESULTS.md) for an explanation of FPS and the differences between different models. - -## Usage - -``` -python eval.py --model [path to model file] --output [where to save the output] --dataset [which dataset to evaluate on] --split [val for validation or test for test-dev] -``` - -See the code for a complete list of available command-line arguments. - -Examples: -(``--model`` defaults to `./saves/XMem.pth`) - -DAVIS 2017 validation: - -``` -python eval.py --output ../output/d17 --dataset D17 -``` - -DAVIS 2016 validation: - -``` -python eval.py --output ../output/d16 --dataset D16 -``` - -DAVIS 2017 test-dev: - -``` -python eval.py --output ../output/d17-td --dataset D17 --split test -``` - -YouTubeVOS 2018 validation: - -``` -python eval.py --output ../output/y18 --dataset Y18 -``` - -Long-Time Video (3X) (note that `mem_every`, aka `r`, is set differently): - -``` -python eval.py --output ../output/lv3 --dataset LV3 --mem_every 10 -``` - -## Getting quantitative results - -We do not provide any tools for getting quantitative results here. We used the followings to get the results reported in the paper: - -- DAVIS 2017 validation: [davis2017-evaluation](https://github.com/davisvideochallenge/davis2017-evaluation) -- DAVIS 2016 validation: [davis2016-evaluation](https://github.com/hkchengrex/davis2016-evaluation) (Unofficial) -- DAVIS 2017 test-dev: [CodaLab](https://competitions.codalab.org/competitions/20516#participate) -- YouTubeVOS 2018 validation: [CodaLab](https://competitions.codalab.org/competitions/19544#results) -- YouTubeVOS 2019 validation: [CodaLab](https://competitions.codalab.org/competitions/20127#participate-submit_results) -- Long-Time Video: [davis2017-evaluation](https://github.com/davisvideochallenge/davis2017-evaluation) - -(For the Long-Time Video dataset, point `--davis_path` to either `long_video_davis` or `long_video_davis_x3`) - -## On custom data - -Structure your custom data like this: - -```bash -├── custom_data_root -│ ├── JPEGImages -│ │ ├── video1 -│ │ │ ├── 00001.jpg -│ │ │ ├── 00002.jpg -│ │ │ ├── ... -│ │ └── ... -│ ├── Annotations -│ │ ├── video1 -│ │ │ ├── 00001.png -│ │ │ ├── ... -│ │ └── ... -``` - -We use `sort` to determine frame order. The annotations do not have have to be complete (e.g., first-frame only is fine). We use PIL to read the annotations and `np.unique` to determine objects. PNG palette will be used automatically if exists. - -Then, point `--generic_path` to `custom_data_root` and specify `--dataset` as `G` (for generic). - -## Multi-scale evaluation - -Multi-scale evaluation is done in two steps. We first compute and save the object probabilities maps for different settings independently on hard-disks as `hkl` (hickle) files. Then, these maps are merged together with `merge_multi_score.py`. - -Example for DAVIS 2017 validation MS: - -Step 1 (can be done in parallel with multiple GPUs): - -```bash -python eval.py --output ../output/d17_ms/720p --mem_every 3 --dataset D17 --save_scores --size 720 -python eval.py --output ../output/d17_ms/720p_flip --mem_every 3 --dataset D17 --save_scores --size 720 --flip -``` - -Step 2: - -```bash -python merge_multi_scale.py --dataset D --list ../output/d17_ms/720p ../output/d17_ms/720p_flip --output ../output/d17_ms_merged -``` - -Instead of `--list`, you can also use `--pattern` to specify a glob pattern. It also depends on your shell (e.g., `zsh` or `bash`). - -## Advanced usage - -To develop your own evaluation interface, see `./inference/` -- most importantly, `inference_core.py`. diff --git a/docs/LIMITATIONS.md b/docs/LIMITATIONS.md new file mode 100644 index 0000000..4db76bb --- /dev/null +++ b/docs/LIMITATIONS.md @@ -0,0 +1,48 @@ +# Limitations + +Since XMem++ is built on top of [XMem](https://github.com/hkchengrex/XMem), it shares common failure cases with it. + +We have also identified a new limitation of memory-based methods which is as follows: + +## Negative masks are problematic + +If the model is prone to predict a lot of false-positive segmentations on a particular video, it might not be easy to get rid of them by just providing empty ground truth masks: + +
+ Negative masks failure case +
+ +**However, this is easy to deal with**: + +- If there are **no objects** that must be segmented, then the user can just very quickly remove unwanted masks in the editing (just "Reset frame" -> "Next frame" -> "Repeat") or by deleting the unwanted mask files. +- If there is at least a **small part of the target object** present, then providing a ground truth mask for it will likely remove unwanted false positive segmentations. + +We further re-iterate on the original limitations provided by [Ho Kei (Rex) Cheng](https://hkchengrex.com/) in their original XMem repository's [FAILURE_CASES.md document](https://github.com/hkchengrex/XMem/blob/main/docs/FAILURE_CASES.md). The following examples are taken from there: + +## Original XMem limitations + +> Like all methods, XMem can fail. Here, we try to show some illustrative and frankly consistent failure modes that we noticed. We slowed down all videos for visualization. +### Fast motion, similar objects + +> The first one is fast motion with similarly-looking objects that do not provide sufficient appearance clues for XMem to track. Below is an example from the YouTubeVOS validation set (0e8a6b63bb): + +https://user-images.githubusercontent.com/7107196/179459162-80b65a6c-439d-4239-819f-68804d9412e9.mp4 + +> And the source video: + +https://user-images.githubusercontent.com/7107196/181700094-356284bc-e8a4-4757-ab84-1e9009fddd4b.mp4 + +> Technically it can be solved by using more positional and motion clues. XMem is not sufficiently proficient at those. + +## Shot changes; saliency shift + +> Ever wondered why I did not include the final scene of Chika Dance when the roach flies off? Because it failed there. +> +> XMem seems to be attracted to any new salient object in the scene when the (true) target object is missing. By new I mean an object that did not appear (or had a different appearance) earlier in the video -- as XMem could not have a memory representation for that object. This happens a lot if the camera shot changes. + +https://user-images.githubusercontent.com/7107196/179459190-d736937a-6925-4472-b46e-dcf94e1cafc0.mp4 + +> Note that the first shot change is not as problematic. + +_However, this problem can now be partially mitigated with using **more annotations** to correct the model's false positive predictions._ + diff --git a/docs/PYTHON_API.md b/docs/PYTHON_API.md new file mode 100644 index 0000000..f976434 --- /dev/null +++ b/docs/PYTHON_API.md @@ -0,0 +1,57 @@ +# XMem++ Python API + +XMem++ exposes 2 main functions you can use: +- `run_on_video` - run full inference on a video/images folder with given annotations. +- `select_k_next_best_annotation_candidates` - given a video/images folder, at least one ground truth annotation (to know which object we are even segmenting) and existing predictions [optional], select $k$ next best + +See also [main.py](../main.py). + +## Inference with preselected ground truth annotations +### Using list of video frames +To run segmentation on a list of video frames (`.jpg`) with preselected annotations: +```Python +from inference.run_on_video import run_on_video +imgs_path = 'example_videos/caps/JPEGImages' +masks_path = 'example_videos/caps/Annotations' # Should contain annotation masks for frames in `frames_with_masks` +output_path = 'output/example_video_caps' +frames_with_masks = [0, 14, 33, 43, 66] # indices of frames for which there is an annotation mask +run_on_video(imgs_path, masks_path, output_path, frames_with_masks) +``` +### Using a video file +To run segmentation on a video file (like `.mp4`) with preselected annotations: +```Python +from inference.run_on_video import run_on_video +video_path = 'example_videos/chair/chair.mp4' +masks_path = 'example_videos/chair/Annotations' # Should contain annotation masks for rames in `frames_with_masks` +output_path = 'output/example_video_chair_from_mp4' +frames_with_masks = [5, 10, 15] # indices of frames for which there is an annotation mask +run_on_video(video_path, masks_path, output_path, frames_with_masks) +``` +## Getting next best frames to annotate +If after this you want to get proposals which frames to annotate next, add the following lines: +```Python +from inference.run_on_video import select_k_next_best_annotation_candidates +# Get proposals for the next 3 best annotation candidates using previously predicted masks +next_candidates = select_k_next_best_annotation_candidates(imgs_path, masks_path, output_path, previously_chosen_candidates=frames_with_masks, use_previously_predicted_masks=True) +print("Next candidates for annotations are: ") +for idx in next_candidates: + print(f"\tFrame {idx}") +``` +If you don't have previous predictions, just put `use_previously_predicted_masks=False`, the algorithm will run a new inference internally. + +## Evaluating on a video with all ground truth masks available +If you have a fully-labeled video and want to run **XMem++** and compute IoU, run the following code: +```Python +# Run inference on a video with all annotations provided, compute IoU +import os +import random +from inference.run_on_video import run_on_video +imgs_path = 'example_videos/chair/JPEGImages' +masks_path = 'example_videos/chair/Annotations' +output_path = 'output/example_video_chair' +num_frames = len(os.listdir(imgs_path)) +frames_with_masks = random.sample(range(0, num_frames), 3) # Give 3 random masks as GT annotations +stats = run_on_video(imgs_path, masks_path, output_path, frames_with_masks, compute_iou=True) # stats: pandas DataFrame +mean_iou = stats[stats['iou'] != -1]['iou'].mean() # -1 is for GT annotations, we just skip them +print(f"Average IoU: {mean_iou}") # Should be 90%+ as a sanity check +``` \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..6fc637f --- /dev/null +++ b/docs/README.md @@ -0,0 +1,10 @@ +# Extra documentation +This directory houses a collection of additional docs to help the users. + +## Table of contents +* [DEMO.md](DEMO.md) - how to start and use the GUI app, and the list of it's features and use cases. +* [PYTHON_API.md](DEMO.md) - how to use XMem++ from Python code. +* [DATA_FORMAT_HELP.md](DEMO.md) - what's the data format XMem++ expects; some helper commands for video. +* [LIMITATIONS.md](DEMO.md) - Limitations and failure cases. +* [main README.md ](../README.md) - Main documentation `(../README.md)`. +* [README.md](README.md) - This file. \ No newline at end of file diff --git a/docs/RESULTS.md b/docs/RESULTS.md deleted file mode 100644 index abef0c5..0000000 --- a/docs/RESULTS.md +++ /dev/null @@ -1,104 +0,0 @@ -# Results - -## Preamble - -Our code, by default, uses automatic mixed precision (AMP). Its effect on the output is negligible. -All speeds reported in the paper are recorded with AMP turned off (`--benchmark`). -Due to refactoring, there might be slight differences between the outputs produced by this code base with the precomputed results/results reported in the paper. This difference rarely leads to a change of the least significant figure (i.e., 0.1). - -**For most complete results, please see the paper (and the appendix)!** - -All available precomputed results can be found [[here]](https://drive.google.com/drive/folders/1UxHPXJbQLHjF5zYVn3XZCXfi_NYL81Bf?usp=sharing). - -## Pretrained models - -We provide four pretrained models for download: - -1. XMem.pth (Default) -2. XMem-s012.pth (Trained with BL30K) -3. XMem-s2.pth (No pretraining on static images) -4. XMem-no-sensory (No sensory memory) - -The model without pretraining is for reference. The model without sensory memory might be more suitable for tasks without spatial continuity, like mask tracking in a multi-camera 3D reconstruction setting, though I would encourage you to try the base model as well. - -Download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing). - -## Long-Time Video - -[[Precomputed Results]](https://drive.google.com/drive/folders/1NADcetigH6d83mUvyb2rH4VVjwFA76Lh?usp=sharing) - -### Long-Time Video (1X) - -| Model | J&F | J | F | -| --- | :--:|:--:|:---:| -| XMem | 89.8±0.2 | 88.0±0.2 | 91.6±0.2 | - -### Long-Time Video (3X) - -| Model | J&F | J | F | -| --- | :--:|:--:|:---:| -| XMem | 90.0±0.4 | 88.2±0.3 | 91.8±0.4 | - -## DAVIS - -[[Precomputed Results]](https://drive.google.com/drive/folders/1XTOGevTedRSjHnFVsZyTdxJG-iHjO0Re?usp=sharing) - -### DAVIS 2016 - -| Model | J&F | J | F | FPS | FPS (AMP) | -| --- | :--:|:--:|:---:|:---:|:---:| -| XMem | 91.5 | 90.4 | 92.7 | 29.6 | 40.3 | -| XMem-s012 | 92.0 | 90.7 | 93.2 | 29.6 | 40.3 | -| XMem-s2 | 90.8 | 89.6 | 91.9 | 29.6 | 40.3 | - -### DAVIS 2017 validation - -| Model | J&F | J | F | FPS | FPS (AMP) | -| --- | :--:|:--:|:---:|:---:|:---:| -| XMem | 86.2 | 82.9 | 89.5 | 22.6 | 33.9 | -| XMem-s012 | 87.7 | 84.0 | 91.4 | 22.6 | 33.9 | -| XMem-s2 | 84.5 | 81.4 | 87.6 | 22.6 | 33.9 | -| XMem-no-sensory | 85.1 | - | - | 23.1 | - | - -### DAVIS 2017 test-dev - -| Model | J&F | J | F | -| --- | :--:|:--:|:---:| -| XMem | 81.0 | 77.4 | 84.5 | -| XMem-s012 | 81.2 | 77.6 | 84.7 | -| XMem-s2 | 79.8 | 61.4 | 68.1 | -| XMem-s012 (600p) | 82.5 | 79.1 | 85.8 | - -## YouTubeVOS - -We use all available frames in YouTubeVOS by default. -See [INFERENCE.md](./INFERENCE.md) if you want to evaluate with sparse frames for some reason. - -[[Precomputed Results]](https://drive.google.com/drive/folders/1P_BmOdcG6OP5mWGqWzCZrhQJ7AaLME4E?usp=sharing) - -[[Precomputed Results (sparse)]](https://drive.google.com/drive/folders/1IRV1fHepufUXM45EEbtl9D4pkoh9POSZ?usp=sharing) - -### YouTubeVOS 2018 validation - -| Model | G | J-Seen | F-Seen | J-Unseen | F-Unseen | FPS | FPS (AMP) | -| --- | :--:|:--:|:---:|:---:|:---:|:---:|:---:| -| XMem | 85.7 | 84.6 | 89.3 | 80.2 | 88.7 | 22.6 | 31.7 | -| XMem-s012 | 86.1 | 85.1 | 89.8 | 80.3 | 89.2 | 22.6 | 31.7 | -| XMem-s2 | 84.3 | 83.9 | 88.8 | 77.7 | 86.7 | 22.6 | 31.7 | -| XMem-no-sensory | 84.4 | - | - | - | - | 23.1 | - | - -### YouTubeVOS 2019 validation - -| Model | G | J-Seen | F-Seen | J-Unseen | F-Unseen | -| --- | :--:|:--:|:---:|:---:|:---:| -| XMem | 85.5 | 84.3 | 88.6 | 80.3 | 88.6 | -| XMem-s012 | 85.8 | 84.8 | 89.2 | 80.3 | 88.8 | -| XMem-s2 | 84.2 | 83.8 | 88.3 | 78.1 | 86.7 | - -## Multi-scale evaluation - -Please see the appendix for quantitative results. - -[[DAVIS-MS Precomputed Results]](https://drive.google.com/drive/folders/1H3VHKDO09izp6KR3sE-LzWbjyM-jpftn?usp=sharing) - -[[YouTubeVOS-MS Precomputed Results]](https://drive.google.com/drive/folders/1ww5HVRbMKXraLd2dy1rtk6kLjEawW9Kn?usp=sharing) diff --git a/docs/TRAINING.md b/docs/TRAINING.md deleted file mode 100644 index 3034946..0000000 --- a/docs/TRAINING.md +++ /dev/null @@ -1,49 +0,0 @@ -# Training - -First, set up the datasets following [GETTING STARTED.md](./GETTING_STARTED.md). - -The model is trained progressively with different stages (0: static images; 1: BL30K; 2: longer main training; 3: shorter main training). After each stage finishes, we start the next stage by loading the latest trained weight. -For example, the base model is pretrained with static images followed by the shorter main training (s03). - -To train the base model on two GPUs, you can use: - -```bash -python -m torch.distributed.run --master_port 25763 --nproc_per_node=2 train.py --exp_id retrain --stage 03 -``` - -`master_port` needs to point to an unused port. -`nproc_per_node` refers to the number of GPUs to be used (specify `CUDA_VISIBLE_DEVICES` to select which GPUs to use). -`exp_id` is an identifier you give to this training job. - -See other available command line arguments in `util/configuration.py`. -**Unlike the training code of STCN, batch sizes are effective. You don't have to adjust the batch size when you use more/fewer GPUs.** - -We implemented automatic staging in this code base. You don't have to train different stages by yourself like in STCN (but that is still supported). -`stage` is a string that we split to determine the training stages. Examples include `0` (static images only), `03` (base training), `012` (with BL30K), `2` (main training only). - -You can use `tensorboard` to visualize the training process. - -## Outputs - -The model files and checkpoints will be saved in `./saves/[name containing datetime and exp_id]`. - -`.pth` files with `_checkpoint` store the network weights, optimizer states, etc. and can be used to resume training (with `--load_checkpoint`). - -Other `.pth` files store the network weights only and can be used for inference. We note that there are variations in performance across different training runs and across the last few saved models. For the base model, we most often note that main training at 107K iterations leads to the best result (full training is 110K). - -We measure the median and std scores across five training runs of the base model: - -| Dataset | median | std | -| --- | :--:|:--:| -| DAVIS J&F | 86.2 | 0.23 | -| YouTubeVOS 2018 G | 85.6 | 0.21 - -## Pretrained models - -You can start training from scratch, or use any of our pretrained models for fine-tuning. For example, you can load our stage 0 model to skip main training: - -```bash -python -m torch.distributed.launch --master_port 25763 --nproc_per_node=2 train.py --exp_id retrain_stage3_only --stage 3 --load_network saves/XMem-s0.pth -``` - -Download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing). diff --git a/docs/icon.png b/docs/icon.png index 45309b2..0220850 100644 Binary files a/docs/icon.png and b/docs/icon.png differ diff --git a/docs/resources/architecture_explanations.jpg b/docs/resources/architecture_explanations.jpg new file mode 100644 index 0000000..a9ccd7f Binary files /dev/null and b/docs/resources/architecture_explanations.jpg differ diff --git a/docs/resources/bad_mask_256.png b/docs/resources/bad_mask_256.png new file mode 100644 index 0000000..3115104 Binary files /dev/null and b/docs/resources/bad_mask_256.png differ diff --git a/docs/resources/billie_shoes_square.gif b/docs/resources/billie_shoes_square.gif new file mode 100644 index 0000000..10bc5f9 Binary files /dev/null and b/docs/resources/billie_shoes_square.gif differ diff --git a/docs/resources/chair_short_square.gif b/docs/resources/chair_short_square.gif new file mode 100644 index 0000000..860ef2b Binary files /dev/null and b/docs/resources/chair_short_square.gif differ diff --git a/docs/resources/dog_tail_square.gif b/docs/resources/dog_tail_square.gif new file mode 100644 index 0000000..764d05f Binary files /dev/null and b/docs/resources/dog_tail_square.gif differ diff --git a/docs/resources/frame_selector_showcase.jpg b/docs/resources/frame_selector_showcase.jpg new file mode 100644 index 0000000..662f428 Binary files /dev/null and b/docs/resources/frame_selector_showcase.jpg differ diff --git a/docs/resources/good_mask_256.png b/docs/resources/good_mask_256.png new file mode 100644 index 0000000..0b2ca7c Binary files /dev/null and b/docs/resources/good_mask_256.png differ diff --git a/docs/resources/gui_demo.jpg b/docs/resources/gui_demo.jpg new file mode 100644 index 0000000..7bdfd40 Binary files /dev/null and b/docs/resources/gui_demo.jpg differ diff --git a/docs/resources/ice_cream_square.gif b/docs/resources/ice_cream_square.gif new file mode 100644 index 0000000..34a481f Binary files /dev/null and b/docs/resources/ice_cream_square.gif differ diff --git a/docs/resources/negative_masks_limitation.jpg b/docs/resources/negative_masks_limitation.jpg new file mode 100644 index 0000000..407cab2 Binary files /dev/null and b/docs/resources/negative_masks_limitation.jpg differ diff --git a/docs/resources/pants_workout_square.gif b/docs/resources/pants_workout_square.gif new file mode 100644 index 0000000..b6348fd Binary files /dev/null and b/docs/resources/pants_workout_square.gif differ diff --git a/docs/resources/skz_square.gif b/docs/resources/skz_square.gif new file mode 100644 index 0000000..4dc70fc Binary files /dev/null and b/docs/resources/skz_square.gif differ diff --git a/docs/resources/tattoo_square.gif b/docs/resources/tattoo_square.gif new file mode 100644 index 0000000..7433aad Binary files /dev/null and b/docs/resources/tattoo_square.gif differ diff --git a/docs/resources/visual_quality_vs_assessed_quality.jpg b/docs/resources/visual_quality_vs_assessed_quality.jpg new file mode 100644 index 0000000..ee33640 Binary files /dev/null and b/docs/resources/visual_quality_vs_assessed_quality.jpg differ diff --git a/docs/resources/vlog_square.gif b/docs/resources/vlog_square.gif new file mode 100644 index 0000000..8e079aa Binary files /dev/null and b/docs/resources/vlog_square.gif differ diff --git a/docs/resources/why_permanent_memory_helps.jpg b/docs/resources/why_permanent_memory_helps.jpg new file mode 100644 index 0000000..f486647 Binary files /dev/null and b/docs/resources/why_permanent_memory_helps.jpg differ diff --git a/docs/resources/workflow.jpg b/docs/resources/workflow.jpg new file mode 100644 index 0000000..c2b245e Binary files /dev/null and b/docs/resources/workflow.jpg differ diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..bd478d8 --- /dev/null +++ b/environment.yml @@ -0,0 +1,201 @@ +name: XMemA +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - asttokens=2.0.5=pyhd3eb1b0_0 + - backcall=0.2.0=pyhd3eb1b0_0 + - blas=1.0=mkl + - brotlipy=0.7.0=py39h27cfd23_1003 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.01.10=h06a4308_0 + - certifi=2022.12.7=py39h06a4308_0 + - cffi=1.15.1=py39h74dc2b5_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - cryptography=37.0.1=py39h9ce1e76_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - daal4py=2021.6.0=py39h79cecc1_1 + - dal=2021.6.0=hdb19cb5_916 + - decorator=5.1.1=pyhd3eb1b0_0 + - executing=0.8.3=pyhd3eb1b0_0 + - faiss-gpu=1.7.3=py3.9_h28a55e0_0_cuda11.3 + - ffmpeg=4.3=hf484d3e_0 + - fftw=3.3.10=nompi_h77c792f_102 + - flit-core=3.6.0=pyhd3eb1b0_0 + - freetype=2.11.0=h70c0345_0 + - giflib=5.2.1=h7b6447c_0 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - hdbscan=0.8.28=py39hce5d2b2_1 + - idna=3.4=py39h06a4308_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - ipython=8.8.0=py39h06a4308_0 + - jedi=0.18.1=py39h06a4308_1 + - joblib=1.1.1=py39h06a4308_0 + - jpeg=9e=h7f8727e_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libdeflate=1.8=h7f8727e_5 + - libfaiss=1.7.3=hfc2d529_0_cuda11.3 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=12.2.0=h69a702a_19 + - libgfortran5=12.2.0=h337968e_19 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.2=h7f8727e_0 + - libllvm11=11.1.0=hf817b99_2 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.4.0=hecacb30_0 + - libunistring=0.9.10=h27cfd23_0 + - libwebp=1.2.4=h11a3e52_0 + - libwebp-base=1.2.4=h5eee18b_0 + - llvmlite=0.39.1=py39he621ea3_0 + - lz4-c=1.9.3=h295c915_1 + - matplotlib-inline=0.1.6=py39h06a4308_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py39h7f8727e_0 + - mkl_fft=1.3.1=py39hd3c417c_0 + - mkl_random=1.2.2=py39h51133e4_0 + - mpi=1.0=mpich + - mpich=3.3.2=external_0 + - ncurses=6.3=h5eee18b_3 + - nettle=3.7.3=hbbd107a_1 + - numba=0.56.4=py39h417a72b_0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1t=h7f8727e_0 + - parso=0.8.3=pyhd3eb1b0_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pip=22.2.2=py39h06a4308_0 + - prompt-toolkit=3.0.20=pyhd3eb1b0_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - pure_eval=0.2.2=pyhd3eb1b0_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pygments=2.11.2=pyhd3eb1b0_0 + - pynndescent=0.5.8=pyh1a96a4e_0 + - pyopenssl=22.0.0=pyhd3eb1b0_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.13=haa1d7c7_2 + - python_abi=3.9=2_cp39 + - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 + - pytorch-mutex=1.0=cuda + - readline=8.1.2=h7f8727e_1 + - scikit-learn-intelex=2021.6.0=py39h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.39.3=h5082296_0 + - stack_data=0.2.0=pyhd3eb1b0_0 + - tbb=2021.6.0=hdb19cb5_0 + - threadpoolctl=3.1.0=pyh8a188c0_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=0.12.1=py39_cu113 + - tqdm=4.64.1=pyhd8ed1ab_0 + - traitlets=5.7.1=py39h06a4308_0 + - typing_extensions=4.3.0=py39h06a4308_0 + - tzdata=2022e=h04d1e81_0 + - umap-learn=0.5.3=py39hf3d152e_0 + - wcwidth=0.2.5=pyhd3eb1b0_0 + - xz=5.2.6=h5eee18b_0 + - zlib=1.2.12=h5eee18b_3 + - zstd=1.5.2=ha4553b6_0 + - pip: + - absl-py==1.4.0 + - autopep8==2.0.0 + - baal==1.7.0 + - beautifulsoup4==4.11.1 + - blessed==1.19.1 + - cachetools==5.3.0 + - charset-normalizer==3.0.1 + - contourpy==1.0.6 + - cupy-cuda11x==11.5.0 + - cycler==0.11.0 + - cython==0.29.32 + - dill==0.3.6 + - fastrlock==0.8.1 + - filelock==3.8.0 + - fonttools==4.38.0 + - gdown==4.5.3 + - gitdb==4.0.9 + - gitpython==3.1.29 + - google-auth==2.16.0 + - google-auth-oauthlib==0.4.6 + - gprof2dot==2022.7.29 + - gpustat==1.0.0 + - grpcio==1.51.1 + - h5py==3.7.0 + - haishoku==1.1.8 + - hickle==5.0.2 + - imageio==2.26.0 + - importlib-metadata==6.0.0 + - kiwisolver==1.4.4 + - lazy-loader==0.1 + - markdown==3.4.1 + - markupsafe==2.1.2 + - matplotlib==3.6.2 + - multiprocess==0.70.14 + - networkx==3.0 + - numpy==1.23.5 + - nvidia-ml-py==11.495.46 + - oauthlib==3.2.2 + - opencv-python==4.6.0.66 + - p-tqdm==1.4.0 + - packaging==22.0 + - pandas==1.5.2 + - pathos==0.3.0 + - pillow==9.2.0 + - pox==0.3.2 + - ppft==1.7.6.6 + - profilehooks==1.12.0 + - progressbar2==4.1.1 + - protobuf==3.20.3 + - psutil==5.9.4 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pycodestyle==2.9.1 + - pyparsing==3.0.9 + - pyqt5==5.15.7 + - pyqt5-qt5==5.15.2 + - pyqt5-sip==12.11.0 + - python-dateutil==2.8.2 + # - python-graphviz==0.20.1 + - python-utils==3.3.3 + - pytz==2022.7 + - pywavelets==1.4.1 + - requests==2.28.2 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - scikit-image==0.20.0 + - scikit-learn==1.2.0 + - scipy==1.9.1 + - seaborn==0.12.2 + - setuptools==66.1.1 + - smmap==5.0.0 + - snakeviz==2.1.1 + - soupsieve==2.3.2.post1 + - structlog==21.5.0 + - tensorboard==2.11.2 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - termcolor==2.2.0 + - thin-plate-spline==1.0.1 + # - thinplate==1.0.0 + - tifffile==2023.2.28 + - tomli==2.0.1 + - torch-tb-profiler==0.4.1 + - torchmetrics==0.9.3 + - torchvision==0.13.1 + - torchviz==0.0.2 + - tornado==6.2 + - trash-cli==0.23.2.13.2 + - typing==3.7.4.3 + - urllib3==1.26.14 + - werkzeug==2.2.2 + - wheel==0.38.4 + - zipp==3.12.0 \ No newline at end of file diff --git a/import_existing.py b/import_existing.py new file mode 100644 index 0000000..de51404 --- /dev/null +++ b/import_existing.py @@ -0,0 +1,87 @@ +import json +from pathlib import Path +import argparse + +import numpy as np +from PIL import Image +import progressbar +from tqdm import tqdm + +from util.image_loader import PaletteConverter + + +def resize_preserve(img, size, interpolation): + h, w = img.height, img.width + # Resize preserving aspect ratio + new_w = (w*size//min(w, h)) + new_h = (h*size//min(w, h)) + + resized_img = img.resize((new_w, new_h), resample=interpolation) + + return resized_img + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--name', type=str, help='The name of the project to use (name of the corresponding folder in the workspace). Will be created if doesn\'t exist ', required=True) + parser.add_argument('--size', type=str, help='The name of the project to use (name of the corresponding folder in the workspace). Will be created if doesn\'t exist ', default=480) + parser.add_argument('--images', type=str, help='Path to the folder with video frames', required=False) + parser.add_argument('--masks', type=str, help='Path to the folder with existing masks', required=False) + + args = parser.parse_args() + p_project = Path('workspace') / str(args.name) + if p_project.exists(): + print(f"Found the project {args.name} in the workspace.") + else: + print(f"Creating new project {args.name} in the workspace.") + + if args.images is not None: + p_imgs = Path(args.images) + p_imgs_out = p_project / 'images' + p_imgs_out.mkdir(parents=True, exist_ok=True) + + if any(p_imgs_out.iterdir()): + print(f"The project {args.name} already has images in the workspace. Delete them first.") + exit(0) + + img_files = sorted(p_imgs.iterdir()) + + for i in progressbar.progressbar(range(len(img_files)), prefix="Copying/resizing images..."): + p_img = img_files[i] + img = Image.open(p_img) + resized_img = resize_preserve(img, args.size, Image.Resampling.BILINEAR) + resized_img.save(p_imgs_out / f'frame_{i:06d}{p_img.suffix}') # keep the same image format + + if args.masks is not None: + p_masks = Path(args.masks) + p_masks_out = p_project / 'masks' + p_masks_out.mkdir(parents=True, exist_ok=True) + + if any(p_masks_out.iterdir()): + print(f"The project {args.name} already has masks in the workspace. Delete them first.") + exit(0) + + from util.palette import davis_palette + palette_converter = PaletteConverter(davis_palette) + + mask_files = sorted(p_masks.iterdir()) + + for i in progressbar.progressbar(range(len(mask_files)), prefix="Copying/resizing masks; converting to DAVIS color palette..."): + p_mask = mask_files[i] + mask = Image.open(p_mask) + resized_mask = resize_preserve(mask, args.size, Image.Resampling.NEAREST).convert('P') + + index_mask = palette_converter.image_to_index_mask(resized_mask) + + index_mask.save(p_masks_out / f'frame_{i:06d}{p_mask.suffix}') # keep the same image form + + try: + with open(p_project / 'info.json') as f: + data = json.load(f) + except Exception: + data = {} + + data['num_objects'] = palette_converter._num_objects + + with open(p_project / 'info.json', 'wt') as f_out: + json.dump(data, f_out, indent=4) + diff --git a/inference/data/mask_mapper.py b/inference/data/mask_mapper.py index 8e5b38d..378c090 100644 --- a/inference/data/mask_mapper.py +++ b/inference/data/mask_mapper.py @@ -30,7 +30,7 @@ def convert_mask(self, mask, exhaustive=False): new_labels = list(set(labels) - set(self.labels)) if not exhaustive: - assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' + assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' #a: it runs if you put exhaustive = True # add new remappings for i, l in enumerate(new_labels): diff --git a/inference/data/test_datasets.py b/inference/data/test_datasets.py index 3a4446e..1f2a1d5 100644 --- a/inference/data/test_datasets.py +++ b/inference/data/test_datasets.py @@ -89,7 +89,7 @@ def get_datasets(self): path.join(self.mask_dir, video), size=self.size, to_save=self.req_frame_list[video], - use_all_mask=True + use_all_masks=True ) def __len__(self): diff --git a/inference/data/video_reader.py b/inference/data/video_reader.py index 28cc4c6..cad0d52 100644 --- a/inference/data/video_reader.py +++ b/inference/data/video_reader.py @@ -1,21 +1,39 @@ +from dataclasses import dataclass, replace import os from os import path +from tempfile import TemporaryDirectory +from typing import Optional +import cv2 +import progressbar +import torch from torch.utils.data.dataset import Dataset from torchvision import transforms from torchvision.transforms import InterpolationMode import torch.nn.functional as F +import torchvision.transforms.functional as FT from PIL import Image import numpy as np from dataset.range_transform import im_normalization +@dataclass +class Sample: + rgb: torch.Tensor + raw_image_pil: Image.Image + frame: str + save: bool + shape: tuple + need_resize: bool + mask: Optional[torch.Tensor] = None + + class VideoReader(Dataset): """ This class is used to read a video, one frame at a time """ - def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None): + def __init__(self, vid_name, video_path, mask_dir, size=-1, to_save=None, use_all_masks=False, size_dir=None): """ image_dir - points to a directory of jpg images mask_dir - points to a directory of png masks @@ -26,17 +44,12 @@ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all Default false. Set to true for YouTubeVOS validation. """ self.vid_name = vid_name - self.image_dir = image_dir + self.video_path = video_path self.mask_dir = mask_dir self.to_save = to_save - self.use_all_mask = use_all_mask - if size_dir is None: - self.size_dir = self.image_dir - else: - self.size_dir = size_dir + self.use_all_masks = use_all_masks - self.frames = sorted(os.listdir(self.image_dir)) - self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette() + self.reference_mask = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).convert('P') self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0]) if size < 0: @@ -52,39 +65,86 @@ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all ]) self.size = size + if os.path.isfile(self.video_path): + self.tmp_dir = TemporaryDirectory() + self.image_dir = self.tmp_dir.name + self._extract_frames() + else: + self.image_dir = video_path - def __getitem__(self, idx): - frame = self.frames[idx] - info = {} - data = {} - info['frame'] = frame - info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) + if size_dir is None: + self.size_dir = self.image_dir + else: + self.size_dir = size_dir + + self.frames = sorted(os.listdir(self.image_dir)) - im_path = path.join(self.image_dir, frame) + def __getitem__(self, idx) -> Sample: + data = {} + frame_name = self.frames[idx] + im_path = path.join(self.image_dir, frame_name) img = Image.open(im_path).convert('RGB') if self.image_dir == self.size_dir: shape = np.array(img).shape[:2] else: - size_path = path.join(self.size_dir, frame) + size_path = path.join(self.size_dir, frame_name) size_im = Image.open(size_path).convert('RGB') shape = np.array(size_im).shape[:2] - gt_path = path.join(self.mask_dir, frame[:-4]+'.png') + gt_path = path.join(self.mask_dir, frame_name[:-4]+'.png') + if not os.path.exists(gt_path): + gt_path = path.join(self.mask_dir, frame_name[:-4]+'.PNG') + + data['raw_image_pil'] = img img = self.im_transform(img) - load_mask = self.use_all_mask or (gt_path == self.first_gt_path) + load_mask = self.use_all_masks or (gt_path == self.first_gt_path) if load_mask and path.exists(gt_path): mask = Image.open(gt_path).convert('P') mask = np.array(mask, dtype=np.uint8) data['mask'] = mask + info = {} + info['save'] = (self.to_save is None) or (frame_name[:-4] in self.to_save) + info['frame'] = frame_name info['shape'] = shape info['need_resize'] = not (self.size < 0) + data['rgb'] = img - data['info'] = info + + data = Sample(**data, **info) return data + + def __len__(self): + return len(self.frames) + + def __del__(self): + if hasattr(self, 'tmp_dir'): + self.tmp_dir.cleanup() + + def _extract_frames(self): + cap = cv2.VideoCapture(self.video_path) + frame_index = 0 + print(f'Extracting frames from {self.video_path} into a temporary dir...') + bar = progressbar.ProgressBar(max_value=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) + while(cap.isOpened()): + _, frame = cap.read() + if frame is None: + break + if self.size > 0: + h, w = frame.shape[:2] + new_w = (w*self.size//min(w, h)) + new_h = (h*self.size//min(w, h)) + if new_w != w or new_h != h: + frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) + cv2.imwrite(path.join(self.image_dir, f'frame_{frame_index:06d}.jpg'), frame) + frame_index += 1 + bar.update(frame_index) + bar.finish() + print('Done!') + def resize_mask(self, mask): # mask transform is applied AFTER mapper, so we need to post-process it in eval.py @@ -93,8 +153,14 @@ def resize_mask(self, mask): return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), mode='nearest') - def get_palette(self): - return self.palette + def map_the_colors_back(self, pred_mask: Image.Image): + # https://stackoverflow.com/questions/29433243/convert-image-to-specific-palette-using-pil-without-dithering + # dither=Dither.NONE just in case + return pred_mask.quantize(palette=self.reference_mask, dither=Image.Dither.NONE).convert('RGB') - def __len__(self): - return len(self.frames) \ No newline at end of file + @staticmethod + def collate_fn_identity(x): + if x.mask is not None: + return replace(x, mask=torch.tensor(x.mask)) + else: + return x \ No newline at end of file diff --git a/inference/frame_selection/__init__.py b/inference/frame_selection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/inference/frame_selection/frame_selection.py b/inference/frame_selection/frame_selection.py new file mode 100644 index 0000000..90fe345 --- /dev/null +++ b/inference/frame_selection/frame_selection.py @@ -0,0 +1,245 @@ +from copy import copy +from pathlib import Path +import time +from typing import Any, Dict, List, Set, Tuple, Union + +import cv2 +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as FT +import numpy as np +from tqdm import tqdm +from inference.frame_selection.frame_selection_utils import extract_keys +from torchvision.transforms import Resize, InterpolationMode + +from model.memory_util import get_similarity + + +def first_frame_only(*args, **kwargs): + # baseline + return [0] + + +def uniformly_selected_frames(dataloader, *args, how_many_frames=10, **kwargs) -> List[int]: + # baseline + # TODO: debug and check if works + num_total_frames = len(dataloader) + return np.linspace(0, num_total_frames - 1, how_many_frames).astype(int).tolist() + + +def calculate_proposals_for_annotations_with_iterative_distance_cycle_MASKS(dataloader, processor, existing_masks_path: str, how_many_frames=10, print_progress=False, mult_instead=False, alpha=1.0, too_small_mask_threshold_px=9, **kwargs) -> List[int]: + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = extract_keys(dataloader, processor, print_progress) + + h, w = frame_keys[0].squeeze().shape[1:3] # removing batch dimension + p_masks_dir = Path(existing_masks_path) + mask_sizes_px = [] + for i, p_img in enumerate(p_masks_dir.iterdir()): + img = cv2.imread(str(p_img)) + img = cv2.resize(img, (w, h)) / 255. + img_tensor = FT.to_tensor(img) + mask_size_px = (img_tensor > 0).sum() + mask_sizes_px.append(mask_size_px) + + if not mult_instead: + composite_key = torch.cat([frame_keys[i].cpu().squeeze(), img_tensor], dim=0) # along channels + else: + composite_key = frame_keys[i].cpu().squeeze() * img_tensor.max(dim=0, keepdim=True).values # all objects -> 1., background -> 0.. Keep 1 channel only + composite_key = composite_key * alpha + frame_keys[i].cpu().squeeze() * (1 - alpha) + frame_keys[i] = composite_key + + chosen_frames = [0] + chosen_frames_mem_keys = [frame_keys[0].to(device)] + + for i in tqdm(range(how_many_frames - 1), desc='Iteratively picking the most dissimilar frames', disable=not print_progress): + dissimilarities = [] + # how to run a loop for lower memory usage + for j in tqdm(range(num_frames), desc='Computing similarity to chosen frames', disable=not print_progress): + qk = frame_keys[j].to(device) + + if mask_sizes_px[j] < too_small_mask_threshold_px: + dissimilarity_min_across_all = 0 + else: + dissimilarities_across_mem_keys = [] + for mem_key in chosen_frames_mem_keys: + mem_key = mem_key.to(device) + + similarity_per_pixel = get_similarity( + mem_key, ms=None, qk=qk, qe=None) + reverse_similarity_per_pixel = get_similarity( + qk, ms=None, qk=mem_key, qe=None) + + # mapping of pixels A -> B would be very similar to B -> A if the images are similar + # and very different if the images are different + cycle_dissimilarity_per_pixel = ( + similarity_per_pixel - reverse_similarity_per_pixel) + + cycle_dissimilarity_score = F.relu(cycle_dissimilarity_per_pixel).sum() / \ + cycle_dissimilarity_per_pixel.numel() + dissimilarities_across_mem_keys.append( + cycle_dissimilarity_score) + + # filtering our existing or very similar frames + dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) + + dissimilarities.append(dissimilarity_min_across_all) + + values, indices = torch.topk(torch.tensor( + dissimilarities), k=1, largest=True) + chosen_new_frame = int(indices[0]) + + chosen_frames.append(chosen_new_frame) + chosen_frames_mem_keys.append( + frame_keys[chosen_new_frame].to(device)) + # chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + return chosen_frames + + +def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=0.5, min_mask_presence_percent=0.25, device: torch.device = 'cuda:0', progress_callback=None, only_new_candidates=True, epsilon=0.5): + assert len(keys) == len(masks) + assert len(keys) > 0 + # assert keys[0].shape[-2:] == masks[0].shape[-2:] + assert num_next_candidates > 0 + assert len(previously_chosen_candidates) > 0 + assert 0.0 <= alpha <= 1.0 + assert min_mask_presence_percent >= 0 + assert len(previously_chosen_candidates) < len(keys) + + + """ + Select candidate frames for annotation based on dissimilarity and cycle consistency. + + Parameters + ---------- + keys : torch.Tensor + A list of "key" feature maps for all frames of the video (from XMem key encoder) + shrinkages : [Type] + A list of "shrinkage" feature maps for all frames of the video (from XMem key encoder). Used for similarity computation. + selections : [Type] + A list of "sellection" feature maps for all frames of the video (from XMem key encoder). Used for similarity computation. + masks : List[torch.Tensor] + A list of masks for each frame (predicted or user-provided). + num_next_candidates : int + The number of candidate frames to select. + previously_chosen_candidates : List[int], optional + A list of previously chosen candidate indices. Default is (0,). + print_progress : bool, optional + Whether to print progress information. Default is False. + alpha : float, optional + The weight for the masks in the candidate selection process, [0..1]. If 0 - masks will be ignored, the same frames will be chosen for the same video. If 1.0 - ONLY regions of the frames containing the mask will be compared. Default is 0.5. + If you trust your masks and want object-specific selections, set higher. If your predictions are really bad, set lower + min_mask_presence_percent : float, optional + The minimum percentage of pixels for a valid mask. Default is 0.25. Used to ignore frames with a tiny mask (when heavily occluded or just some random wrong prediction) + device : torch.device, optional + The device to run the computation on. Default is 'cuda:0'. + progress_callback : callable, optional + A callback function for progress updates. Used in GUI for a progress bar. Default is None. + only_new_candidates : bool, optional + Whether to return only the newly selected candidates or include previous as well. Default is True. + epsilon : float, optional + Threshold for foreground/background [0..1]. Default is 0.5 + + Returns + ------- + List[int] + A list of indices of the selected candidate frames. + + Notes + ----- + This function uses a dissimilarity measure and cycle consistency to select candidate frames for the user to annotate. + The dissimilarity measure ensures that the selected frames are as diverse as possible, while the cycle consistency + ensures that the dissimilarity D(A->A)=0, while D(A->B)>0, and is larger the more different A and B are (pixel-wise, on feature map level - so both semantically and spatially). + """ + + with torch.no_grad(): + composite_keys = [] + keys = keys.squeeze() + N = len(keys) + h, w = keys[0].shape[1:3] # removing batch dimension + resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) + masks_validity = np.full(N, True) + + invalid = 0 + for i, mask in enumerate(masks): + mask_3ch = mask if mask.ndim == 3 else mask.unsqueeze(0) + mask_bin = mask_3ch.max(dim=0).values # for multiple objects -> use them as one large mask (simplest solution) + mask_size_px = (mask_bin > epsilon).sum() + + ratio = mask_size_px / mask_bin.numel() * 100 + if ratio < min_mask_presence_percent: # percentages to ratio + if i not in previously_chosen_candidates: + # if it's previously chosen, it's okay, we don't test for their validity + # e.g. we select frame #J, because we predicted something for it + # but in reality it's actually empty, so gt=0 + # so next iteration will break + masks_validity[i] = False + composite_keys.append(None) + invalid += 1 + + continue + + # if it's previously chosen, it's okay + # if i in previously_chosen_candidates: + # print(f"{i} previous candidate would be invalid (ratio perc={ratio})") + # raise ValueError(f"Given min_mask_presence_percent={min_mask_presence_percent}, even the previous candidates will be ignored. Reduce the value to avoid the error.") + + mask = resize(mask) + composite_key = keys[i] * mask.max(dim=0, keepdim=True).values # any object -> 1., background -> 0.. Keep 1 channel only + composite_key = composite_key * alpha + keys[i] * (1 - alpha) + + composite_keys.append(composite_key.to(dtype=keys[i].dtype, device=device)) + + print(f"Frames with invalid (empty or too small) masks: {invalid} / {len(masks)}") + chosen_candidates = list(previously_chosen_candidates) + chosen_candidate_keys = [composite_keys[i] for i in chosen_candidates] + + for i in tqdm(range(num_next_candidates), desc='Iteratively picking the most dissimilar frames', disable=not print_progress): + candidate_dissimilarities = [] + for j in tqdm(range(N), desc='Computing similarity to chosen frames', disable=not print_progress): + + if not masks_validity[j]: + # ignore this potential candidate + dissimilarity_min_across_all = 0 + else: + qk = composite_keys[j].to(device).unsqueeze(0) + q_shrinkage = shrinkages[j].to(device).unsqueeze(0) + q_selection = selections[j].to(device).unsqueeze(0) + + dissimilarities_across_mem_keys = [] + for mem_idx, mem_key in zip(chosen_candidates, chosen_candidate_keys): + mem_key = mem_key.unsqueeze(0) + mem_shrinkage = shrinkages[mem_idx].to(device).unsqueeze(0) + mem_selection = selections[mem_idx].to(device).unsqueeze(0) + + similarity_per_pixel = get_similarity(mem_key, ms=mem_shrinkage, qk=qk, qe=q_selection) + reverse_similarity_per_pixel = get_similarity(qk, ms=q_shrinkage, qk=mem_key, qe=mem_selection) + + # mapping of pixels A -> B would be very similar to B -> A if the images are similar + # and very different if the images are different + cycle_dissimilarity_per_pixel = (similarity_per_pixel - reverse_similarity_per_pixel).to(dtype=torch.float32) + + # Take non-negative mappings, normalize by tensor size + cycle_dissimilarity_score = F.relu(cycle_dissimilarity_per_pixel).sum() / cycle_dissimilarity_per_pixel.numel() + dissimilarities_across_mem_keys.append(cycle_dissimilarity_score) + + # filtering out existing or very similar frames + # if the key has already been used or is very similar to at least one of the chosen candidates + # dissimilarity_min_across_all -> 0 (or close to) + dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) + + candidate_dissimilarities.append(dissimilarity_min_across_all) + + index = torch.argmax(torch.tensor(candidate_dissimilarities)) + chosen_new_frame = int(index) + + chosen_candidates.append(chosen_new_frame) + chosen_candidate_keys.append(composite_keys[chosen_new_frame]) + + if progress_callback is not None: + progress_callback.emit(i + 1) + + if only_new_candidates: + chosen_candidates = chosen_candidates[len(previously_chosen_candidates):] + return chosen_candidates diff --git a/inference/frame_selection/frame_selection_utils.py b/inference/frame_selection/frame_selection_utils.py new file mode 100644 index 0000000..2510fae --- /dev/null +++ b/inference/frame_selection/frame_selection_utils.py @@ -0,0 +1,217 @@ +from functools import partial + +import torch +import torchvision.transforms.functional as FT +from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor, RandomAffine +from tqdm import tqdm + +from inference.data.video_reader import Sample + + +def extract_keys(dataloder, processor, print_progress=False, flatten=True, **kwargs): + frame_keys = [] + shrinkages = [] + selections = [] + device = None + with torch.no_grad(): # just in case + key_sum = None + + for ti, data in enumerate(tqdm(dataloder, disable=not print_progress, desc='Calculating key features')): + data: Sample = data + rgb = data.rgb.cuda() + key, shrinkage, selection = processor.encode_frame_key(rgb) + + if key_sum is None: + device = key.device + # to avoid possible overflow + key_sum = torch.zeros_like( + key, device=device, dtype=torch.float64) + + key_sum += key.type(torch.float64) + + if flatten: + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + selection = selection.flatten(start_dim=2) + + frame_keys.append(key.cpu()) + shrinkages.append(shrinkage.cpu()) + selections.append(selection.cpu()) + + num_frames = ti + 1 # 0 after 1 iteration, 1 after 2, etc. + + return frame_keys, shrinkages, selections, device, num_frames, key_sum + + +WhichAugToPick = -1 + + +def get_determenistic_augmentations(img_size=None, mask=None, subset: str = None): + assert subset in {'best_3', 'best_3_with_symmetrical', + 'best_all', 'original_only', 'all'} + + bright = ColorJitter(brightness=(1.5, 1.5)) + dark = ColorJitter(brightness=(0.5, 0.5)) + gray = Grayscale(num_output_channels=3) + reduce_bits = RandomPosterize(bits=3, p=1) + sharp = RandomAdjustSharpness(sharpness_factor=16, p=1) + rotate_right = RandomAffine(degrees=(30, 30)) + blur = partial(FT.gaussian_blur, kernel_size=7) + + if img_size is not None: + h, w = img_size[-2:] + translate_distance = w // 5 + else: + translate_distance = 200 + + translate_right = partial(FT.affine, angle=0, translate=( + translate_distance, 0), scale=1, shear=0) + + zoom_out = partial(FT.affine, angle=0, + translate=(0, 0), scale=0.5, shear=0) + zoom_in = partial(FT.affine, angle=0, translate=(0, 0), scale=1.5, shear=0) + shear_right = partial(FT.affine, angle=0, + translate=(0, 0), scale=1, shear=20) + + identity = torch.nn.Identity() + identity.name = 'identity' + + # if mask is not None: + # if mask.any(): + # min_y, min_x, max_y, max_x = get_bbox_from_mask(mask) + # h, w = mask.shape[-2:] + # crop_mask = partial(FT.resized_crop, top=min_y - 10, left=min_x - 10, + # height=max_y - min_y + 10, width=max_x - min_x + 10, size=(w, h)) + # crop_mask.name = 'crop_mask' + # else: + # crop_mask = identity # if the mask is empty + # else: + crop_mask = None + + bright.name = 'bright' + dark.name = 'dark' + gray.name = 'gray' + reduce_bits.name = 'reduce_bits' + sharp.name = 'sharp' + rotate_right.name = 'rotate_right' + translate_right.name = 'translate_right' + zoom_out.name = 'zoom_out' + zoom_in.name = 'zoom_in' + shear_right.name = 'shear_right' + blur.name = 'blur' + + rotate_left = RandomAffine(degrees=(-30, -30)) + rotate_left.name = 'rotate_left' + + shear_left = partial(FT.affine, angle=0, + translate=(0, 0), scale=1, shear=-20) + shear_left.name = 'shear_left' + + if WhichAugToPick != -1: + return [img_mask_augs_pairs[WhichAugToPick]] + + if subset == 'best_3': + img_mask_augs_pairs = [ + # augs only applied to the image + # (bright, identity), + # (dark, identity), + # (gray, identity), + # (reduce_bits, identity), + # (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + # (rotate_right, rotate_right), + # (rotate_left, rotate_left), + # (translate_right, translate_right), + # (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + # (shear_left, shear_left), + ] + + return img_mask_augs_pairs + elif subset == 'best_3_with_symmetrical': + img_mask_augs_pairs = [ + # augs only applied to the image + # (bright, identity), + # (dark, identity), + # (gray, identity), + # (reduce_bits, identity), + # (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + # (rotate_right, rotate_right), + # (rotate_left, rotate_left), + # (translate_right, translate_right), + # (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + (shear_left, shear_left), + ] + + return img_mask_augs_pairs + elif subset == 'best_all': + img_mask_augs_pairs = [ + # augs only applied to the image + (bright, identity), + (dark, identity), + # (gray, identity), + (reduce_bits, identity), + (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + (rotate_right, rotate_right), + (rotate_left, rotate_left), + # (translate_right, translate_right), + (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + (shear_left, shear_left), + ] + + return img_mask_augs_pairs + + elif subset == 'original_only': + img_mask_augs_pairs = [ + # augs only applied to the image + (bright, identity), + (dark, identity), + (gray, identity), + (reduce_bits, identity), + (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + # (rotate_right, rotate_right), + # (translate_right, translate_right), + # (zoom_out, zoom_out), + # (zoom_in, zoom_in), + # (shear_right, shear_right), + ] + else: + img_mask_augs_pairs = [ + # augs only applied to the image + (bright, identity), + (dark, identity), + (gray, identity), + (reduce_bits, identity), + (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + (rotate_right, rotate_right), + (rotate_left, rotate_left), + (translate_right, translate_right), + (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + (shear_left, shear_left), + ] + + if crop_mask is not None: + img_mask_augs_pairs.append((crop_mask, crop_mask)) + + return img_mask_augs_pairs \ No newline at end of file diff --git a/inference/inference_core.py b/inference/inference_core.py index f5459df..ed9190e 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -1,3 +1,6 @@ +from time import perf_counter + +import torch from inference.memory_manager import MemoryManager from model.network import XMem from model.aggregate import aggregate @@ -19,12 +22,20 @@ def __init__(self, network:XMem, config): self.clear_memory() self.all_labels = None - def clear_memory(self): + # warmup + self.network.encode_key(torch.zeros((1, 3, 480, 854), device='cuda:0')) + + def clear_memory(self, keep_permanent=False): self.curr_ti = -1 self.last_mem_ti = 0 if not self.deep_update_sync: self.last_deep_update_ti = -self.deep_update_every - self.memory = MemoryManager(config=self.config) + if keep_permanent: + new_memory = self.memory.copy_perm_mem_only() + else: + new_memory = MemoryManager(config=self.config) + + self.memory = new_memory def update_config(self, config): self.mem_every = config['mem_every'] @@ -39,15 +50,36 @@ def set_all_labels(self, all_labels): # self.all_labels = [l.item() for l in all_labels] self.all_labels = all_labels - def step(self, image, mask=None, valid_labels=None, end=False): + def encode_frame_key(self, image): + image, self.pad = pad_divide_by(image, 16) + image = image.unsqueeze(0) # add the batch dimension + + key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, + need_ek=True, + need_sk=True) + + return key, shrinkage, selection + def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False, disable_memory_updates=False, do_not_add_mask_to_memory=False, return_key_and_stuff=False): + # For feedback: + # 1. We run the model as usual + # 2. We get feedback: 2 lists, one with good prediction indices, one with bad + # 3. We force the good frames (+ annotated frames) to stay in working memory forever + # 4. We force the bad frames to never even get added to the working memory + # 5. Rerun with these settings # image: 3*H*W # mask: num_objects*H*W or None self.curr_ti += 1 + image, self.pad = pad_divide_by(image, 16) image = image.unsqueeze(0) # add the batch dimension + if manually_curated_masks: + is_mem_frame = (mask is not None) and (not end) + else: + is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) - is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) - need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) + is_ignore = do_not_add_mask_to_memory # to avoid adding permanent memory frames twice, since they are alredy in the memory + + need_segment = (valid_labels is None) or (len(self.all_labels) != len(valid_labels)) is_deep_update = ( (self.deep_update_sync and is_mem_frame) or # synchronized (not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync @@ -56,12 +88,18 @@ def step(self, image, mask=None, valid_labels=None, end=False): key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, need_ek=(self.enable_long_term or need_segment), - need_sk=is_mem_frame) + need_sk=True) multi_scale_features = (f16, f8, f4) + if disable_memory_updates: + is_normal_update = False + is_deep_update = False + is_mem_frame = False + self.curr_ti -= 1 # do not advance the iteration further + # segment the current frame is needed if need_segment: - memory_readout = self.memory.match_memory(key, selection).unsqueeze(0) + memory_readout = self.memory.match_memory(key, selection, disable_usage_updates=disable_memory_updates).unsqueeze(0) hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout, self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False) # remove batch dim @@ -90,18 +128,59 @@ def step(self, image, mask=None, valid_labels=None, end=False): pred_prob_with_bg = aggregate(mask, dim=0) # also create new hidden states - self.memory.create_hidden_state(len(self.all_labels), key) + if not disable_memory_updates: + self.memory.create_hidden_state(len(self.all_labels), key) # save as memory if needed if is_mem_frame: value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update) self.memory.add_memory(key, shrinkage, value, self.all_labels, - selection=selection if self.enable_long_term else None) + selection=selection if self.enable_long_term else None, ignore=is_ignore) + self.last_mem_ti = self.curr_ti if is_deep_update: self.memory.set_hidden(hidden) self.last_deep_update_ti = self.curr_ti - - return unpad(pred_prob_with_bg, self.pad) + + res = unpad(pred_prob_with_bg, self.pad) + + if return_key_and_stuff: + return res, key, shrinkage, selection + else: + return res + + def put_to_permanent_memory(self, image, mask, ti=None): + image, self.pad = pad_divide_by(image, 16) + image = image.unsqueeze(0) # add the batch dimension + key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, + need_ek=True, + need_sk=True) + + mask, _ = pad_divide_by(mask, 16) + + pred_prob_with_bg = aggregate(mask, dim=0) + self.memory.create_hidden_state(len(self.all_labels), key) + + value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), + pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=False) + + is_update = self.memory.frame_already_saved(ti) + # print(ti, f"update={is_update}") + if self.memory.frame_already_saved(ti): + self.memory.update_permanent_memory(ti, key, shrinkage, value, selection=selection if self.enable_long_term else None) + else: + self.memory.add_memory(key, shrinkage, value, self.all_labels, + selection=selection if self.enable_long_term else None, permanent=True, ti=ti) + + # print(self.memory.permanent_work_mem.key.shape) + + return is_update + + def remove_from_permanent_memory(self, frame_idx): + self.memory.remove_from_permanent_memory(frame_idx) + + @property + def permanent_memory_frames(self): + return list(self.memory.frame_id_to_permanent_mem_idx.keys()) \ No newline at end of file diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 039a382..f858c4c 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -15,7 +15,12 @@ import functools import os +from pathlib import Path +import re +from time import perf_counter import cv2 + +from inference.frame_selection.frame_selection import select_next_candidates # fix conflicts between qt5 and cv2 os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH") @@ -24,10 +29,11 @@ from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox, QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog, - QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, QRadioButton) + QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, + QRadioButton, QTabWidget, QDialog, QErrorMessage, QMessageBox, QLineEdit) -from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon -from PyQt5.QtCore import Qt, QTimer +from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon, QRegExpValidator +from PyQt5.QtCore import Qt, QTimer, QThreadPool, QRegExp from model.network import XMem @@ -56,37 +62,57 @@ def __init__(self, net: XMem, self.processor = InferenceCore(net, config) self.processor.set_all_labels(list(range(1, self.num_objects+1))) self.res_man = resource_manager + self.threadpool = QThreadPool() + self.last_opened_directory = str(Path.home()) self.num_frames = len(self.res_man) self.height, self.width = self.res_man.h, self.res_man.w # set window - self.setWindowTitle('XMem Demo') + self.setWindowTitle('XMem++ Demo') self.setGeometry(100, 100, self.width, self.height+100) self.setWindowIcon(QIcon('docs/icon.png')) # some buttons self.play_button = QPushButton('Play Video') + self.play_button.setToolTip("Play/Pause the video") self.play_button.clicked.connect(self.on_play_video) self.commit_button = QPushButton('Commit') + self.commit_button.setToolTip("Finish current interaction with the mask") self.commit_button.clicked.connect(self.on_commit) + self.save_reference_button = QPushButton('Save reference') + self.save_reference_button.setToolTip("Save current mask in the permanent memory.\nUsed by the model as a reference ground truth.") + self.save_reference_button.clicked.connect(self.on_save_reference) + self.compute_candidates_button = QPushButton('Compute Annotation candidates') + self.compute_candidates_button.setToolTip("Get next k frames that you should annotate.") + self.compute_candidates_button.clicked.connect(self.on_compute_candidates) + + self.full_run_button = QPushButton('FULL Propagate') + self.full_run_button.setToolTip("Clear the temporary memory, scroll to beginning and predict new masks for all the frames.") + self.full_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='full')) self.forward_run_button = QPushButton('Forward Propagate') - self.forward_run_button.clicked.connect(self.on_forward_propagation) + self.forward_run_button.setToolTip("Predict new masks for all the frames starting with the current one.") + self.forward_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='forward')) self.forward_run_button.setMinimumWidth(200) self.backward_run_button = QPushButton('Backward Propagate') - self.backward_run_button.clicked.connect(self.on_backward_propagation) + self.backward_run_button.setToolTip("Predict new masks for all the frames before with the current one.") + self.backward_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='backward')) self.backward_run_button.setMinimumWidth(200) - self.reset_button = QPushButton('Reset Frame') + self.reset_button = QPushButton('Delete Mask') + self.reset_button.setToolTip("Delete the mask for the current frames. Cannot be undone!") self.reset_button.clicked.connect(self.on_reset_mask) + self.spacebar = QShortcut(QKeySequence(Qt.Key_Space), self) + self.spacebar.activated.connect(self.pause_propagation) + # LCD self.lcd = QTextEdit() self.lcd.setReadOnly(True) self.lcd.setMaximumHeight(28) - self.lcd.setMaximumWidth(120) + self.lcd.setFixedWidth(120) self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1)) # timeline slider @@ -122,23 +148,26 @@ def __init__(self, net: XMem, self.combo.currentTextChanged.connect(self.set_viz_mode) self.save_visualization_checkbox = QCheckBox(self) + self.save_visualization_checkbox.setChecked(True) self.save_visualization_checkbox.toggled.connect(self.on_save_visualization_toggle) - self.save_visualization_checkbox.setChecked(False) - self.save_visualization = False + self.save_visualization = True # Radio buttons for type of interactions - self.curr_interaction = 'Click' + self.curr_interaction = 'Free' self.interaction_group = QButtonGroup() self.radio_fbrs = QRadioButton('Click') + self.radio_fbrs.setToolTip("Clicks in/out of the current mask. Careful - will delete existing mask!") self.radio_s2m = QRadioButton('Scribble') + self.radio_s2m.setToolTip('Draw a line in/out of the current mask. Edits existing masks directly.') self.radio_free = QRadioButton('Free') + self.radio_free.setToolTip('Free drawing') self.interaction_group.addButton(self.radio_fbrs) self.interaction_group.addButton(self.radio_s2m) self.interaction_group.addButton(self.radio_free) self.radio_fbrs.toggled.connect(self.interaction_radio_clicked) self.radio_s2m.toggled.connect(self.interaction_radio_clicked) self.radio_free.toggled.connect(self.interaction_radio_clicked) - self.radio_fbrs.toggle() + self.radio_free.toggle() # Main canvas -> QLabel self.main_canvas = QLabel() @@ -166,10 +195,14 @@ def __init__(self, net: XMem, self.zoom_m_button.clicked.connect(self.on_zoom_minus) # Parameters setting - self.clear_mem_button = QPushButton('Clear memory') + self.clear_mem_button = QPushButton('Clear TEMP and LONG memory') + self.clear_mem_button.setToolTip("Temporary and long-term memory can have features from the previous model run.
" + "If you had errors in the predictions, they might influence new masks.
" + "So for a new model run either clean the memory or just use FULL propagate.") self.clear_mem_button.clicked.connect(self.on_clear_memory) self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size') + self.work_mem_gauge.setToolTip("Temporary and Permanent memory together.") self.long_mem_gauge, self.long_mem_gauge_layout = create_gauge('Long-term memory size') self.gpu_mem_gauge, self.gpu_mem_gauge_layout = create_gauge('GPU mem. (all processes, w/ caching)') self.torch_mem_gauge, self.torch_mem_gauge_layout = create_gauge('GPU mem. (used by torch, w/o caching)') @@ -196,7 +229,12 @@ def __init__(self, net: XMem, # import mask/layer self.import_mask_button = QPushButton('Import mask') + self.import_mask_button.setToolTip("Import an existing .png file with a mask for a current frame.\nReplace existing mask.") self.import_mask_button.clicked.connect(self.on_import_mask) + + self.import_all_masks_button = QPushButton('Import ALL masks') + self.import_all_masks_button.setToolTip("Import a list of mask for some or all frames in the video.\nIf more than 10 are imported, the invididual confirmations will not be shown.") + self.import_all_masks_button.clicked.connect(self.on_import_all_masks) self.import_layer_button = QPushButton('Import layer') self.import_layer_button.clicked.connect(self.on_import_layer) @@ -233,16 +271,44 @@ def __init__(self, net: XMem, navi.addWidget(QLabel('Save overlay during propagation')) navi.addWidget(self.save_visualization_checkbox) navi.addStretch(1) + + # self.test_btn = QPushButton('TEST') + # self.test_btn.clicked.connect(self.TEST) + # navi.addWidget(self.test_btn) + navi.addWidget(self.save_reference_button) + # navi.addWidget(self.compute_candidates_button) navi.addWidget(self.commit_button) + navi.addWidget(self.full_run_button) navi.addWidget(self.forward_run_button) navi.addWidget(self.backward_run_button) # Drawing area, main canvas and minimap + self.color_picker = ColorPicker(self.num_objects, davis_palette) + self.color_picker.clicked.connect(self.hit_number_key) draw_area = QHBoxLayout() + draw_area.addWidget(self.color_picker) draw_area.addWidget(self.main_canvas, 4) + self.tabs = QTabWidget() + self.tabs.setMinimumWidth(500) + self.map_tab = QWidget() + self.references_tab = QWidget() + + references_scroll = QScrollArea() + references_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + references_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + references_scroll.setWidgetResizable(True) + references_scroll.setWidget(self.references_tab) + + self.references_scroll = references_scroll + + self.tabs.addTab(self.map_tab,"Minimap && Stats") + self.tabs.addTab(self.references_scroll, "References && Candidates") + + tabs_layout = QVBoxLayout() + # Minimap area - minimap_area = QVBoxLayout() + minimap_area = QVBoxLayout(self.map_tab) minimap_area.setAlignment(Qt.AlignTop) mini_label = QLabel('Minimap') mini_label.setAlignment(Qt.AlignTop) @@ -272,13 +338,44 @@ def __init__(self, net: XMem, import_area = QHBoxLayout() import_area.setAlignment(Qt.AlignTop) import_area.addWidget(self.import_mask_button) + import_area.addWidget(self.import_all_masks_button) import_area.addWidget(self.import_layer_button) minimap_area.addLayout(import_area) - # console - minimap_area.addWidget(self.console) - - draw_area.addLayout(minimap_area, 1) + chosen_figures_area = QVBoxLayout(self.references_tab) + chosen_figures_area.addWidget(QLabel("SAVED REFERENCES IN PERMANENT MEMORY")) + self.references_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail, delete_image=self.on_remove_reference, name='Reference frames') + chosen_figures_area.addWidget(self.references_collection) + + self.candidates_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail, name='Candidate frames') + chosen_figures_area.addWidget(QLabel("ANNOTATION CANDIDATES")) + chosen_figures_area.addWidget(self.candidates_collection) + + tabs_layout.addWidget(self.tabs) + tabs_layout.addWidget(self.console) + draw_area.addLayout(tabs_layout, 1) + + candidates_area = QVBoxLayout() + self.candidates_min_mask_size_edit = QLineEdit() + self.candidates_min_mask_size_edit.setToolTip("Minimal size a mask should have to be considered, % of the total image size." + "\nIf it's smaller than the value specified, the frame will be ignored." + "\nUsed to filter out \"junk\" frames or frames with very heavy occlusions.") + float_validator = QRegExpValidator(QRegExp(r"^(100(\.0+)?|[1-9]?\d(\.\d+)?|0(\.\d+)?)$")) + self.candidates_min_mask_size_edit.setValidator(float_validator) + self.candidates_min_mask_size_edit.setText("0.25") + self.candidates_k_slider = NamedSlider("k", 1, 20, 1, default=5) + self.candidates_k_slider.setToolTip("How many annotation candidates to select.") + self.candidates_alpha_slider = NamedSlider("α", 0, 100, 1, default=50, multiplier=0.01, min_text='Frames', max_text='Masks') + self.candidates_alpha_slider.setToolTip("Target importance." + "
If 0 the candidates will be the same regardless of which object is being segmented." + "
If 1 the only part of the image considered will be the one occupied by the mask.") + candidates_area.addWidget(QLabel("Min mask size, % of the total image size, 0-100")) + candidates_area.addWidget(self.candidates_min_mask_size_edit) + candidates_area.addWidget(QLabel("Candidates calculation hyperparameters")) + candidates_area.addWidget(self.candidates_k_slider) + candidates_area.addWidget(self.candidates_alpha_slider) + candidates_area.addWidget(self.compute_candidates_button) + tabs_layout.addLayout(candidates_area) layout = QVBoxLayout() layout.addLayout(draw_area) @@ -312,6 +409,8 @@ def __init__(self, net: XMem, self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32) self.cursur = 0 self.on_showing = None + self.reference_ids = set() + self.candidates_ids = [] # Zoom parameters self.zoom_pixels = 150 @@ -341,10 +440,12 @@ def __init__(self, net: XMem, self.vis_target_objects = [1] # try to load the default overlay self._try_load_layer('./docs/ECCV-logo.png') - + self.load_current_image_mask() self.show_current_frame() self.show() + self.style_new_reference() + self.load_existing_references() self.console_push_text('Initialized.') self.initialized = True @@ -446,6 +547,28 @@ def update_current_image_fast(self): qImg = QImage(self.viz.data, width, height, bytesPerLine, QImage.Format_RGB888) self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(), Qt.KeepAspectRatio, Qt.FastTransformation))) + + def load_current_image_thumbnail(self, *args, size=128): + # all this instead of self.main_canvas.pixmap() because it contains the brush as well + viz = get_visualization(self.viz_mode, self.current_image, self.current_mask, + self.overlay_layer, self.vis_target_objects) + + height, width, channel = viz.shape + bytesPerLine = 3 * width + qImg = QImage(viz.data, width, height, bytesPerLine, QImage.Format_RGB888) + curr_pixmap = QPixmap(qImg.scaled(self.main_canvas.size(), + Qt.KeepAspectRatio, Qt.FastTransformation)) + + curr_size = curr_pixmap.size() + h = curr_size.height() + w = curr_size.width() + + if h < w: + thumbnail = curr_pixmap.scaledToHeight(size) + else: + thumbnail = curr_pixmap.scaledToWidth(size) + + return thumbnail def show_current_frame(self, fast=False): # Re-compute overlay and show the image @@ -459,6 +582,25 @@ def show_current_frame(self, fast=False): self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1)) self.tl_slider.setValue(self.cursur) + if self.cursur in self.reference_ids: + self.style_editing_reference() + else: + self.style_new_reference() + + def style_editing_reference(self): + self.save_reference_button.setText("Update reference") + self.save_reference_button.setStyleSheet('QPushButton {background-color: #E4A11B; font-weight: bold; }') + + def style_new_reference(self): + self.save_reference_button.setText("Save reference") + self.save_reference_button.setStyleSheet('QPushButton {background-color: #14A44D; font-weight: bold;}') + + def load_existing_references(self): + for i in self.res_man.references: + self.scroll_to(i) + self.on_save_reference() + self.scroll_to(0) + def pixel_pos_to_image_pos(self, x, y): # Un-scale and un-pad the label coordinates into image coordinates oh, ow = self.image_size.height(), self.image_size.width() @@ -531,6 +673,11 @@ def tl_slide(self): self.load_current_image_mask() self.show_current_frame() + def scroll_to(self, idx): + assert self.tl_slider.minimum() <= idx <= self.tl_slider.maximum() + self.tl_slider.setValue(idx) + self.tl_slide() + def brush_slide(self): self.brush_size = self.brush_slider.value() self.brush_label.setText('Brush size: %d' % self.brush_size) @@ -541,12 +688,40 @@ def brush_slide(self): # Initialization, forget about it pass + def confirm_ready_for_propagation(self): + if len(self.reference_ids) > 0: + return True + + qm = QErrorMessage(self) + qm.setWindowModality(Qt.WindowModality.WindowModal) + qm.showMessage("Save at least 1 reference!") + + return False + + def general_propagation_callback(self, propagation_type: str): + if not self.confirm_ready_for_propagation(): + return + + self.tabs.setCurrentIndex(0) + if propagation_type == 'full': + self.on_full_propagation() + elif propagation_type == 'forward': + self.on_forward_propagation() + elif propagation_type == 'backward': + self.on_backward_propagation() + + def on_full_propagation(self): + self.on_clear_memory() + self.scroll_to(0) + self.on_forward_propagation() + def on_forward_propagation(self): if self.propagating: # acts as a pause button self.propagating = False else: self.propagate_fn = self.on_next_frame + self.full_run_button.setEnabled(False) self.backward_run_button.setEnabled(False) self.forward_run_button.setText('Pause Propagation') self.on_propagation() @@ -557,12 +732,14 @@ def on_backward_propagation(self): self.propagating = False else: self.propagate_fn = self.on_prev_frame + self.full_run_button.setEnabled(False) self.forward_run_button.setEnabled(False) self.backward_run_button.setText('Pause Propagation') self.on_propagation() def on_pause(self): self.propagating = False + self.full_run_button.setEnabled(True) self.forward_run_button.setEnabled(True) self.backward_run_button.setEnabled(True) self.clear_mem_button.setEnabled(True) @@ -576,7 +753,13 @@ def on_propagation(self): self.show_current_frame(fast=True) self.console_push_text('Propagation started.') - self.current_prob = self.processor.step(self.current_image_torch, self.current_prob[1:]) + is_mask = self.cursur in self.reference_ids + msk = self.current_prob[1:] if self.cursur in self.reference_ids else None + current_prob, key, shrinkage, selection = self.processor.step(self.current_image_torch, msk, return_key_and_stuff=True) + if not is_mask: + self.current_prob = current_prob + self.res_man.add_key_and_stuff_with_mask(self.cursur, key, shrinkage, selection, self.current_prob[1:]) + self.current_mask = torch_prob_to_numpy_mask(self.current_prob) # clear self.interacted_prob = None @@ -590,11 +773,16 @@ def on_propagation(self): self.load_current_image_mask(no_mask=True) self.load_current_torch_image_mask(no_mask=True) - - self.current_prob = self.processor.step(self.current_image_torch) - self.current_mask = torch_prob_to_numpy_mask(self.current_prob) - - self.save_current_mask() + is_mask = self.cursur in self.reference_ids + msk = self.current_prob[1:] if self.cursur in self.reference_ids else None + current_prob, key, shrinkage, selection = self.processor.step(self.current_image_torch, msk, return_key_and_stuff=True) + self.res_man.add_key_and_stuff_with_mask(self.cursur, key, shrinkage, selection, self.current_prob[1:]) + + if not is_mask: + self.current_prob = current_prob + self.current_mask = torch_prob_to_numpy_mask(self.current_prob) + self.save_current_mask() + self.show_current_frame(fast=True) self.update_memory_size() @@ -602,7 +790,7 @@ def on_propagation(self): if self.cursur == 0 or self.cursur == self.num_frames-1: break - + self.propagating = False self.curr_frame_dirty = False self.on_pause() @@ -616,6 +804,84 @@ def on_commit(self): self.complete_interaction() self.update_interacted_mask() + def confirm_ready_for_candidates_selection(self): + if self.res_man.all_masks_present(): + return True + + qm = QErrorMessage(self) + qm.setWindowModality(Qt.WindowModality.WindowModal) + qm.showMessage("Run propagation on all frames first!") + + return False + + def on_compute_candidates(self): + def _update_candidates(candidates_ids): + print(candidates_ids) + for i in self.candidates_ids: + # removing any old candidates left + self.candidates_collection.remove_image(i) + self.candidates_ids = candidates_ids + + prev_pos = self.cursur + for i in self.candidates_ids: + self.scroll_to(i) + self.candidates_collection.add_image(i) + self.scroll_to(prev_pos) + self.tabs.setCurrentIndex(1) + + def _update_progress(i): + candidate_progress.setValue(i) + + if not self.confirm_ready_for_candidates_selection(): + return + + k = self.candidates_k_slider.value() + alpha = self.candidates_alpha_slider.value() + candidate_progress = QProgressDialog("Selecting candidates", None, 0, k, self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint)) + worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.shrinkages, self.res_man.selections, self.res_man.small_masks, k, self.reference_ids, + print_progress=False, min_mask_presence_percent=float(self.candidates_min_mask_size_edit.text()), alpha=alpha) # Any other args, kwargs are passed to the run function + worker.signals.result.connect(_update_candidates) + worker.signals.progress.connect(_update_progress) + + self.threadpool.start(worker) + + candidate_progress.open() + + def on_save_reference(self): + if self.interaction is not None: + self.on_commit() + current_image_torch, _ = image_to_torch(self.current_image) + current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda() + + msk = current_prob[1:] + a = perf_counter() + is_update = self.processor.put_to_permanent_memory(current_image_torch, msk, self.cursur) + b = perf_counter() + + self.console_push_text(f"Saving took {(b-a)*1000:.2f} ms.") + + if is_update: + self.reference_ids.remove(self.cursur) + self.references_collection.remove_image(self.cursur) + + self.reference_ids.add(self.cursur) + self.references_collection.add_image(self.cursur) + self.res_man.add_reference(self.cursur) + + if self.cursur in self.candidates_ids: + self.candidates_ids.remove(self.cursur) + self.candidates_collection.remove_image(self.cursur) + + self.show_current_frame() + self.tabs.setCurrentIndex(1) + + def on_remove_reference(self, img_idx): + self.processor.remove_from_permanent_memory(img_idx) + self.reference_ids.remove(img_idx) + self.res_man.remove_reference(img_idx) + + self.show_current_frame() + def on_prev_frame(self): # self.tl_slide will trigger on setValue self.cursur = max(0, self.cursur-1) @@ -678,6 +944,7 @@ def hit_number_key(self, number): self.vis_brush(self.last_ex, self.last_ey) self.update_interact_vis() self.show_current_frame() + self.color_picker.select(self.current_object) def clear_brush(self): self.brush_vis_map.fill(0) @@ -821,10 +1088,10 @@ def on_gpu_timer(self): def update_memory_size(self): try: - max_work_elements = self.processor.memory.max_work_elements + max_work_elements = self.processor.memory.max_work_elements + self.processor.memory.permanent_work_mem.size max_long_elements = self.processor.memory.max_long_elements - curr_work_elements = self.processor.memory.work_mem.size + curr_work_elements = self.processor.memory.temporary_work_mem.size + self.processor.memory.permanent_work_mem.size curr_long_elements = self.processor.memory.long_mem.size self.work_mem_gauge.setFormat(f'{curr_work_elements} / {max_work_elements}') @@ -860,22 +1127,77 @@ def update_config(self): self.processor.update_config(self.config) def on_clear_memory(self): - self.processor.clear_memory() + self.processor.clear_memory(keep_permanent=True) torch.cuda.empty_cache() self.update_gpu_usage() self.update_memory_size() def _open_file(self, prompt): options = QFileDialog.Options() - file_name, _ = QFileDialog.getOpenFileName(self, prompt, "", "Image files (*)", options=options) + file_name, _ = QFileDialog.getOpenFileName(self, prompt, self.last_opened_directory, "Image files (*)", options=options) + if file_name: + self.last_opened_directory = str(Path(file_name).parent) return file_name - def on_import_mask(self): - file_name = self._open_file('Mask') - if len(file_name) == 0: - return + def on_import_all_masks(self): + dir_path = QFileDialog.getExistingDirectory() + if dir_path: + self.last_opened_directory = dir_path + + all_correct = True + frame_ids = [] + incorrect_files = [] + pattern = re.compile(r'([0-9]+)') + files_paths = sorted(Path(dir_path).iterdir()) + for p_f in files_paths: + match = pattern.search(p_f.name) + if match: + frame_id = int(match.string[match.start():match.end()]) + frame_ids.append(frame_id) + else: + all_correct = False + incorrect_files.apend(p_f.name) + + + if not all_correct or frame_ids != sorted(frame_ids): + qm = QErrorMessage(self) + qm.setWindowModality(Qt.WindowModality.WindowModal) + broken_file_names = '\n'.join(incorrect_files) + qm.showMessage(f"Files with incorrect names: {broken_file_names}") - mask = self.res_man.read_external_image(file_name, size=(self.height, self.width)) + else: + if len(frame_ids) > 10: + qm = QMessageBox(QMessageBox.Icon.Question, "Confirm mask replacement", "") + question = f"There are more than 10 masks to import, so confirmations for each individual one would not be asked. Are you willing to continue?" + ret = qm.question(self, 'Confirm mask replacement', question, qm.Yes | qm.No) + if ret == qm.Yes: + progress_dialog = QProgressDialog("Importing masks", None, 0, len(frame_ids), self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint)) + progress_dialog.open() + a = perf_counter() + for i, p_f in zip(frame_ids, files_paths): + # Only showing progress bar to speed up + self.cursur = i + self.on_import_mask(str(p_f), ask_confirmation=False) + progress_dialog.setValue(i + 1) + QApplication.processEvents() + b = perf_counter() + self.console_push_text(f"Importing {len(frame_ids)} masks took {b-a:.2f} seconds ({len(frame_ids)/(b-a):.2f} FPS)") + self.cursur = 0 + + else: + for i, p_f in zip(frame_ids, files_paths): + self.scroll_to(i) + self.on_import_mask(str(p_f), ask_confirmation=True) + + def on_import_mask(self, mask_file_path=None, ask_confirmation=True): + if mask_file_path: + file_name = mask_file_path + else: + file_name = self._open_file('Mask') + if len(file_name) == 0: + return + + mask = self.res_man.read_external_image(file_name, size=(self.height, self.width), force_mask=True) shape_condition = ( (len(mask.shape) == 2) and @@ -892,11 +1214,29 @@ def on_import_mask(self): elif not object_condition: self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.') else: - self.console_push_text(f'Mask file {file_name} loaded.') - self.current_image_torch = self.current_prob = None - self.current_mask = mask - self.show_current_frame() - self.save_current_mask() + if ask_confirmation: + qm = QMessageBox(QMessageBox.Icon.Question, "Confirm mask replacement", "") + question = f"Replace mask for current frame {self.cursur} with {Path(file_name).name}?" + ret = qm.question(self, 'Confirm mask replacement', question, qm.Yes | qm.No) + + if not ask_confirmation or ret == qm.Yes: + self.console_push_text(f'Mask file {file_name} loaded.') + self.current_image_torch = self.current_prob = None + self.current_mask = mask + + if ask_confirmation: + # for speedup purposes + self.curr_frame_dirty = False + self.reset_this_interaction() + self.show_current_frame() + + self.save_current_mask() + + if ask_confirmation: + # Only save references if it's an individual image or a few (< 10) + # If the user is importing 1000+ masks, the memory is going to explode + self.on_save_reference() + def on_import_layer(self): file_name = self._open_file('Layer') diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index daf852b..c99145c 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -1,5 +1,83 @@ -from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar) +from functools import partial +from typing import Optional, Union +import time +import traceback, sys +from PyQt5 import QtCore +from PyQt5.QtGui import QPalette, QColor + +from PyQt5.QtCore import Qt, QRunnable, pyqtSlot, pyqtSignal, QObject, QPoint, QRect, QSize +from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar, QDialog, QWidget, + QProgressDialog, QScrollArea, QLayout, QLayoutItem, QStyle, QSizePolicy, QSpacerItem, + QFrame, QPushButton, QSlider, QMessageBox, QGridLayout) + +class WorkerSignals(QObject): + ''' + Defines the signals available from a running worker thread. + + Supported signals are: + + finished + No data + + error + tuple (exctype, value, traceback.format_exc() ) + + result + object data returned from processing, anything + + progress + int indicating % progress + + ''' + finished = pyqtSignal() + error = pyqtSignal(tuple) + result = pyqtSignal(object) + progress = pyqtSignal(int) + + +class Worker(QRunnable): + ''' + Worker thread + + Inherits from QRunnable to handler worker thread setup, signals and wrap-up. + + :param callback: The function callback to run on this worker thread. Supplied args and + kwargs will be passed through to the runner. + :type callback: function + :param args: Arguments to pass to the callback function + :param kwargs: Keywords to pass to the callback function + + ''' + + def __init__(self, fn, *args, **kwargs): + super(Worker, self).__init__() + + # Store constructor arguments (re-used for processing) + self.fn = fn + self.args = args + self.kwargs = kwargs + self.signals = WorkerSignals() + + # Add the callback to our kwargs + self.kwargs['progress_callback'] = self.signals.progress + + @pyqtSlot() + def run(self): + ''' + Initialise the runner function with passed args, kwargs. + ''' + + # Retrieve args/kwargs here; and fire processing using them + try: + result = self.fn(*self.args, **self.kwargs) + except: + traceback.print_exc() + exctype, value = sys.exc_info()[:2] + self.signals.error.emit((exctype, value, traceback.format_exc())) + else: + self.signals.result.emit(result) # Return the result of the processing + finally: + self.signals.finished.emit() # Done def create_parameter_box(min_val, max_val, text, step=1, callback=None): @@ -38,3 +116,363 @@ def create_gauge(text): layout.addWidget(gauge) return gauge, layout + + +class FlowLayout(QLayout): + def __init__(self, parent: QWidget=None, margin: int=-1, hSpacing: int=-1, vSpacing: int=-1): + super().__init__(parent) + + self.itemList = list() + self.m_hSpace = hSpacing + self.m_vSpace = vSpacing + + self.setContentsMargins(margin, margin, margin, margin) + + def __del__(self): + # copied for consistency, not sure this is needed or ever called + item = self.takeAt(0) + while item: + item = self.takeAt(0) + + def addItem(self, item: QLayoutItem): + self.itemList.append(item) + + def horizontalSpacing(self) -> int: + if self.m_hSpace >= 0: + return self.m_hSpace + else: + return self.smartSpacing(QStyle.PM_LayoutHorizontalSpacing) + + def verticalSpacing(self) -> int: + if self.m_vSpace >= 0: + return self.m_vSpace + else: + return self.smartSpacing(QStyle.PM_LayoutVerticalSpacing) + + def count(self) -> int: + return len(self.itemList) + + def itemAt(self, index: int) -> Union[QLayoutItem, None]: + if 0 <= index < len(self.itemList): + return self.itemList[index] + else: + return None + + def takeAt(self, index: int) -> Union[QLayoutItem, None]: + if 0 <= index < len(self.itemList): + return self.itemList.pop(index) + else: + return None + + def expandingDirections(self) -> Qt.Orientations: + return Qt.Orientations(Qt.Orientation(0)) + + def hasHeightForWidth(self) -> bool: + return True + + def heightForWidth(self, width: int) -> int: + height = self.doLayout(QRect(0, 0, width, 0), True) + return height + + def setGeometry(self, rect: QRect) -> None: + super().setGeometry(rect) + self.doLayout(rect, False) + + def sizeHint(self) -> QSize: + return self.minimumSize() + + def minimumSize(self) -> QSize: + size = QSize() + for item in self.itemList: + size = size.expandedTo(item.minimumSize()) + + margins = self.contentsMargins() + size += QSize(margins.left() + margins.right(), margins.top() + margins.bottom()) + return size + + def smartSpacing(self, pm: QStyle.PixelMetric) -> int: + parent = self.parent() + if not parent: + return -1 + elif parent.isWidgetType(): + return parent.style().pixelMetric(pm, None, parent) + else: + return parent.spacing() + + def doLayout(self, rect: QRect, testOnly: bool) -> int: + left, top, right, bottom = self.getContentsMargins() + effectiveRect = rect.adjusted(+left, +top, -right, -bottom) + x = effectiveRect.x() + y = effectiveRect.y() + lineHeight = 0 + + for item in self.itemList: + wid = item.widget() + spaceX = self.horizontalSpacing() + if spaceX == -1: + spaceX = wid.style().layoutSpacing(QSizePolicy.PushButton, QSizePolicy.PushButton, Qt.Horizontal) + spaceY = self.verticalSpacing() + if spaceY == -1: + spaceY = wid.style().layoutSpacing(QSizePolicy.PushButton, QSizePolicy.PushButton, Qt.Vertical) + + nextX = x + item.sizeHint().width() + spaceX + if nextX - spaceX > effectiveRect.right() and lineHeight > 0: + x = effectiveRect.x() + y = y + lineHeight + spaceY + nextX = x + item.sizeHint().width() + spaceX + lineHeight = 0 + + if not testOnly: + item.setGeometry(QRect(QPoint(x, y), item.sizeHint())) + + x = nextX + lineHeight = max(lineHeight, item.sizeHint().height()) + + return y + lineHeight - rect.y() + bottom + + +class JFlowLayout(FlowLayout): + # flow layout, similar to an HTML `
` + # this is our "wrapper" to the `FlowLayout` sample Qt code we have implemented + # we use it in place of where we used to use a `QHBoxLayout` + # in order to make few outside-world changes, and revert to `QHBoxLayout`if we ever want to, + # there are a couple of methods here which are available on a `QBoxLayout` but not on a `QLayout` + # for which we provide a "lite-equivalent" which will suffice for our purposes + + def addLayout(self, layout: QLayout, stretch: int=0): + # "equivalent" of `QBoxLayout.addLayout()` + # we want to add sub-layouts (e.g. a `QVBoxLayout` holding a label above a widget) + # there is some dispute as to how to do this/whether it is supported by `FlowLayout` + # see my https://forum.qt.io/topic/104653/how-to-do-a-no-break-qhboxlayout + # there is a suggestion that we should not add a sub-layout but rather enclose it in a `QWidget` + # but since it seems to be working as I've done it below I'm elaving it at that for now... + + # suprisingly to me, we do not need to add the layout via `addChildLayout()`, that seems to make no difference + # self.addChildLayout(layout) + # all that seems to be reuqired is to add it onto the list via `addItem()` + self.addItem(layout) + + def addStretch(self, stretch: int=0): + # "equivalent" of `QBoxLayout.addStretch()` + # we can't do stretches, we just arbitrarily put in a "spacer" to give a bit of a gap + w = stretch * 20 + spacerItem = QSpacerItem(w, 0, QSizePolicy.Expanding, QSizePolicy.Minimum) + self.addItem(spacerItem) + + +class NamedSlider(QWidget): + valueChanged = pyqtSignal(float) + + def __init__(self, name: str, min_: int, max_: int, step_size: int, default: int, multiplier=1, min_text=None, max_text=None, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.name = name + self.multiplier = multiplier + self.min_text = min_text + self.max_text = max_text + + layout = QHBoxLayout(self) + self.slider = QSlider(Qt.Horizontal) + self.slider.setMinimum(min_) + self.slider.setMaximum(max_) + self.slider.setValue(default) + self.slider.setTickPosition(QSlider.TicksBelow) + self.slider.setTickInterval(step_size) + + name_label = QLabel(name + " |") + self.value_label = QLabel() + + layout.addWidget(name_label) + layout.addWidget(self.value_label) + layout.addWidget(self.slider) + + self.update_name() + + self.slider.valueChanged.connect(self.on_slide) + + def value(self): + return self.slider.value() * self.multiplier + + def on_slide(self): + self.update_name() + self.valueChanged.emit(self.slider.value() * self.multiplier) + + def update_name(self): + value = self.value() + value_str = None + if self.multiplier != 1: + if isinstance(self.multiplier, float): + min_str = f'{self.slider.minimum() * self.multiplier:.2f}' + value_str = f'{value:.2f}' + max_str = f'{self.slider.maximum() * self.multiplier:.2f}' + + if value_str is None: + min_str = f'{self.slider.minimum() * self.multiplier:d}' + value_str = f'{value:d}' + max_str = f'{self.slider.maximum() * self.multiplier:d}' + + if self.min_text is not None: + min_str += f' ({self.min_text})' + if self.max_text is not None: + max_str += f' ({self.max_text})' + + final_str = f'{min_str} <= {value_str} <= {max_str}' + + self.value_label.setText(final_str) + +class ClickableLabel(QLabel): + clicked = pyqtSignal() + def mouseReleaseEvent(self, event): + super(ClickableLabel, self).mousePressEvent(event) + if event.button() == Qt.LeftButton and event.pos() in self.rect(): + self.clicked.emit() + + +class ImageWithCaption(QWidget): + def __init__(self, img: QLabel, caption: str, on_close: callable = None, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.layout = QVBoxLayout(self) + self.text_label = QLabel(caption) + self.close_btn = QPushButton("x") + self.close_btn.setMaximumSize(35, 35) + self.close_btn.setMinimumSize(35, 35) + self.close_btn.setStyleSheet('QPushButton {background-color: #DC4C64; font-weight: bold; }') + if on_close is not None: + self.close_btn.clicked.connect(on_close) + + self.top_tab_layout = QHBoxLayout() + self.top_tab_layout.addWidget(self.text_label) + self.top_tab_layout.addWidget(self.close_btn) + self.top_tab_layout.setAlignment(self.text_label, Qt.AlignmentFlag.AlignCenter) + self.top_tab_layout.setAlignment(self.close_btn, Qt.AlignmentFlag.AlignRight) + + self.layout.addLayout(self.top_tab_layout) + + self.layout.addWidget(img) + + self.layout.setAlignment(self.text_label, Qt.AlignmentFlag.AlignHCenter) + +class ImageLinkCollection(QWidget): + def __init__(self, on_click: callable, load_image: callable, delete_image: callable = None, name: str = None, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.on_click = on_click + self.load_image = load_image + self.delete_image = delete_image + self.name = name + # scrollable_area = QScrollArea(self) + # frame = QFrame(scrollable_area) + + self.flow_layout = JFlowLayout(self) + + self.img_widgets_lookup = dict() + + + def add_image(self, img_idx): + image = self.load_image(img_idx) + + img_widget = ClickableLabel() + img_widget.setPixmap(image) + + img_widget.clicked.connect(partial(self.on_click, img_idx)) + + wrapper = ImageWithCaption(img_widget, f"Frame {img_idx:>6d}", on_close=partial(self.on_close_click, img_idx)) + # layout.addWidget(img_widget) + + self.img_widgets_lookup[img_idx] = wrapper + self.flow_layout.addWidget(wrapper) + + def remove_image(self, img_idx): + img_widget = self.img_widgets_lookup.pop(img_idx) + self.flow_layout.removeWidget(img_widget) + + def on_close_click(self, img_idx): + qm = QMessageBox(QMessageBox.Icon.Warning, "Confirm deletion", "") + question = f"Delete Frame {img_idx}" + if self.name is not None: + question += f' from {self.name}' + + question += '?' + ret = qm.question(self, 'Confirm deletion', question, qm.Yes | qm.No) + + if ret == qm.Yes: + self.remove_image(img_idx) + if self.delete_image is not None: + self.delete_image(img_idx) + +class ColorPicker(QWidget): + clicked = pyqtSignal(int) + + def __init__(self, num_colors, color_palette: bytes, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.num_colors = num_colors + + self.outer_layout = QVBoxLayout(self) + self.outer_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + + self.inner_layout = QGridLayout() # 2 x N/2 + # self.inner_layout_wrapper = QHBoxLayout() + self.inner_layout.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.palette = color_palette + self.previously_selected = None + + for i in range(self.num_colors): + index = i + 1 + + color_widget = ClickableLabel(str(index)) + + color = self.palette[index * 3: index*3 + 3] + + color_widget.setStyleSheet(f"QLabel {{font-family: Monospace; color:white; font-weight: 900; background-color: rgb{tuple(color)}}} QLabel.selected {{border: 4px solid}}") + color_widget.setAlignment(Qt.AlignmentFlag.AlignCenter) + + color_widget.setFixedSize(40, 40) + self.inner_layout.addWidget(color_widget, int(i / 2), i % 2) + + color_widget.clicked.connect(partial(self._on_color_clicked, index)) + + color_picker_name = QLabel("Object selector") + color_picker_name.setAlignment(Qt.AlignmentFlag.AlignCenter) + color_picker_name.setStyleSheet("QLabel {font-family: Monospace; font-weight: 900}") + + num_objects_label = QLabel(f"({self.num_colors} objects)") + num_objects_label.setStyleSheet("QLabel {font-family: Monospace; font-weight: 900}") + num_objects_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + + + color_picker_instruction = QLabel("Click or use\nnumerical keys") + color_picker_instruction.setStyleSheet("QLabel {font-family: Monospace; font-style: italic}") + color_picker_instruction.setAlignment(Qt.AlignmentFlag.AlignCenter) + + text_wrapper_widget = QWidget() + text_wrapper_widget.setStyleSheet("QWidget {background-color: rgb(225, 225, 225);}") + text_layout = QVBoxLayout(text_wrapper_widget) + text_layout.addWidget(color_picker_name) + text_layout.addWidget(num_objects_label) + text_layout.addWidget(color_picker_instruction) + + self.outer_layout.addWidget(text_wrapper_widget) + self.outer_layout.addLayout(self.inner_layout) + + self.select(1) # First object selected by default + + def _on_color_clicked(self, index: int): + self.clicked.emit(index) + pass + + def select(self, index: int): # 1-based, not 0-based + widget = self.inner_layout.itemAt(index - 1).widget() + widget.setProperty("class", "selected") + widget.style().unpolish(widget) + widget.style().polish(widget) + widget.update() + + # print(widget.text()) + # print(widget.styleSheet()) + + if self.previously_selected is not None: + self.previously_selected.setProperty("class", "") + self.previously_selected.style().unpolish(self.previously_selected) + self.previously_selected.style().polish(self.previously_selected) + self.previously_selected.update() + + self.previously_selected = self.inner_layout.itemAt(index - 1).widget() diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index b0f28af..80680fc 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -1,10 +1,17 @@ +import json import os from os import path +from pathlib import Path import shutil import collections import cv2 from PIL import Image +import torch +from torchvision.transforms import Resize, InterpolationMode + +from util.image_loader import PaletteConverter + if not hasattr(Image, 'Resampling'): # Pillow<9.0 Image.Resampling = Image import numpy as np @@ -44,11 +51,17 @@ def __init__(self, config): self.workspace = config['workspace'] self.size = config['size'] self.palette = davis_palette + self.palette_converter = PaletteConverter(self.palette) # create temporary workspace if not specified if self.workspace is None: if images is not None: - basename = path.basename(images) + p_images = Path(images) + if p_images.name == 'JPEGImages' or (Path.cwd() / 'workspace') in p_images.parents: + # take the name instead of actual images dir (second case checks for videos already in ./workspace ) + basename = p_images.parent.name + else: + basename = p_images.name elif video is not None: basename = path.basename(video)[:-4] else: @@ -58,6 +71,16 @@ def __init__(self, config): self.workspace = path.join('./workspace', basename) print(f'Workspace is in: {self.workspace}') + self.workspace_info_file = path.join(self.workspace, 'info.json') + self.references = set() + self._num_objects = None + self._try_load_info() + + if config['num_objects'] is not None: # forced overwrite from user + self._num_objects = config['num_objects'] + elif self._num_objects is None: # both are None, single object first run use case + self._num_objects = config['num_objects_default_value'] + self._save_info() # determine the location of input images need_decoding = False @@ -100,6 +123,13 @@ def __init__(self, config): self.height, self.width = self.get_image(0).shape[:2] self.visualization_init = False + self._resize = None + self._masks = None + self._keys = None + self._keys_processed = np.zeros(self.length, dtype=bool) + self.key_h = None + self.key_w = None + def _extract_frames(self, video): cap = cv2.VideoCapture(video) frame_index = 0 @@ -115,7 +145,7 @@ def _extract_frames(self, video): new_h = (h*self.size//min(w, h)) if new_w != w or new_h != h: frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) - cv2.imwrite(path.join(self.image_dir, f'{frame_index:07d}.jpg'), frame) + cv2.imwrite(path.join(self.image_dir, f'frame_{frame_index:06d}.jpg'), frame) frame_index += 1 bar.update(frame_index) bar.finish() @@ -138,6 +168,61 @@ def _copy_resize_frames(self, images): cv2.imwrite(path.join(self.image_dir, image_name), frame) print('Done!') + def add_key_and_stuff_with_mask(self, ti, key, shrinkage, selection, mask): + if self._keys is None: + c, h, w = key.squeeze().shape + if self.key_h is None: + self.key_h = h + if self.key_w is None: + self.key_w = w + c_mask, h_mask, w_mask = mask.shape + self._keys = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device) + self._shrinkages = torch.empty((self.length, 1, h, w), dtype=key.dtype, device=key.device) + self._selections = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device) + self._masks = torch.empty((self.length, c_mask, h_mask, w_mask), dtype=mask.dtype, device=key.device) + # self._resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) + + if not self._keys_processed[ti]: + # keys don't change for the video, so we only save them once + self._keys[ti] = key + self._shrinkages[ti] = shrinkage + self._selections[ti] = selection + self._keys_processed[ti] = True + + self._masks[ti] = mask# self._resize(mask) + + def all_masks_present(self): + return self._keys_processed.sum() == self.length + + def add_reference(self, frame_id: int): + self.references.add(frame_id) + self._save_info() + + def remove_reference(self, frame_id: int): + print(self.references) + self.references.remove(frame_id) + self._save_info() + + def _save_info(self): + p_workspace_subdir = Path(self.workspace_info_file).parent + p_workspace_subdir.mkdir(parents=True, exist_ok=True) + with open(self.workspace_info_file, 'wt') as f: + data = {'references': sorted(self.references), 'num_objects': self._num_objects} + + json.dump(data, f, indent=4) + + def _try_load_info(self): + try: + with open(self.workspace_info_file) as f: + data = json.load(f) + self._num_objects = data['num_objects'] + + # We might have num_objects, but not references if imported the project + self.references = set(data['references']) + except Exception: + pass + + def save_mask(self, ti, mask): # mask should be uint8 H*W without channels assert 0 <= ti < self.length @@ -180,13 +265,36 @@ def _get_mask_unbuffered(self, ti): else: return None - def read_external_image(self, file_name, size=None): + def read_external_image(self, file_name, size=None, force_mask=False): image = Image.open(file_name) is_mask = image.mode in ['L', 'P'] + if size is not None: # PIL uses (width, height) image = image.resize((size[1], size[0]), - resample=Image.Resampling.NEAREST if is_mask else Image.Resampling.BICUBIC) + resample=Image.Resampling.NEAREST if is_mask or force_mask else Image.Resampling.BICUBIC) + + if force_mask and image.mode != 'P': + image = self.palette_converter.image_to_index_mask(image) + # if image.mode in ['RGB', 'L'] and len(image.getcolors()) <= 2: + # image = np.array(image.convert('L')) + # # hardcoded for b&w images + # image = np.where(image, 1, 0) # 255 (or whatever) -> binarize + + # return image.astype('uint8') + # elif image.mode == 'RGB': + # image = image.convert('P', palette=self.palette) + # tmp_image = np.array(image) + # out_image = np.zeros_like(tmp_image) + # for i, c in enumerate(np.unique(tmp_image)): + # if i == 0: + # continue + # out_image[tmp_image == c] = i # palette indices into 0, 1, 2, ... + # self.palette = image.getpalette() + # return out_image + + # image = image.convert('P', palette=self.palette) # saved without DAVIS palette, just number objects 0, 1, ... + image = np.array(image) return image @@ -204,3 +312,24 @@ def h(self): @property def w(self): return self.width + + @property + def small_masks(self): + return self._masks + + @property + def keys(self): + return self._keys + + + @property + def shrinkages(self): + return self._shrinkages + + @property + def selections(self): + return self._selections + + @property + def num_objects(self): + return self._num_objects diff --git a/inference/kv_memory_store.py b/inference/kv_memory_store.py index 33a3326..31b62e9 100644 --- a/inference/kv_memory_store.py +++ b/inference/kv_memory_store.py @@ -89,6 +89,10 @@ def add(self, key, value, shrinkage, selection, objects: List[int]): else: self.v.append(gv) + pos = int((self.k.shape[-1] + 1e-9) // (key.shape[-1] + 1e-9)) - 1 # index of newly added frame + + return pos + def update_usage(self, usage): # increase all life count by 1 # increase use of indexed elements @@ -98,6 +102,26 @@ def update_usage(self, usage): self.use_count += usage.view_as(self.use_count) self.life_count += 1 + def replace_at(self, start_pos: int, key, value, shrinkage=None, selection=None): + start = start_pos * key.shape[-1] + end = (start_pos + 1) * key.shape[-1] + + self.k[:,:,start:end] = key + + for gi in range(self.num_groups): + self.v[gi][:, :, start:end] = value[gi] + + if self.s is not None and shrinkage is not None: + self.s[:, :, start:end] = shrinkage + + if self.e is not None and selection is not None: + self.e[:, :, start:end] = selection + + def remove_at(self, start: int, elem_size: int): + end = start + elem_size + + self.sieve_by_range(start, end, min_size=0) # remove the value irrespective of its size + def sieve_by_range(self, start: int, end: int, min_size: int): # keep only the elements *outside* of this range (with some boundary conditions) # i.e., concat (a[:start], a[end:]) @@ -105,6 +129,7 @@ def sieve_by_range(self, start: int, end: int, min_size: int): # (because they are not consolidated) if end == 0: + # just sieves till the `start` # negative 0 would not work as the end index! self.k = self.k[:,:,:start] if self.count_usage: diff --git a/inference/memory_manager.py b/inference/memory_manager.py index bce2c00..93f4563 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -9,15 +9,17 @@ class MemoryManager: """ Manages all three memory stores and the transition between working/long-term memory """ + def __init__(self, config): + self.config = config self.hidden_dim = config['hidden_dim'] self.top_k = config['top_k'] self.enable_long_term = config['enable_long_term'] self.enable_long_term_usage = config['enable_long_term_count_usage'] if self.enable_long_term: - self.max_mt_frames = config['max_mid_term_frames'] - self.min_mt_frames = config['min_mid_term_frames'] + self.max_mt_frames = config['max_mid_term_frames'] # maximum work memory size + self.min_mt_frames = config['min_mid_term_frames'] # minimum number of frames to keep in work memory when consolidating self.num_prototypes = config['num_prototypes'] self.max_long_elements = config['max_long_term_elements'] @@ -29,7 +31,9 @@ def __init__(self, config): # B x num_objects x CH x H x W self.hidden = None - self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) + self.temporary_work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) + self.permanent_work_mem = KeyValueMemoryStore(count_usage=False) + self.frame_id_to_permanent_mem_idx = dict() if self.enable_long_term: self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage) @@ -54,10 +58,14 @@ def _readout(self, affinity, v): # this function is for a single object group return v @ affinity - def match_memory(self, query_key, selection): + def match_memory(self, query_key, selection, disable_usage_updates=False): # query_key: B x C^k x H x W # selection: B x C^k x H x W - num_groups = self.work_mem.num_groups + # TODO: keep groups in both..? + # 1x64x30x54 + + # = permanent_work_mem.num_groups, since it's always >= temporary_work_mem.num_groups + num_groups = max(self.temporary_work_mem.num_groups, self.permanent_work_mem.num_groups) h, w = query_key.shape[-2:] query_key = query_key.flatten(start_dim=2) @@ -67,79 +75,112 @@ def match_memory(self, query_key, selection): Memory readout using keys """ + temp_work_mem_size = self.temporary_work_mem.size if self.enable_long_term and self.long_mem.engaged(): # Use long-term memory long_mem_size = self.long_mem.size - memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1) - shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1) + + memory_key = torch.cat([self.long_mem.key, self.temporary_work_mem.key, self.permanent_work_mem.key], -1) + shrinkage = torch.cat([self.long_mem.shrinkage, self.temporary_work_mem.shrinkage, self.permanent_work_mem.shrinkage], -1) similarity = get_similarity(memory_key, shrinkage, query_key, selection) - work_mem_similarity = similarity[:, long_mem_size:] + long_mem_similarity = similarity[:, :long_mem_size] + temp_work_mem_similarity = similarity[:, long_mem_size:long_mem_size+temp_work_mem_size] + perm_work_mem_similarity = similarity[:, long_mem_size+temp_work_mem_size:] # get the usage with the first group # the first group always have all the keys valid affinity, usage = do_softmax( - torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1), - top_k=self.top_k, inplace=True, return_usage=True) + torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], temp_work_mem_similarity, perm_work_mem_similarity], 1), + top_k=self.top_k, inplace=True, return_usage=True) affinity = [affinity] # compute affinity group by group as later groups only have a subset of keys for gi in range(1, num_groups): + temp_group_v_size = self.temporary_work_mem.get_v_size(gi) + perm_group_v_size = self.permanent_work_mem.get_v_size(gi) + temp_sim_size = temp_work_mem_similarity.shape[1] + perm_sim_size = perm_work_mem_similarity.shape[1] + if gi < self.long_mem.num_groups: # merge working and lt similarities before softmax affinity_one_group = do_softmax( - torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):], - work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1), + torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):], + temp_work_mem_similarity[:, temp_sim_size-temp_group_v_size:], + perm_work_mem_similarity[:, perm_sim_size-perm_group_v_size:]], + dim=1), top_k=self.top_k, inplace=True) else: # no long-term memory for this group - affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):], - top_k=self.top_k, inplace=(gi==num_groups-1)) + affinity_one_group = do_softmax(torch.cat([ + temp_work_mem_similarity[:, temp_sim_size-temp_group_v_size:], + perm_work_mem_similarity[:, perm_sim_size-perm_group_v_size:]], + 1), + top_k=self.top_k, inplace=(gi == num_groups-1)) affinity.append(affinity_one_group) all_memory_value = [] - for gi, gv in enumerate(self.work_mem.value): + for gi in range(num_groups): # merge the working and lt values before readout if gi < self.long_mem.num_groups: - all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1)) + all_memory_value.append(torch.cat([self.long_mem.value[gi], self.temporary_work_mem.value[gi], self.permanent_work_mem.value[gi]], -1)) else: - all_memory_value.append(gv) + all_memory_value.append(torch.cat([self.temporary_work_mem.value[gi], self.permanent_work_mem.value[gi]], -1)) """ Record memory usage for working and long-term memory """ + if not disable_usage_updates: # ignore the index return for long-term memory - work_usage = usage[:, long_mem_size:] - self.work_mem.update_usage(work_usage.flatten()) + work_usage = usage[:, long_mem_size:long_mem_size+temp_work_mem_size] # no usage for permanent memory + self.temporary_work_mem.update_usage(work_usage.flatten()) - if self.enable_long_term_usage: - # ignore the index return for working memory - long_usage = usage[:, :long_mem_size] - self.long_mem.update_usage(long_usage.flatten()) + if self.enable_long_term_usage: + # ignore the index return for working memory + long_usage = usage[:, :long_mem_size] + self.long_mem.update_usage(long_usage.flatten()) else: + memory_key = torch.cat([self.temporary_work_mem.key, self.permanent_work_mem.key], -1) + shrinkage = torch.cat([self.temporary_work_mem.shrinkage, self.permanent_work_mem.shrinkage], -1) # No long-term memory - similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection) + similarity = get_similarity(memory_key, shrinkage, query_key, selection) + temp_work_mem_similarity = similarity[:, :temp_work_mem_size] + perm_work_mem_similarity = similarity[:, temp_work_mem_size:] if self.enable_long_term: - affinity, usage = do_softmax(similarity, inplace=(num_groups==1), - top_k=self.top_k, return_usage=True) - - # Record memory usage for working memory - self.work_mem.update_usage(usage.flatten()) + affinity, usage = do_softmax(similarity, inplace=(num_groups == 1), + top_k=self.top_k, return_usage=True) + if not disable_usage_updates: + # Record memory usage for working memory + self.temporary_work_mem.update_usage(usage[:, :temp_work_mem_size].flatten()) else: - affinity = do_softmax(similarity, inplace=(num_groups==1), - top_k=self.top_k, return_usage=False) + affinity = do_softmax(similarity, inplace=(num_groups == 1), + top_k=self.top_k, return_usage=False) affinity = [affinity] # compute affinity group by group as later groups only have a subset of keys for gi in range(1, num_groups): - affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):], - top_k=self.top_k, inplace=(gi==num_groups-1)) + temp_group_v_size = self.temporary_work_mem.get_v_size(gi) + perm_group_v_size = self.permanent_work_mem.get_v_size(gi) + temp_sim_size = temp_work_mem_similarity.shape[1] + perm_sim_size = perm_work_mem_similarity.shape[1] + + affinity_one_group = do_softmax( + torch.cat([ + # concats empty tensor if the group is also empty for temporary memory + temp_work_mem_similarity[:, temp_sim_size-temp_group_v_size:], + perm_work_mem_similarity[:, perm_sim_size-perm_group_v_size:], + ], dim=1), + top_k=self.top_k, inplace=(gi == num_groups-1) + ) affinity.append(affinity_one_group) - - all_memory_value = self.work_mem.value + + all_memory_value = [] + for gi in range(num_groups): + group_v_cat = torch.cat([self.temporary_work_mem.value[gi], self.permanent_work_mem.value[gi]], -1) + all_memory_value.append(group_v_cat) # Shared affinity within each group all_readout_mem = torch.cat([ @@ -149,7 +190,27 @@ def match_memory(self, query_key, selection): return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w) - def add_memory(self, key, shrinkage, value, objects, selection=None): + def update_permanent_memory(self, frame_idx, key, shrinkage, value, selection=None): + saved_pos = self.frame_id_to_permanent_mem_idx[frame_idx] + + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + value = value[0].flatten(start_dim=2) + + if selection is not None: + selection = selection.flatten(start_dim=2) + + self.permanent_work_mem.replace_at(saved_pos, key, value, shrinkage, selection) + + def remove_from_permanent_memory(self, frame_idx): + elem_size = self.HW + saved_pos = self.frame_id_to_permanent_mem_idx[frame_idx] + + self.permanent_work_mem.remove_at(saved_pos, elem_size) + + del self.frame_id_to_permanent_mem_idx[frame_idx] + + def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=False, ignore=False, ti=None): # key: 1*C*H*W # value: 1*num_objects*C*H*W # objects contain a list of object indices @@ -165,7 +226,7 @@ def add_memory(self, key, shrinkage, value, objects, selection=None): # key: 1*C*N # value: num_objects*C*N key = key.flatten(start_dim=2) - shrinkage = shrinkage.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) value = value[0].flatten(start_dim=2) self.CK = key.shape[1] @@ -176,18 +237,49 @@ def add_memory(self, key, shrinkage, value, objects, selection=None): warnings.warn('the selection factor is only needed in long-term mode', UserWarning) selection = selection.flatten(start_dim=2) - self.work_mem.add(key, value, shrinkage, selection, objects) - + if ignore: + pass # all permanent frames are pre-placed into permanent memory (when using our memory modification) + # also ignores the first frame (#0) when using original memory mechanism, since it's already in the permanent memory + elif permanent: + pos = self.permanent_work_mem.add(key, value, shrinkage, selection, objects) + if ti is not None: + self.frame_id_to_permanent_mem_idx[ti] = pos + else: + self.temporary_work_mem.add(key, value, shrinkage, selection, objects) + + + num_temp_groups = self.temporary_work_mem.num_groups + num_perm_groups = self.permanent_work_mem.num_groups + + if not self.temporary_work_mem.engaged() or (num_temp_groups != num_perm_groups): + # print(f"PERM_NUM_GROUPS={num_perm_groups} vs TEMP_NUM_GROUPS={num_temp_groups}", end=' ') + + # first frame or new group; we need to have both memories engaged to avoid crashes when concating + # so we just initialize the temporary one with an empty tensor + key0 = key[..., 0:0] + value0 = value[..., 0:0] + shrinkage0 = shrinkage[..., 0:0] + selection0 = selection[..., 0:0] + if num_perm_groups > num_temp_groups: + # for preloading into permanent memory + self.temporary_work_mem.add(key0, value0, shrinkage0, selection0, objects) + else: + # for original memory mechanism + self.permanent_work_mem.add(key0, value0, shrinkage0, selection0, objects) + + # print(f"AFTER->PERM_NUM_GROUPS={self.permanent_work_mem.num_groups} vs TEMP_NUM_GROUPS={self.temporary_work_mem.num_groups}") + # long-term memory cleanup if self.enable_long_term: # Do memory compressed if needed - if self.work_mem.size >= self.max_work_elements: + if self.temporary_work_mem.size >= self.max_work_elements: + # if we have more then N elements in the work memory # Remove obsolete features if needed if self.long_mem.size >= (self.max_long_elements-self.num_prototypes): self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes) - - self.compress_features() + # We NEVER remove anything from the working memory + self.compress_features() def create_hidden_state(self, n, sample_key): # n is the TOTAL number of objects @@ -196,46 +288,61 @@ def create_hidden_state(self, n, sample_key): self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device) elif self.hidden.shape[1] != n: self.hidden = torch.cat([ - self.hidden, + self.hidden, torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device) ], 1) - assert(self.hidden.shape[1] == n) + assert (self.hidden.shape[1] == n) def set_hidden(self, hidden): self.hidden = hidden def get_hidden(self): return self.hidden + + def frame_already_saved(self, ti): + return ti in self.frame_id_to_permanent_mem_idx + + # def slices_excluding_permanent(self, group_value, start, end): + # HW = self.HW + # group_value[:,:,HW:-self.min_work_elements+HW] + + # slices = [] + + # # this won't work because after just 1 consolidation all permanent frames are going to be god know where + # # and their indices would mean nothing + # # How about have 2 separate tensors and concatenate them just for memory reading? + # all_indices = torch.arange(self.temporary_work_mem.size // HW) # all frames indices from 0 to ... def compress_features(self): HW = self.HW candidate_value = [] - total_work_mem_size = self.work_mem.size - for gv in self.work_mem.value: + total_work_mem_size = self.temporary_work_mem.size + for gv in self.temporary_work_mem.value: # Some object groups might be added later in the video # So not all keys have values associated with all objects # We need to keep track of the key->value validity mem_size_in_this_group = gv.shape[-1] if mem_size_in_this_group == total_work_mem_size: # full LT - candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) + candidate_value.append(gv[:, :, :-self.min_work_elements]) else: # mem_size is smaller than total_work_mem_size, but at least HW assert HW <= mem_size_in_this_group < total_work_mem_size - if mem_size_in_this_group > self.min_work_elements+HW: + if mem_size_in_this_group > self.min_work_elements: # part of this object group still goes into LT - candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) + candidate_value.append(gv[:, :, :-self.min_work_elements]) else: # this object group cannot go to the LT at all candidate_value.append(None) # perform memory consolidation + # now starts at zero, because the 1st frame is going into permanent memory prototype_key, prototype_value, prototype_shrinkage = self.consolidation( - *self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value) + *self.temporary_work_mem.get_all_sliced(0, -self.min_work_elements), candidate_value) # remove consolidated working memory - self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW) + self.temporary_work_mem.sieve_by_range(0, -self.min_work_elements, min_size=self.min_work_elements+HW) # add to long-term memory self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None) @@ -282,3 +389,39 @@ def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None return prototype_key, prototype_value, prototype_shrinkage + + def copy_perm_mem_only(self): + new_mem = MemoryManager(config=self.config) + + if self.permanent_work_mem.key is None or self.permanent_work_mem.key.size(-1) == 0: + return new_mem + + new_mem.permanent_work_mem = self.permanent_work_mem + new_mem.frame_id_to_permanent_mem_idx = self.frame_id_to_permanent_mem_idx + + key0 = self.permanent_work_mem.key[..., 0:0] + value0 = self.permanent_work_mem.value[0][..., 0:0] + shrinkage0 = self.permanent_work_mem.shrinkage[..., 0:0] if self.permanent_work_mem.shrinkage is not None else None + selection0 = self.permanent_work_mem.selection[..., 0:0] if self.permanent_work_mem.selection is not None else None + + new_mem.temporary_work_mem.add(key0, value0, shrinkage0, selection0, self.permanent_work_mem.all_objects) + + new_mem.CK = self.permanent_work_mem.key.shape[1] + new_mem.CV = self.permanent_work_mem.value[0].shape[1] + + key_shape = self.permanent_work_mem.key.shape + sample_key = self.permanent_work_mem.key[..., 0:self.HW].view(*key_shape[:-1], self.H, self.W) + new_mem.create_hidden_state(len(self.permanent_work_mem.all_objects), sample_key) + + new_mem.temporary_work_mem.obj_groups = self.temporary_work_mem.obj_groups + new_mem.temporary_work_mem.all_objects = self.temporary_work_mem.all_objects + + + new_mem.CK = self.CK + new_mem.CV = self.CV + new_mem.H = self.H + new_mem.W = self.W + new_mem.HW = self.HW + + return new_mem + diff --git a/inference/run_experiments.py b/inference/run_experiments.py new file mode 100644 index 0000000..c25961d --- /dev/null +++ b/inference/run_experiments.py @@ -0,0 +1,453 @@ +import os +import json +from pathlib import Path +from typing import Any, Dict, List, Set, Tuple, Union + +import cv2 +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm +from matplotlib import pyplot as plt +from PIL import Image +from inference.frame_selection.frame_selection import uniformly_selected_frames +from util.metrics import batched_f_measure, batched_jaccard +from p_tqdm import p_umap + +# from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS +from inference.run_on_video import predict_annotation_candidates, run_on_video + +# ---------------BEGIN Inference and visualization utils -------------------------- + +def make_non_uniform_grid(rows_of_image_paths: List[List[str]], output_path: str, grid_size=3, resize_to: Tuple[int, int]=(854, 480)): + assert len(rows_of_image_paths) == grid_size + for row in rows_of_image_paths: + assert len(row) <= grid_size + + p_out_dir = Path(output_path) + if not p_out_dir.exists(): + p_out_dir.mkdir(parents=True) + num_frames = None + + for row in rows_of_image_paths: + for img_path_dir in row: + num_frames_in_dir = len(os.listdir(img_path_dir)) + if num_frames is None: + num_frames = num_frames_in_dir + else: + assert num_frames == num_frames_in_dir + + rows_of_iterators = [] + for row_of_image_dir_paths in rows_of_image_paths: + row = [] + for image_dir_path in row_of_image_dir_paths: + p = Path(image_dir_path) + iterator = iter(sorted(p.iterdir())) + row.append(iterator) + rows_of_iterators.append(row) + + for i in tqdm(range(num_frames)): + rows_of_frames = [] + for row in rows_of_iterators: + frames = [] + global_h, global_w = None, None + for iterator in row: + frame_path = str(next(iterator)) + frame = cv2.imread(frame_path) + h, w = frame.shape[0:2] + + if resize_to is not None: + desired_w, desired_h = resize_to + if h != desired_w or w != desired_w: + frame = cv2.resize(frame, (desired_w, desired_h)) + h, w = frame.shape[0:2] + + frames.append(frame) + + if global_h is None: + global_h, global_w = h, w + + wide_frame = np.concatenate(frames, axis=1) + + if len(frames) < grid_size: + pad_size = global_w * (grid_size - len(frames)) // 2 + # center the frame + wide_frame = np.pad(wide_frame, [(0, 0), (pad_size, pad_size), (0, 0)], mode='constant', constant_values=0) + rows_of_frames.append(wide_frame) + + big_frame = np.concatenate(rows_of_frames, axis=0) + cv2.imwrite(str(p_out_dir / f'frame_{i:06d}.png'), big_frame) + + +def visualize_grid(video_names: List[str], labeled=True): + for video_name in video_names: + p_in_general = Path( + f'/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/{video_name}/Overlay') + if labeled: + p_in_general /= 'Labeled' + + cycle = p_in_general / 'INTERNAL_CYCLE_CONSISTENCY' + ddiff = p_in_general / 'INTERNAL_DOUBLE_DIFF' + umap = p_in_general / 'UMAP_EUCLIDEAN' + pca_euclidean = p_in_general / 'PCA_EUCLIDEAN' + pca_cosine = p_in_general / 'PCA_COSINE' + one_frame_only = p_in_general / 'ONLY_ONE_FRAME' + baseline_uniform = p_in_general / 'BASELINE_UNIFORM' + baseline_human = p_in_general / 'HUMAN_CHOSEN' + ULTIMATE = p_in_general / 'ULTIMATE_AUTO' + + grid = [ + [cycle, ddiff, umap], + [pca_euclidean, pca_cosine, baseline_uniform], + [baseline_human, one_frame_only, ULTIMATE] + ] + if labeled: + p_out = p_in_general.parent.parent / 'All_combined' + else: + p_out = p_in_general.parent / 'All_combined_unlabeled' + + make_non_uniform_grid(grid, p_out, grid_size=3) + + +def get_videos_info(): + return { + 'long_scene': { + 'num_annotation_candidates': 3, # 3, + 'video_frames_path': '/home/maksym/RESEARCH/VIDEOS/long_scene/JPEGImages', + 'video_masks_path': '/home/maksym/RESEARCH/VIDEOS/long_scene/Annotations', + 'masks_out_path': '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/long_scene' + }, + 'long_scene_scale': { + 'num_annotation_candidates': 3, # 3, + 'video_frames_path': '/home/maksym/RESEARCH/VIDEOS/long_scene_scale/JPEGImages', + 'video_masks_path': '/home/maksym/RESEARCH/VIDEOS/long_scene_scale/Annotations', + 'masks_out_path': '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/long_scene_scale' + }, + 'ariana_smile': { + 'num_annotation_candidates': 3, # 3, + 'video_frames_path': '/home/maksym/RESEARCH/VIDEOS/Scenes_ariana_fixed_naming/smile/JPEGImages', + 'video_masks_path': '/home/maksym/RESEARCH/VIDEOS/Scenes_ariana_fixed_naming/smile/Annotations/Lips', + 'masks_out_path': '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/ariana_smile' + }, + 'ariana_blog': { + 'num_annotation_candidates': 5, # 5, + 'video_frames_path': '/home/maksym/RESEARCH/VIDEOS/Scenes_ariana_fixed_naming/blog/JPEGImages', + 'video_masks_path': '/home/maksym/RESEARCH/VIDEOS/Scenes_ariana_fixed_naming/blog/Annotations/Together', + 'masks_out_path': '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/ariana_blog' + }, + } + + +def run_multiple_frame_selectors(videos_info: Dict[str, Dict], csv_output_path: str, predictors: Dict[str, callable] = None, load_existing_masks=False): + output = pd.DataFrame(columns=list(predictors)) + p_bar = tqdm(total=len(videos_info) * len(predictors)) + + exceptions = pd.DataFrame(columns=['video', 'method', 'error_message']) + + for video_name, info in videos_info.items(): + video_frames_path = info['video_frames_path'] + num_candidate_frames = info['num_annotation_candidates'] + if load_existing_masks: + masks_first_frame_only = Path(info['masks_out_path']) / 'ONLY_ONE_FRAME' + else: + masks_first_frame_only = None + + results = {} + for method_name, method_func in predictors.items(): + try: + chosen_annotation_frames = predict_annotation_candidates( + video_frames_path, + num_candidates=num_candidate_frames, + candidate_selection_function=method_func, + masks_first_frame_only=masks_first_frame_only, + masks_in_path=info['video_masks_path'], + masks_out_path=Path(info['masks_out_path']) / 'FIRST_FRAME_ONLY' / 'masks', # used by some target-aware algorithms + print_progress=False + ) + except Exception as e: + print(f"[!!!] ERROR ({video_name},{method_name})={e}") + print("Resulting to uniform baseline") + chosen_annotation_frames = predict_annotation_candidates( + video_frames_path, + num_candidates=num_candidate_frames, + candidate_selection_function=KNOWN_ANNOTATION_PREDICTORS['UNIFORM'], + masks_in_path=info['video_masks_path'], + print_progress=False + ) + exceptions.append([video_name, method_name, str(e)]) + + torch.cuda.empty_cache() + results[method_name] = json.dumps(chosen_annotation_frames) + p_bar.update() + + output.loc[video_name] = results + + # save updated after every video + output.index.name = 'video_name' + output.to_csv(csv_output_path) + + if min(exceptions.shape) > 0: + exceptions.to_csv('output/exceptions.csv') + + +def run_inference_with_pre_chosen_frames(chosen_frames_csv_path: str, videos_info: Dict[str, Dict], output_path: str, only_methods_subset: Set[str] = None, compute_iou=False, IoU_results_save_path=None, **kwargs): + df = pd.read_csv(chosen_frames_csv_path, index_col='video_name') + if only_methods_subset is not None: + num_runs = df.shape[0] * len(only_methods_subset) + else: + num_runs = np.prod(df.shape) + + if compute_iou: + assert IoU_results_save_path is not None + p_iou_dir = Path(IoU_results_save_path) + + i = 0 + p_bar = tqdm(desc='Running inference comparing multiple different AL approaches', total=num_runs) + + for video_name, info in videos_info.items(): + video_row = df.loc[video_name] + # ious = {} + for method in video_row.index: + if only_methods_subset is not None and method not in only_methods_subset: + continue + + chosen_frames_str = video_row.loc[method] + chosen_frames = json.loads(chosen_frames_str) + + video_frames_path = info['video_frames_path'] + video_masks_path = info['video_masks_path'] + + output_masks_path = Path(output_path) / video_name / method + + stats = run_on_video(video_frames_path, video_masks_path, output_masks_path, + frames_with_masks=chosen_frames, compute_iou=compute_iou, print_progress=False, **kwargs) + + if compute_iou: + p_out_curr_video_method = p_iou_dir / video_name + if not p_out_curr_video_method.exists(): + p_out_curr_video_method.mkdir(parents=True) + + stats.to_csv(p_out_curr_video_method / f'{method}.csv')#f'output/AL_comparison_all_methods/{video_name}_{method}.csv') + # print(f"Video={video_name},method={method},IoU={stats['iou'].mean():.4f}") + # ious[f'{video_name}_{method}'] = [float(iou) for iou in stats['iou']] + + p_bar.update() + i += 1 + + # with open(f'output/AL_comparison_all_methods/ious_{video_name}_all_methods.json', 'wt') as f_out: + # json.dump(ious, f_out) + +def run_inference_with_uniform_frames(videos_info: Dict[str, Dict], output_path: str, **kwargs): + num_runs = len(videos_info) + + i = 0 + p_bar = tqdm(desc='Running inference comparing multiple different AL approaches', total=num_runs) + + for video_name, info in videos_info.items(): + frames = os.listdir(info['video_frames_path']) + chosen_frames = uniformly_selected_frames(frames, how_many_frames=info['num_annotation_candidates']) + + video_frames_path = info['video_frames_path'] + video_masks_path = info['video_masks_path'] + + output_masks_path = Path(output_path) / video_name + try: + stats = run_on_video(video_frames_path, video_masks_path, output_masks_path, + frames_with_masks=chosen_frames, compute_iou=False, print_progress=False, **kwargs) + except ValueError as e: + print(f"[!!!] {e}") + p_bar.update() + i += 1 + + +def visualize_chosen_frames(video_name: str, num_total_frames: int, data: pd.Series, output_path: str): + def _sort_index(series): + ll = list(series.index) + sorted_ll = sorted(ll, key=lambda x: str( + min(json.loads(series.loc[x])))) + return sorted_ll + + sorted_index = _sort_index(data) + plt.figure(figsize=(16, 10)) + plt.title(f"Chosen frames for {video_name}") + plt.xlim(-10, num_total_frames + 10) + num_methods = len(data.index) + + plt.ylim(-0.25, num_methods + 0.25) + + plt.xlabel('Frame number') + plt.ylabel('AL method') + + plt.yticks([]) # disable yticks + + previous_plots = [] + + for i, method_name in enumerate(sorted_index): + chosen_frames = json.loads(data.loc[method_name]) + num_frames = len(chosen_frames) + + x = sorted(chosen_frames) + y = [i for _ in chosen_frames] + plt.axhline(y=i, zorder=1, xmin=0.01, xmax=0.99) + + plt.scatter(x=x, y=y, label=method_name, s=256, zorder=3, marker="v") + if len(previous_plots) != 0: + for i in range(num_frames): + curr_x, curr_y = x[i], y[i] + prev_x, prev_y = previous_plots[-1][0][i], previous_plots[-1][1][i] + + plt.plot([prev_x, curr_x], [prev_y, curr_y], + linewidth=1, color='gray', alpha=0.5) + + previous_plots.append((x, y)) + + # texts = map(str, range(num_frames)) + # for i, txt in enumerate(texts): + # plt.annotate(txt, (x[i] + 2, y[i] + 0.1), zorder=4, fontproperties={'weight': 'bold'}) + + plt.legend() + p_out = Path(f'{output_path}/chosen_frames_{video_name}.png') + if not p_out.parent.exists(): + p_out.parent.mkdir(parents=True) + + plt.savefig(p_out, bbox_inches='tight') + +# -------------------------END Inference and visualization utils -------------------------- +# ------------------------BEGIN metrics --------------------------------------------------- + +def _load_gt(p): + return np.stack([np.array(Image.open(p_gt).convert('P')) for p_gt in sorted(p.iterdir())]) + + +def _load_preds(p, palette: Image.Image, size: tuple): + return np.stack([Image.open(p_gt).convert('RGB').resize(size, resample=Image.Resampling.NEAREST).quantize(palette=palette, dither=Image.Dither.NONE) for p_gt in sorted(p.iterdir())]) + +def compute_metrics_al(p_source_masks, p_preds, looped=True): + def _proc(p_video: Path): + video_name = p_video.name + p_gts = p_source_masks / p_video.name + first_mask = Image.open(next(p_gts.iterdir())).convert('P') + w, h = first_mask.size + gts = _load_gt(p_gts) + + stats = { + 'video_name': video_name + } + + for p_method in p_video.iterdir(): + if not p_method.is_dir(): + continue + method_name = p_method.name + p_masks = p_method / 'masks' + preds = _load_preds(p_masks, palette=first_mask, size=(w, h)) + + assert preds.shape == gts.shape + + iou = batched_jaccard(gts, preds) + avg_iou = iou.mean(axis=0) + + f_score = batched_f_measure(gts, preds) + avg_f_score = f_score.mean(axis=0) + + stats[f'{method_name}-iou'] = float(avg_iou) + stats[f'{method_name}-f'] = float(avg_f_score) + + if looped: + n = iou.shape[0] + between = int(0.9 * n) + first_part_iou = iou[:between].mean() + second_part_iou = iou[between:].mean() + + first_part_f_score = f_score[:between].mean() + second_part_f_score = f_score[between:].mean() + + stats[f'{method_name}-iou-90'] = float(first_part_iou) + stats[f'{method_name}-iou-10'] = float(second_part_iou) + stats[f'{method_name}-f-90'] = float(first_part_f_score) + stats[f'{method_name}-f-10'] = float(second_part_f_score) + + return stats + + list_of_stats = p_umap(_proc, list(p_preds.iterdir()), num_cpus=3) + + results = pd.DataFrame.from_records(list_of_stats).dropna(axis='columns').set_index('video_name') + return results + +def compute_metrics(p_source_masks, p_preds, pred_to_annot_names_lookup=None): + list_of_stats = [] + # for p_pred_video in list(p_preds.iterdir()): + def _proc(p_pred_video: Path): + video_name = p_pred_video.name + if pred_to_annot_names_lookup is not None: + video_name = pred_to_annot_names_lookup[video_name] + + # if 'XMem' in str(p_pred_video): + p_pred_video = Path(p_pred_video) / 'masks' + p_gts = p_source_masks / video_name + first_mask = Image.open(next(p_gts.iterdir())).convert('P') + w, h = first_mask.size + gts = _load_gt(p_gts) + + preds = _load_preds(p_pred_video, palette=first_mask, size=(w, h)) + + assert preds.shape == gts.shape + + avg_iou = batched_jaccard(gts, preds).mean(axis=0) + avg_f_score = batched_f_measure(gts, preds).mean(axis=0) + stats = { + 'video_name': video_name, + 'iou': float(avg_iou), + 'f': float(avg_f_score), + } + + return stats + # list_of_stats.append(stats) + # p_source_masks = Path('/home/maksym/RESEARCH/Datasets/MOSE/train/Annotations') + # p_preds = Path('/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/MOSE/AL_comparison') + + list_of_stats = p_umap(_proc, sorted(p_preds.iterdir(), key=lambda x: len(os.listdir(x)), reverse=True), num_cpus=4) + + results = pd.DataFrame.from_records(list_of_stats).dropna(axis='columns').set_index('video_name') + return results + +# -------------------------END metrics ------------------------------------------------------ + +def get_dataset_video_info(p_imgs_general, p_annotations_general, p_out_general, num_annotation_candidates=5): + videos_info = {} + + for p_video in sorted(p_imgs_general.iterdir(), key=lambda x: len(os.listdir(x)), reverse=True): # longest video first to avoid OOM in the future + video_name = p_video.name + p_masks = p_annotations_general / video_name + + videos_info[video_name] = dict( + num_annotation_candidates=num_annotation_candidates, + video_frames_path=p_video, + video_masks_path=p_masks, + masks_out_path=p_out_general / video_name + ) + + return videos_info + + + +if __name__ == "__main__": + pass + + # ## Usage examples + # ## Run from root-level directory, e.g. in `main.py` + + # ## Running multiple frame selectors, saving their predicted frame numbers to a .csv file + # run_multiple_frame_selectors(get_videos_info(), csv_output_path='output/al_videos_chosen_frames.csv') + + # ## Running and visualizing inference based on pre-calculated frames selected + # run_inference_with_pre_chosen_frames( + # chosen_frames_csv_path='output/al_videos_chosen_frames.csv', + # videos_info=get_videos_info(), + # output_path='/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/' + # ) + + # ## Concatenating multiple video results into a non-uniform grid + # visualize_grid( + # names=['long_scene', 'ariana_blog', 'ariana_smile', 'long_scene_scale'], + # labeled=True, + # ) diff --git a/inference/run_on_video.py b/inference/run_on_video.py new file mode 100644 index 0000000..0b4eb67 --- /dev/null +++ b/inference/run_on_video.py @@ -0,0 +1,367 @@ +from dataclasses import replace +from functools import partial +from multiprocessing import Process, Queue +from os import PathLike, path +from tempfile import TemporaryDirectory +from time import perf_counter +import time +from typing import Iterable, Optional, Union, List +from pathlib import Path +from warnings import warn + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from torchvision.transforms import functional as FT, ToTensor +from torch.utils.data import DataLoader +from tqdm import tqdm +from PIL import Image + +from inference.frame_selection.frame_selection import select_next_candidates +from model.network import XMem +from util.configuration import VIDEO_INFERENCE_CONFIG +from util.image_saver import ParallelImageSaver, create_overlay, save_image +from util.tensor_util import compute_array_iou +from inference.inference_core import InferenceCore +from inference.data.video_reader import Sample, VideoReader +from inference.data.mask_mapper import MaskMapper +from inference.frame_selection.frame_selection_utils import extract_keys, get_determenistic_augmentations + +def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out_path, + original_memory_mechanism=False, + compute_iou=False, + manually_curated_masks=False, + print_progress=True, + augment_images_with_masks=False, + overwrite_config: dict = None, + save_overlay=True, + object_color_if_single_object=(255, 255, 255), + print_fps=False, + image_saving_max_queue_size=200): + + torch.autograd.set_grad_enabled(False) + frames_with_masks = set(frames_with_masks) + + config = VIDEO_INFERENCE_CONFIG.copy() + overwrite_config = {} if overwrite_config is None else overwrite_config + overwrite_config['masks_out_path'] = masks_out_path + config.update(overwrite_config) + + mapper, processor, vid_reader, loader = _load_main_objects(imgs_in_path, masks_in_path, config) + vid_name = vid_reader.vid_name + vid_length = len(loader) + + at_least_one_mask_loaded = False + total_preloading_time = 0.0 + + if original_memory_mechanism: + # only the first frame goes into permanent memory originally + frames_to_put_in_permanent_memory = [0] + # the rest are going to be processed later + else: + # in our modification, all frames with provided masks go into permanent memory + frames_to_put_in_permanent_memory = frames_with_masks + at_least_one_mask_loaded, total_preloading_time = _preload_permanent_memory(frames_to_put_in_permanent_memory, vid_reader, mapper, processor, augment_images_with_masks=augment_images_with_masks) + + if not at_least_one_mask_loaded: + raise ValueError("No valid masks provided!") + + stats = [] + + total_processing_time = 0.0 + with ParallelImageSaver(config['masks_out_path'], vid_name=vid_name, overlay_color_if_b_and_w=object_color_if_single_object, max_queue_size=image_saving_max_queue_size) as im_saver: + for ti, data in enumerate(tqdm(loader, disable=not print_progress)): + with torch.cuda.amp.autocast(enabled=True): + data: Sample = data # Just for Intellisense + # No batch dimension here, just single samples + sample = replace(data, rgb=data.rgb.cuda()) + + if ti in frames_with_masks: + msk = sample.mask + else: + msk = None + + # Map possibly non-continuous labels to continuous ones + if msk is not None: + # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True + msk, labels = mapper.convert_mask( + msk.numpy(), exhaustive=True) + msk = torch.Tensor(msk).cuda() + if sample.need_resize: + msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] + processor.set_all_labels(list(mapper.remappings.values())) + else: + labels = None + + if original_memory_mechanism: + # we only ignore the first mask, since it's already in the permanent memory + do_not_add_mask_to_memory = (ti == 0) + else: + # we ignore all frames with masks, since they are already preloaded in the permanent memory + do_not_add_mask_to_memory = msk is not None + # Run the model on this frame + # 2+ channels, classes+ and background + a = perf_counter() + prob = processor.step(sample.rgb, msk, labels, end=(ti == vid_length-1), + manually_curated_masks=manually_curated_masks, do_not_add_mask_to_memory=do_not_add_mask_to_memory) + + # Upsample to original size if needed + out_mask = _post_process(sample, prob) + b = perf_counter() + total_processing_time += (b - a) + + curr_stat = {'frame': sample.frame, 'mask_provided': msk is not None} + if compute_iou: + gt = sample.mask # for IoU computations, original mask or None, NOT msk + if gt is not None and msk is None: # There exists a ground truth, but the model didn't see it + iou = float(compute_array_iou(out_mask, gt)) + else: + iou = -1 # skipping frames where the model saw the GT + curr_stat['iou'] = iou + stats.append(curr_stat) + + # Save the mask and the overlay (potentially) + + if config['save_masks']: + out_mask = mapper.remap_index_mask(out_mask) + out_img = Image.fromarray(out_mask) + out_img = vid_reader.map_the_colors_back(out_img) + + im_saver.save_mask(mask=out_img, frame_name=sample.frame) + + if save_overlay: + original_img = sample.raw_image_pil + im_saver.save_overlay(orig_img=original_img, mask=out_img, frame_name=sample.frame) + im_saver.wait_for_jobs_to_finish(verbose=True) + + if print_fps: + print(f"TOTAL PRELOADING TIME: {total_preloading_time:.4f}s") + print(f"TOTAL PROCESSING TIME: {total_processing_time:.4f}s") + print(f"TOTAL TIME (excluding image saving): {total_preloading_time + total_processing_time:.4f}s") + print(f"TOTAL PROCESSING FPS: {len(loader) / total_processing_time:.4f}") + print(f"TOTAL FPS (excluding image saving): {len(loader) / (total_preloading_time + total_processing_time):.4f}") + + return pd.DataFrame(stats) + +def _load_main_objects(imgs_in_path, masks_in_path, config): + model_path = config['model'] + network = XMem(config, model_path, pretrained_key_encoder=False, pretrained_value_encoder=False).cuda().eval() + if model_path is not None: + model_weights = torch.load(model_path) + network.load_weights(model_weights, init_as_zero_if_needed=True) + else: + warn('No model weights were loaded, as config["model"] was not specified.') + + mapper = MaskMapper() + processor = InferenceCore(network, config=config) + + vid_reader, loader = _create_dataloaders(imgs_in_path, masks_in_path, config) + return mapper,processor,vid_reader,loader + + +def _post_process(sample, prob): + if sample.need_resize: + prob = F.interpolate(prob.unsqueeze( + 1), sample.shape, mode='bilinear', align_corners=False)[:, 0] + + # Probability mask -> index mask + out_mask = torch.argmax(prob, dim=0) + out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + return out_mask + + +def _create_dataloaders(imgs_in_path: Union[str, PathLike], masks_in_path: Union[str, PathLike], config: dict): + vid_reader = VideoReader( + "", + imgs_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', + masks_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized_two_face', + size=config['size'], + use_all_masks=True + ) + + # Just return the samples as they are; only using DataLoader for preloading frames from the disk + loader = DataLoader(vid_reader, batch_size=None, shuffle=False, num_workers=1, collate_fn=VideoReader.collate_fn_identity) + + vid_length = len(loader) + # no need to count usage for LT if the video is not that long anyway + config['enable_long_term_count_usage'] = ( + config['enable_long_term'] and + (vid_length + / (config['max_mid_term_frames']-config['min_mid_term_frames']) + * config['num_prototypes']) + >= config['max_long_term_elements'] + ) + + return vid_reader,loader + + +def _preload_permanent_memory(frames_to_put_in_permanent_memory: List[int], vid_reader: VideoReader, mapper: MaskMapper, processor: InferenceCore, augment_images_with_masks=False): + total_preloading_time = 0 + at_least_one_mask_loaded = False + for j in frames_to_put_in_permanent_memory: + sample: Sample = vid_reader[j] + sample = replace(sample, rgb=sample.rgb.cuda()) + + # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True + if sample.mask is None: + raise FileNotFoundError(f"Couldn't find mask {j}! Check that the filename is either the same as for frame {j} or follows the `frame_%06d.png` format if using a video file for input.") + msk, labels = mapper.convert_mask(sample.mask, exhaustive=True) + msk = torch.Tensor(msk).cuda() + + if min(msk.shape) == 0: # empty mask, e.g. [1, 0, 720, 1280] + warn(f"Skipping adding frame {j} to permanent memory, as the mask is empty") + continue # just don't add anything to the memory + if sample.need_resize: + msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] + # sample = replace(sample, mask=msk) + + processor.set_all_labels(list(mapper.remappings.values())) + a = perf_counter() + processor.put_to_permanent_memory(sample.rgb, msk) + b = perf_counter() + total_preloading_time += (b - a) + + if not at_least_one_mask_loaded: + at_least_one_mask_loaded = True + + if augment_images_with_masks: + augs = get_determenistic_augmentations( + sample.rgb.shape, msk, subset='best_all') + rgb_raw = sample.raw_image_pil + + for img_aug, mask_aug in augs: + # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies + rgb_aug = vid_reader.im_transform(img_aug(rgb_raw)).cuda() + + msk_aug = mask_aug(msk) + + processor.put_to_permanent_memory(rgb_aug, msk_aug) + + return at_least_one_mask_loaded, total_preloading_time + + +def run_on_video( + imgs_in_path: Union[str, PathLike], + masks_in_path: Union[str, PathLike], + masks_out_path: Union[str, PathLike], + frames_with_masks: Iterable[int] = (0, ), + compute_iou=False, + print_progress=True, + **kwargs +) -> pd.DataFrame: + """ + Args: + imgs_in_path (Union[str, PathLike]): Path to the directory containing video frames in the following format: `frame_000000.png`. .jpg works too. + + masks_in_path (Union[str, PathLike]): Path to the directory containing video frames' masks in the same format, with corresponding names between video frames. Each unique object should have unique color. + + masks_out_path (Union[str, PathLike]): Path to the output directory (will be created if doesn't exist) where the predicted masks will be stored in .png format. + + frames_with_masks (Iterable[int]): A list of integers representing the frames on which the masks should be applied (default: [0], only applied to the first frame). 0-based. + + compute_iou (bool): A flag to indicate whether to compute the IoU metric (default: False, requires ALL video frames to have a corresponding mask). + + print_progress (bool): A flag to indicate whether to print a progress bar (default: True). + + Returns: + stats (pd.Dataframe): a table containing every frame and the following information: IoU score with corresponding mask (if `compute_iou` is True) + """ + + return _inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, + masks_out_path=masks_out_path, + frames_with_masks=frames_with_masks, + compute_iou=compute_iou, + print_progress=print_progress, + **kwargs + ) + + +def select_k_next_best_annotation_candidates( + imgs_in_path: Union[str, PathLike], + masks_in_path: Union[str, PathLike], # at least the 1st frame + masks_out_path: Optional[Union[str, PathLike]] = None, + k: int = 5, + print_progress=True, + previously_chosen_candidates=[0], + use_previously_predicted_masks=True, + # Candidate selection hyperparameters + alpha=0.5, + min_mask_presence_percent=0.25, + **kwargs +): + """ + Selects the next best annotation candidate frames based on the provided frames and mask paths. + + Parameters: + imgs_in_path (Union[str, PathLike]): The path to the directory containing input images. + masks_in_path (Union[str, PathLike]): The path to the directory containing the first frame masks. + masks_out_path (Optional[Union[str, PathLike]], optional): The path to save the generated masks. + If not provided, a temporary directory will be used. Defaults to None. + k (int, optional): The number of next best annotation candidate frames to select. Defaults to 5. + print_progress (bool, optional): Whether to print progress during processing. Defaults to True. + previously_chosen_candidates (list, optional): List of indices of frames with previously chosen candidates. + Defaults to [0]. + use_previously_predicted_masks (bool, optional): Whether to use previously predicted masks. + If True, `masks_out_path` must be provided. Defaults to True. + alpha (float, optional): Hyperparameter controlling the candidate selection process. Defaults to 0.5. + min_mask_presence_percent (float, optional): Minimum mask presence percentage for candidate selection. + Defaults to 0.25. + **kwargs: Additional keyword arguments to pass to `run_on_video`. + + Returns: + list: A list of indices representing the selected next best annotation candidate frames. + """ + mapper, processor, vid_reader, loader = _load_main_objects(imgs_in_path, masks_in_path, VIDEO_INFERENCE_CONFIG) + + # Extracting "key" feature maps + # Could be combined with inference (like in GUI), but the code would be a mess + frame_keys, shrinkages, selections, *_ = extract_keys(loader, processor, print_progress=print_progress, flatten=False) + # extracting the keys and corresponding matrices + + to_tensor = ToTensor() + if masks_out_path is not None: + p_masks_out = Path(masks_out_path) + + if use_previously_predicted_masks: + print("Using existing predicted masks, no need to run inference.") + assert masks_out_path is not None, "When `use_existing_masks=True`, you need to put the path to previously predicted masks in `masks_out_path`" + try: + masks = [to_tensor(Image.open(p)) for p in sorted((p_masks_out / 'masks').iterdir())] + except Exception as e: + warn("Loading previously predicting masks failed for `select_k_next_best_annotation_candidates`.") + raise e + if len(masks) != len(frame_keys): + raise FileNotFoundError(f"Not enough masks ({len(masks)}) for {len(frame_keys)} frames provided when using `use_previously_predicted_masks=True`!") + else: + print("Existing predictions were not given, will run full inference and save masks in `masks_out_path` or a temporary directory if `masks_out_path` is not given.") + if masks_out_path is None: + d = TemporaryDirectory() + p_masks_out = Path(d) + + # running inference once to obtain masks + run_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, # Ignored + masks_out_path=p_masks_out, # Used for some frame selectors + frames_with_masks=previously_chosen_candidates, + compute_iou=False, + print_progress=print_progress, + **kwargs + ) + + masks = [to_tensor(Image.open(p)) for p in sorted((p_masks_out / 'masks').iterdir())] + + keys = torch.cat(frame_keys) + shrinkages = torch.cat(shrinkages) + selections = torch.cat(selections) + + new_selected_candidates = select_next_candidates(keys, shrinkages=shrinkages, selections=selections, masks=masks, num_next_candidates=k, previously_chosen_candidates=previously_chosen_candidates, print_progress=print_progress, alpha=alpha, only_new_candidates=True, min_mask_presence_percent=min_mask_presence_percent) + + if masks_out_path is None: + # Remove the temporary directory + d.cleanup() + + return new_selected_candidates \ No newline at end of file diff --git a/interactive_demo.py b/interactive_demo.py index fef2140..1fc490d 100644 --- a/interactive_demo.py +++ b/interactive_demo.py @@ -46,7 +46,7 @@ parser.add_argument('--buffer_size', help='Correlate with CPU memory consumption', type=int, default=100) - parser.add_argument('--num_objects', type=int, default=1) + parser.add_argument('--num_objects', type=int, default=None) # Long-memory options # Defaults. Some can be changed in the GUI. @@ -71,7 +71,7 @@ with torch.cuda.amp.autocast(enabled=not args.no_amp): # Load our checkpoint - network = XMem(config, args.model).cuda().eval() + network = XMem(config, args.model, pretrained_key_encoder=False, pretrained_value_encoder=False).cuda().eval() # Loads the S2M model if args.s2m_model is not None: @@ -81,15 +81,18 @@ else: s2m_model = None - s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255) + # Manages most IO + config['num_objects_default_value'] = 1 + resource_manager = ResourceManager(config) + num_objects = resource_manager.num_objects + config['num_objects'] = num_objects + + s2m_controller = S2MController(s2m_model, num_objects, ignore_class=255) if args.fbrs_model is not None: fbrs_controller = FBRSController(args.fbrs_model) else: fbrs_controller = None - # Manages most IO - resource_manager = ResourceManager(config) - app = QApplication(sys.argv) ex = App(network, resource_manager, s2m_controller, fbrs_controller, config) sys.exit(app.exec_()) diff --git a/main.py b/main.py new file mode 100644 index 0000000..c225a7f --- /dev/null +++ b/main.py @@ -0,0 +1,43 @@ +import os +import random +from inference.run_on_video import run_on_video, select_k_next_best_annotation_candidates + +if __name__ == '__main__': + # If pytorch cannot download the weights due to an ssl error, uncomment the following lines + # import ssl + # ssl._create_default_https_context = ssl._create_unverified_context + + # Run inference on a video file with preselected annotated frames + video_path = 'example_videos/chair/chair.mp4' + masks_path = 'example_videos/chair/Annotations' + output_path = 'output/example_video_chair_from_mp4' + frames_with_masks = [5, 10, 15] + + run_on_video(video_path, masks_path, output_path, frames_with_masks) + + # Run inference on extracted .jpg frames with preselected annotations + imgs_path = 'example_videos/caps/JPEGImages' + masks_path = 'example_videos/caps/Annotations' + output_path = 'output/example_video_caps' + frames_with_masks = [0, 14, 33, 43, 66] + + run_on_video(imgs_path, masks_path, output_path, frames_with_masks) + + # Get proposals for the next 3 best annotation candidates using previously predicted masks + # If you don't have previous predictions, just put `use_previously_predicted_masks=False`, the algorithm will run inference internally + next_candidates = select_k_next_best_annotation_candidates(imgs_path, masks_path, output_path, previously_chosen_candidates=frames_with_masks, use_previously_predicted_masks=True) + print("Next candidates for annotations are: ") + for idx in next_candidates: + print(f"\tFrame {idx}") + + # Run inference on a video with all annotations provided, compute IoU + imgs_path = 'example_videos/chair/JPEGImages' + masks_path = 'example_videos/chair/Annotations' + output_path = 'output/example_video_chair' + + num_frames = len(os.listdir(imgs_path)) + frames_with_masks = random.sample(range(0, num_frames), 3) # Give 3 random masks as GT annotations + + stats = run_on_video(imgs_path, masks_path, output_path, frames_with_masks, compute_iou=True) # stats: pandas DataFrame + mean_iou = stats[stats['iou'] != -1]['iou'].mean() # -1 is for GT annotations, we just skip them + print(f"Average IoU: {mean_iou}") # Should be 90%+ as a sanity check \ No newline at end of file diff --git a/model/modules.py b/model/modules.py index 9920799..652cf3a 100644 --- a/model/modules.py +++ b/model/modules.py @@ -100,11 +100,11 @@ def forward(self, g, h): class ValueEncoder(nn.Module): - def __init__(self, value_dim, hidden_dim, single_object=False): + def __init__(self, value_dim, hidden_dim, single_object=False, pretrained=True): super().__init__() self.single_object = single_object - network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2) + network = resnet.resnet18(pretrained=pretrained, extra_dim=1 if single_object else 2) self.conv1 = network.conv1 self.bn1 = network.bn1 self.relu = network.relu # 1/2, 64 @@ -124,13 +124,13 @@ def __init__(self, value_dim, hidden_dim, single_object=False): def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True): # image_feat_f16 is the feature from the key encoder if not self.single_object: - g = torch.stack([masks, others], 2) + g_1 = torch.stack([masks, others], 2) else: - g = masks.unsqueeze(2) - g = self.distributor(image, g) + g_1 = masks.unsqueeze(2) + g_2 = self.distributor(image, g_1) - batch_size, num_objects = g.shape[:2] - g = g.flatten(start_dim=0, end_dim=1) + batch_size, num_objects = g_2.shape[:2] + g = g_2.flatten(start_dim=0, end_dim=1) g = self.conv1(g) g = self.bn1(g) # 1/2, 64 @@ -151,9 +151,9 @@ def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True): class KeyEncoder(nn.Module): - def __init__(self): + def __init__(self, pretrained=True): super().__init__() - network = resnet.resnet50(pretrained=True) + network = resnet.resnet50(pretrained=pretrained) self.conv1 = network.conv1 self.bn1 = network.bn1 self.relu = network.relu # 1/2, 64 diff --git a/model/network.py b/model/network.py index c5f179d..124dd42 100644 --- a/model/network.py +++ b/model/network.py @@ -15,7 +15,7 @@ class XMem(nn.Module): - def __init__(self, config, model_path=None, map_location=None): + def __init__(self, config, model_path=None, map_location=None, pretrained_key_encoder=True, pretrained_value_encoder=True): """ model_path/map_location are used in evaluation only map_location is for converting models saved in cuda to cpu @@ -26,8 +26,8 @@ def __init__(self, config, model_path=None, map_location=None): self.single_object = config.get('single_object', False) print(f'Single object mode: {self.single_object}') - self.key_encoder = KeyEncoder() - self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object) + self.key_encoder = KeyEncoder(pretrained=pretrained_key_encoder) + self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object, pretrained=pretrained_value_encoder) # Projection from f16 feature space to key/value space self.key_proj = KeyProjection(1024, self.key_dim) diff --git a/process_video.py b/process_video.py new file mode 100644 index 0000000..cea3132 --- /dev/null +++ b/process_video.py @@ -0,0 +1,30 @@ +import argparse +import re +from pathlib import Path + +from inference.run_on_video import run_on_video + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Process video frames given a few (1+) existing annotation masks') + parser.add_argument('--video', type=str, help='Path to the video file or directory with .jpg video frames to process', required=True) + parser.add_argument('--masks', type=str, help='Path to the directory with individual .png masks for corresponding video frames, named `frame_000000.png`, `frame_000123.png`, ... or similarly (the script searches for the first integer value in the filename). ' + 'Will use all masks int the directory.', required=True) + parser.add_argument('--output', type=str, help='Path to the output directory where to save the resulting segmentation masks and overlays. ' + 'Will be automatically created if does not exist', required=True) + + args = parser.parse_args() + + frames_with_masks = [] + for file_path in (p for p in Path(args.masks).iterdir() if p.is_file()): + frame_number_match = re.search(r'\d+', file_path.stem) + if frame_number_match is None: + print(f"ERROR: file {file_path} does not contain a frame number. Cannot load it as a mask.") + exit(1) + frames_with_masks.append(int(frame_number_match.group())) + + print("Using masks for frames: ", frames_with_masks) + + p_out = Path(args.output) + p_out.mkdir(parents=True, exist_ok=True) + run_on_video(args.video, args.masks, args.output) diff --git a/requirements.txt b/requirements.txt index 115c19a..2caec47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,4 @@ progressbar2 -gdown -gitpython -git+https://github.com/cheind/py-thin-plate-spline -hickle -tensorboard -numpy \ No newline at end of file +numpy +pandas +tqdm \ No newline at end of file diff --git a/run_gui_in_docker.sh b/run_gui_in_docker.sh new file mode 100644 index 0000000..839786b --- /dev/null +++ b/run_gui_in_docker.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +key='' +key_name='' +value='' + +# Parsing keyword arguments +while [ $# -gt 0 ]; do + if [ -z "${key}" ]; then + case "$1" in + --images|--video|--workspace) + key="other" + key_name="${1}" + ;; + --num_objects) + key="--num_objects" + ;; + *) + printf "***************************\n" + printf "* Error: Invalid argument ${1}\n" + printf "* Specify one of --images --video or --workspace with