diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 0000000000..b79aa2ff06
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1,21 @@
+# Setting up CODEOWNERS for UST related codebase
+# Documentation for open sourced models relevant to UST
+examples/speech_to_text @kahne @sravyapopuri388 @jmp84
+examples/speech_to_speech @an918tw @sravyapopuri388 @jmp84
+examples/speech_synthesis @kahne @jmp84
+examples/simultaneous_translation @kahne @jmp84
+examples/speech_text_joint_to_text @yuntang @jmp84
+
+# Speech related models relevant to UST
+fairseq/models/speech_to_speech @sravyapopuri388 @jmp84
+fairseq/models/speech_to_text @kahne @sravyapopuri388 @jmp84
+fairseq/models/text_to_speech @kahne @jmp84
+
+# CONFORMER IMPLEMENTATION
+fairseq/modules/conformer_layer.py @sravyapopuri388 @jmp84
+fairseq/modules/espnet_multihead_attention.py @sravyapopuri388 @jmp84
+fairseq/modules/rotary_positional_embedding.py @sravyapopuri388 @jmp84
+fairseq/modules/positional_encoding.py @sravyapopuri388 @jmp84
+
+# Machine Translation/NLLB
+fairseq/tasks/translation.py @gwenzek
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index a7f4f0a902..aa15123d8e 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -19,7 +19,7 @@ Steps to reproduce the behavior (**always include the command you ran**):
#### Code sample
-
### Expected behavior
@@ -28,7 +28,7 @@ Minimal means having the shortest code but still preserving the bug. -->
### Environment
- - fairseq Version (e.g., 1.0 or master):
+ - fairseq Version (e.g., 1.0 or main):
- PyTorch Version (e.g., 1.0)
- OS (e.g., Linux):
- How you installed fairseq (`pip`, source):
diff --git a/.github/ISSUE_TEMPLATE/how-to-question.md b/.github/ISSUE_TEMPLATE/how-to-question.md
index 4beb180dbf..04f3f15d3e 100644
--- a/.github/ISSUE_TEMPLATE/how-to-question.md
+++ b/.github/ISSUE_TEMPLATE/how-to-question.md
@@ -6,9 +6,9 @@ labels: 'question, needs triage'
## ❓ Questions and Help
-### Before asking:
-1. search the issues.
-2. search the docs.
+### Before asking:
+1. search the issues.
+2. search the docs.
@@ -16,13 +16,13 @@ labels: 'question, needs triage'
#### Code
-
+
#### What have you tried?
#### What's your environment?
- - fairseq Version (e.g., 1.0 or master):
+ - fairseq Version (e.g., 1.0 or main):
- PyTorch Version (e.g., 1.0)
- OS (e.g., Linux):
- How you installed fairseq (`pip`, source):
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index b28ff98e7b..d005e2df4f 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -1,15 +1,15 @@
# Before submitting
- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
-- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
-- [ ] Did you make sure to update the docs?
-- [ ] Did you write any new necessary tests?
+- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
+- [ ] Did you make sure to update the docs?
+- [ ] Did you write any new necessary tests?
## What does this PR do?
Fixes # (issue).
-## PR review
-Anyone in the community is free to review the PR once the tests have passed.
+## PR review
+Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
## Did you have fun?
diff --git a/.github/stale.yml b/.github/stale.yml
new file mode 100644
index 0000000000..b12867dab0
--- /dev/null
+++ b/.github/stale.yml
@@ -0,0 +1,30 @@
+# Configuration for probot-stale - https://github.com/probot/stale
+# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
+# Number of days of inactivity before an issue becomes stale
+daysUntilStale: 90
+# Number of days of inactivity before a stale issue is closed
+daysUntilClose: 7
+# Issues with these labels will never be considered stale
+exemptLabels:
+ - bug
+# Label to use when marking an issue as stale
+staleLabel: stale
+issues:
+ # Comment to post when marking an issue as stale.
+ markComment: >
+ This issue has been automatically marked as stale.
+ **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
+ We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
+ # Comment to post when closing a stale issue.
+ closeComment: >
+ Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
+pulls:
+ # Comment to post when marking a pull request as stale.
+ markComment: >
+ This pull request has been automatically marked as stale.
+ **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
+ We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
+ # Comment to post when closing a stale pull request.
+ closeComment: >
+ Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
+
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 6ae8093a8a..036233d8cf 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -1,10 +1,10 @@
name: build
on:
- # Trigger the workflow on push to master or any pull request
+ # Trigger the workflow on push to main or any pull request
push:
branches:
- - master
+ - main
pull_request:
jobs:
@@ -14,31 +14,68 @@ jobs:
max-parallel: 4
matrix:
platform: [ubuntu-latest, macos-latest]
- python-version: [3.6, 3.7]
+ python-version: [3.8, 3.9]
runs-on: ${{ matrix.platform }}
steps:
- - uses: actions/checkout@v1
+ - uses: actions/checkout@v2
+
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
+
- name: Conditionally install pytorch
if: matrix.platform == 'windows-latest'
run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
+
- name: Install locally
run: |
python -m pip install --upgrade pip
- python setup.py build_ext --inplace
- python -m pip install --editable .
+ git submodule update --init --recursive
+ python -m pip install .
+
+ - name: Check installation
+ working-directory: /tmp
+ run: python $GITHUB_WORKSPACE/scripts/check_installation.py
+
+ - name: Install optional test requirements
+ run: |
+ python -m pip install '.[dev,docs]'
+ python -m pip install iopath transformers pyarrow
+ python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
+ python -m pip install pygit2 pgzip
+
+ - name: Install xformers for Macos
+ if: matrix.platform == 'macos-latest'
+ run: |
+ brew install llvm libomp
+ CC=/usr/local/opt/llvm/bin/clang CXX=clang++ pip install git+https://github.com/facebookresearch/xformers.git@main
+
+ - name: Install xformers for non-MacOS
+ if: matrix.platform != 'macos-latest'
+ run: |
+ python -m pip install --progress-bar off git+https://github.com/facebookresearch/xformers.git@main
+
+ - name: Lint with black
+ run: black --check --diff .
+
- name: Lint with flake8
run: |
- pip install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
+
+ - name: Build doc
+ run: make singlehtml
+ working-directory: docs/
+
- name: Run tests
- run: |
- python setup.py test
+ # When installing in non-editable mode, the .so files will be generated in 'site-packages/fairseq'.
+ # But by default, pytest import machinery will load local fairseq, and won't see the .so.
+ # Use --import-mode=append to favorize the 'site-packages/fairseq'.
+ # https://docs.pytest.org/en/7.1.x/explanation/pythonpath.html
+ run: pytest --import-mode=append -vvv tests/
+
diff --git a/.github/workflows/depreview.yml b/.github/workflows/depreview.yml
new file mode 100644
index 0000000000..032eddef5f
--- /dev/null
+++ b/.github/workflows/depreview.yml
@@ -0,0 +1,14 @@
+name: 'Dependency Review'
+on: [pull_request]
+
+permissions:
+ contents: read
+
+jobs:
+ dependency-review:
+ runs-on: ubuntu-latest
+ steps:
+ - name: 'Checkout Repository'
+ uses: actions/checkout@v4
+ - name: Dependency Review
+ uses: actions/dependency-review-action@v4
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 0000000000..241b74b32d
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,161 @@
+name: Fairseq Release
+
+on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'Release Type'
+ default: 'patch'
+ required: true
+
+jobs:
+
+ get_next_version:
+ runs-on: ubuntu-latest
+ steps:
+ - name: checkout-repo-content
+ uses: actions/checkout@v2
+
+ - name: setup-python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.8
+
+ - name: get next version and tag
+ id: get-next-version-and-tag
+ run: |
+ output=$(python3 release_utils.py --release-type ${{ github.event.inputs.name }})
+ echo $output
+ new_version=$(echo $output | awk '{print $1}')
+ new_tag=$(echo $output | awk '{print $2}')
+ echo "new version is $new_version"
+ echo "new tag is $new_tag"
+ echo ::set-output name=version::$new_version
+ echo ::set-output name=tag::$new_tag
+ echo ::set-output name=branch_name::$new_version-release
+ echo "NEW_TAG=$new_tag" >> $GITHUB_ENV
+ echo "NEW_BRANCH=$new_version-release" >> $GITHUB_ENV
+
+
+ # update the version number in version.txt
+ - name: update version
+ id: update-version
+ run : |
+ echo "current folder = $PWD"
+ echo "current branch = $(git branch --show-current)"
+ output=$(python3 release_utils.py --release-type ${{ github.event.inputs.name }} --update-version)
+
+ - name: add and commit
+ uses: EndBug/add-and-commit@v9
+ with:
+ author_name: ${{ secrets.AUTHOR_NAME }}
+ author_email: ${{ secrets.AUTHOR_EMAIL }}
+
+ # TODO: change this to main once shipit is disabled.
+ new_branch: '${{ env.NEW_BRANCH }}'
+ default_author: github_actor
+ message: '${{ env.NEW_TAG }} release'
+ pathspec_error_handling: exitAtEnd
+
+ # Arguments for the git pull command. Use NO-PULL to avoid the action pulling at all.
+ # pull: 'NO-PULL'
+ tag: '${{ env.NEW_TAG }}'
+
+ outputs:
+ new_version: ${{ steps.get-next-version-and-tag.outputs.version }}
+ new_tag: ${{ steps.get-next-version-and-tag.outputs.tag }}
+ branch_name: ${{ steps.get-next-version-and-tag.outputs.branch_name }}
+
+ create_sdist:
+ runs-on: ubuntu-latest
+ name: Create Source Distribution
+ needs: get_next_version
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ ref: ${{ needs.get_next_version.outputs.branch_name }}
+
+ - name: Install Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.8'
+
+ - name: Upgrade pip
+ run: |
+ python3 -m pip install --upgrade pip
+
+ - name: Create Source Distribution
+ run: |
+ python3 -m pip install setuptools wheel twine torch
+ python3 setup.py sdist
+
+ - uses: actions/upload-artifact@v2
+ with:
+ path: dist/*.tar.gz
+
+ build_wheels:
+ name: Build wheels on ${{ matrix.os }}
+ runs-on: ${{ matrix.os }}
+ needs: get_next_version
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest]
+
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ ref: ${{ needs.get_next_version.outputs.branch_name }}
+
+ - name: Install Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.8'
+
+ - name: Upgrade pip
+ run: |
+ python3 -m pip install --upgrade pip
+
+ - name: Install cibuildwheel
+ run: |
+ python3 -m pip install cibuildwheel
+
+ - name: Build wheels for CPython
+ run: |
+ python3 -m cibuildwheel --output-dir dist
+ env:
+ CIBW_BUILD: "cp38-*64"
+ CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
+ CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
+ # Install system library
+ CIBW_BEFORE_BUILD_LINUX: (yum install -y libffi-devel || apt-get install -y libffi-devel || apk add --update --no-cache libffi-devel || true) && (yum install -y libc6 || apt-get install -y libc6 || apk add --update --no-cache libc6 || true)
+ CIBW_ENVIRONMENT: "PIP_ONLY_BINARY=numpy"
+ CIBW_SKIP: "*musllinux*"
+
+ - uses: actions/upload-artifact@v2
+ with:
+ path: dist
+
+ upload:
+ name: Upload to PyPi and create release
+ runs-on: ubuntu-latest
+ needs: [build_wheels, create_sdist, get_next_version]
+ steps:
+ - uses: actions/download-artifact@v2
+ with:
+ name: artifact
+ path: dist
+
+ # build the PyPI package and upload it
+ - name: upload
+ env:
+ TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
+ run: |
+ pip install setuptools wheel twine
+ python3 -m twine upload --repository pypi dist/*
+
+ # create the release on github
+ - name: create release on github
+ uses: ncipollo/release-action@v1
+ with:
+ tag: '${{ needs.get_next_version.outputs.new_tag }}'
diff --git a/.gitignore b/.gitignore
index 9546cffd90..4be13638de 100644
--- a/.gitignore
+++ b/.gitignore
@@ -131,3 +131,11 @@ data-bin/
# Experimental Folder
experimental/*
+
+# Weights and Biases logs
+wandb/
+
+# Hydra artifacts
+nohup.out
+multirun
+outputs
diff --git a/.gitmodules b/.gitmodules
index df0d3d3071..07a55d45d4 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,7 +1,3 @@
-[submodule "fairseq/models/huggingface/transformers"]
- path = fairseq/models/huggingface/transformers
- url = https://github.com/myleott/transformers.git
- branch = fairseq
[submodule "fairseq/model_parallel/megatron"]
path = fairseq/model_parallel/megatron
url = https://github.com/ngoyal2707/Megatron-LM
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000..6b1d6aed8c
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,40 @@
+exclude: 'build|stubs'
+
+default_language_version:
+ python: python3
+
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.1.0
+ hooks:
+ - id: trailing-whitespace
+ - id: check-ast
+ - id: check-merge-conflict
+ - id: no-commit-to-branch
+ args: ['--branch=master']
+ - id: check-added-large-files
+ args: ['--maxkb=500']
+ - id: end-of-file-fixer
+
+- repo: https://github.com/ambv/black
+ rev: 22.3.0
+ hooks:
+ - id: black
+ language_version: python3.8
+
+- repo: https://gitlab.com/pycqa/flake8
+ rev: 3.9.2
+ hooks:
+ - id: flake8
+ args: [
+ # only error for syntax errors and undefined names
+ "--select=E9,F63,F7,F82",
+ ]
+
+- repo: https://github.com/pycqa/isort
+ rev: 5.10.1
+ hooks:
+ - id: isort
+ exclude: README.md
+ additional_dependencies: [toml]
+ args: ["--profile", "black"]
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 4d7ca6a98e..60e9025887 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -5,7 +5,7 @@ possible.
## Pull Requests
We actively welcome your pull requests.
-1. Fork the repo and create your branch from `master`.
+1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
@@ -26,3 +26,57 @@ clear and has sufficient instructions to be able to reproduce the issue.
By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
you agree that your contributions will be licensed under the LICENSE file in
the root directory of this source tree.
+
+## Pre-commit hooks
+In order to ensure your code lints, there are pre-commit hooks configured in the repository which you can install.
+After installation, they will automatically run each time you commit.
+An abbreviated guide is given below; for more information, refer to [the offical pre-commit documentation](https://pre-commit.com/).
+
+### Installation
+```
+pip install pre-commit
+pre-commit install
+```
+
+### Usage
+Just commit your changes:
+```
+git commit -m "My informative commit message"
+```
+
+If there was a failure, you will get feedback
+```
+[INFO] Initializing environment for https://github.com/PyCQA/flake8.
+[INFO] Installing environment for https://github.com/pre-commit/pre-commit-hooks.
+[INFO] Once installed this environment will be reused.
+[INFO] This may take a few minutes...
+[INFO] Installing environment for https://github.com/PyCQA/flake8.
+[INFO] Once installed this environment will be reused.
+[INFO] This may take a few minutes...
+Trim Trailing Whitespace.................................................Failed
+- hook id: trailing-whitespace
+- exit code: 1
+- files were modified by this hook
+Fixing examples/nllb/modeling/wmt15_benchmark/eval_langs2.sh
+Fix End of Files.........................................................Failed
+- hook id: end-of-file-fixer
+- exit code: 1
+- files were modified by this hook
+Fixing examples/few_shot/scripts/schedule_jobs_few_shot.py
+flake8...................................................................Passed
+```
+
+Certain hooks modify your files to comply.
+To include these modifications, you will need to add them (i.e. `git add ...`) and commit again.
+
+If all is well, you should see something like:
+```
+Trim Trailing Whitespace.................................................Passed
+Fix End of Files.........................................................Passed
+flake8...................................................................Passed
+[gshard-fix-ci 8698644e1] Fix lint, add pre-commit hooks
+ 10 files changed, 148 insertions(+), 110 deletions(-)
+ create mode 100644 .flake8
+ create mode 100644 .pre-commit-config.yaml
+ rename examples/nllb/modeling/wmt15_benchmark/{eval_langs2.py => eval_langs2.sh} (99%)
+ ```
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000..4f719da85c
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include fairseq/version.txt
diff --git a/README.md b/README.md
index 56ec16cdab..1150c66cbe 100644
--- a/README.md
+++ b/README.md
@@ -2,10 +2,12 @@
-
+
+
+
--------------------------------------------------------------------------------
@@ -13,129 +15,169 @@
Fairseq(-py) is a sequence modeling toolkit that allows researchers and
developers to train custom models for translation, summarization, language
modeling and other text generation tasks.
+
We provide reference implementations of various sequence modeling papers:
List of implemented papers
-- **Convolutional Neural Networks (CNN)**
- - [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
- - [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
- - [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- - [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
- - [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
-- **LightConv and DynamicConv models**
- - [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
-- **Long Short-Term Memory (LSTM) networks**
- - Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
-- **Transformer (self-attention) networks**
- - Attention Is All You Need (Vaswani et al., 2017)
- - [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
- - [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
- - [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/transformer_lm/README.md)
- - [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
- - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- - [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
- - [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
- - [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
- - [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
- - [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
- - [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
- - [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
- - [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
- - [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
-- **Non-autoregressive Transformers**
- - Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
- - Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
- - Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
- - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
-- **Finetuning**
- - [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
+* **Convolutional Neural Networks (CNN)**
+ + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
+ + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
+ + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
+ + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
+ + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
+* **LightConv and DynamicConv models**
+ + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
+* **Long Short-Term Memory (LSTM) networks**
+ + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
+* **Transformer (self-attention) networks**
+ + Attention Is All You Need (Vaswani et al., 2017)
+ + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
+ + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
+ + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
+ + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
+ + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
+ + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
+ + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
+ + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
+ + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
+ + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
+ + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
+ + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
+ + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
+ + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
+ + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
+ + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
+ + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
+ + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
+ + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
+ + [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430)
+ + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
+ + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
+ + [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680)
+ + [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf)
+ + [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf)
+ + [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md)
+* **Non-autoregressive Transformers**
+ + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
+ + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
+ + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
+ + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
+ + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
+* **Finetuning**
+ + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
### What's New:
+* May 2023 [Released models for Scaling Speech Technology to 1,000+ Languages (Pratap, et al., 2023)](examples/mms/README.md)
+* June 2022 [Released code for wav2vec-U 2.0 from Towards End-to-end Unsupervised Speech Recognition (Liu, et al., 2022)](examples/wav2vec/unsupervised/README.md)
+* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers)
+* December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md)
+* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md)
+* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md)
+* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
+* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
+* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
+* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
+* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
+* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
+* February 2021 [Added LASER training code](examples/laser/README.md)
+* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
+* December 2020: [GottBERT model and code released](examples/gottbert/README.md)
+* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
+ * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
+* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
+* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
+* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
+* October 2020: [Added CRISS models and code](examples/criss/README.md)
-- October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
-- October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
-- October 2020: [Added CRISS models and code](examples/criss/README.md)
-- September 2020: [Added Linformer code](examples/linformer/README.md)
-- September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
-- August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
-- August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
-- July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
-- May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
-- April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
-- April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
-- April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
Previous updates
-- March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
-- February 2020: [mBART model and code released](examples/mbart/README.md)
-- February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german)
-- December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
-- November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
-- November 2019: [CamemBERT model and code released](examples/camembert/README.md)
-- November 2019: [BART model and code released](examples/bart/README.md)
-- November 2019: [XLM-R models and code released](examples/xlmr/README.md)
-- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
-- August 2019: [WMT'19 models released](examples/wmt19/README.md)
-- July 2019: fairseq relicensed under MIT license
-- July 2019: [RoBERTa models and code released](examples/roberta/README.md)
-- June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
+* September 2020: [Added Linformer code](examples/linformer/README.md)
+* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
+* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
+* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
+* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
+* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
+* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
+* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
+* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
+* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
+* February 2020: [mBART model and code released](examples/mbart/README.md)
+* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
+* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
+* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
+* November 2019: [CamemBERT model and code released](examples/camembert/README.md)
+* November 2019: [BART model and code released](examples/bart/README.md)
+* November 2019: [XLM-R models and code released](examples/xlmr/README.md)
+* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
+* August 2019: [WMT'19 models released](examples/wmt19/README.md)
+* July 2019: fairseq relicensed under MIT license
+* July 2019: [RoBERTa models and code released](examples/roberta/README.md)
+* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
### Features:
-- multi-GPU training on one machine or across multiple machines (data and model parallel)
-- fast generation on both CPU and GPU with multiple search algorithms implemented:
- - beam search
- - Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- - sampling (unconstrained, top-k and top-p/nucleus)
- - lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md))
-- large mini-batch training even on a single GPU via delayed updates
-- mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
-- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
+* multi-GPU training on one machine or across multiple machines (data and model parallel)
+* fast generation on both CPU and GPU with multiple search algorithms implemented:
+ + beam search
+ + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
+ + sampling (unconstrained, top-k and top-p/nucleus)
+ + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
+* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
+* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
+* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
+* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
+* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
+* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
with a convenient `torch.hub` interface:
-```python
+
+``` python
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
en2de.translate('Hello world', beam=5)
# 'Hallo Welt'
```
+
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
# Requirements and Installation
-* [PyTorch](http://pytorch.org/) version >= 1.4.0
-* Python version >= 3.6
+* [PyTorch](http://pytorch.org/) version >= 1.10.0
+* Python version >= 3.8
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* **To install fairseq** and develop locally:
-```bash
+
+``` bash
git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable ./
# on MacOS:
# CFLAGS="-stdlib=libc++" pip install --editable ./
+
+# to install the latest stable release (0.10.x)
+# pip install fairseq
```
+
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
-```bash
+
+``` bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
--global-option="--deprecated_fused_adam" --global-option="--xentropy" \
--global-option="--fast_multihead_attn" ./
```
-* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
-* If you use Docker make sure to increase the shared memory size either with
-`--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`.
+* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
+* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
+ as command line options to `nvidia-docker run` .
# Getting Started
@@ -148,30 +190,32 @@ types and tasks.
We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
as well as example training and evaluation commands.
-- [Translation](examples/translation/README.md): convolutional and transformer models are available
-- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
+* [Translation](examples/translation/README.md): convolutional and transformer models are available
+* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
We also have more detailed READMEs to reproduce results from specific papers:
-- [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
-- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
-- [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
-- [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
-- [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
-- [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
-- [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
-- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
-- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
-- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
-- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
-- [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
-- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
-- [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
-- [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
-- [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
-- [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
-- [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
-- [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
-- [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
+
+* [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md)
+* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
+* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
+* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
+* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
+* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
+* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
+* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
+* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
+* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
+* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
+* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
+* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
+* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
+* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
+* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
+* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
+* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
+* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
+* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
+* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
# Join the fairseq community
@@ -188,7 +232,7 @@ The license applies to the pre-trained models as well.
Please cite as:
-```bibtex
+``` bibtex
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
diff --git a/RELEASE.md b/RELEASE.md
new file mode 100644
index 0000000000..79480a11c5
--- /dev/null
+++ b/RELEASE.md
@@ -0,0 +1,13 @@
+# Creating a New Release
+
+In order to create a new release:
+
+1. Navigate to the [Fairseq Workflows](https://github.com/facebookresearch/fairseq/actions) and find the one named _Fairseq Release_.
+
+2. Under _Run Workflow_ choose the branch `main` and for _Release Type_ enter either `major`, `minor`, or `patch`.
+
+3. A branch named `$new_version-release` will be created where the `version.txt` file is updated. Merge those changes into `main`.
+
+4. Make sure that a [new PYPI package](https://pypi.org/project/fairseq/) has been uploaded.
+
+5. Make sure that a [new github release](https://github.com/facebookresearch/fairseq/releases) has been created.
diff --git a/config/config.yaml b/config/config.yaml
deleted file mode 100644
index b9ee6c74ac..0000000000
--- a/config/config.yaml
+++ /dev/null
@@ -1,111 +0,0 @@
-# @package _group_
-common:
- no_progress_bar: false
- log_interval: 100
- log_format: null
- tensorboard_logdir: null
- seed: 1
- cpu: false
- tpu: false
- bf16: false
- fp16: false
- memory_efficient_fp16: false
- memory_efficient_bf16: false
- fp16_no_flatten_grads: false
- fp16_init_scale: 128
- fp16_scale_window: null
- fp16_scale_tolerance: 0.0
- min_loss_scale: 1.0e-4
- threshold_loss_scale: null
- user_dir: null
- empty_cache_freq: 0
- all_gather_list_size: 16384
- model_parallel_size: 1
- quantization_config_path: null
- profile: false
-distributed_training:
- distributed_rank: 0
- distributed_backend: "nccl"
- distributed_init_method: null
- distributed_port: -1
- device_id: 0
- local_rank: 0
- distributed_no_spawn: false
- ddp_backend: "c10d"
- bucket_cap_mb: 25
- fix_batches_to_gpus: false
- find_unused_parameters: false
- fast_stat_sync: false
- broadcast_buffers: false
- distributed_wrapper: "DDP"
- slowmo_momentum: null
- slowmo_algorithm: "LocalSGD"
- localsgd_frequency: 3
-dataset:
- num_workers: 1
- skip_invalid_size_inputs_valid_test: false
- max_tokens: null
- batch_size: null
- required_batch_size_multiple: 8
- dataset_impl: null
- data_buffer_size: 10
- train_subset: "train"
- valid_subset: "valid"
- validate_interval: 1
- fixed_validation_seed: null
- disable_validation: false
- curriculum: 0
- gen_subset: "test"
- num_shards: 1
- shard_id: 0
- max_tokens_valid: ${dataset.max_tokens}
- batch_size_valid: ${dataset.batch_size}
-optimization:
- max_epoch: 0
- max_update: 0
- clip_norm: 25.0
- sentence_avg: false
- update_freq: [ 1 ]
- lr: [ 0.25 ]
- min_lr: -1.0
- use_bmuf: false
-checkpoint:
- save_dir: "checkpoints"
- restore_file: "checkpoint_last.pt"
- reset_dataloader: false
- reset_lr_scheduler: false
- reset_meters: false
- reset_optimizer: false
- optimizer_overrides: "{}"
- save_interval: 1
- save_interval_updates: 0
- keep_interval_updates: -1
- keep_last_epochs: -1
- keep_best_checkpoints: -1
- no_save: false
- no_epoch_checkpoints: false
- no_last_checkpoints: false
- no_save_optimizer_state: false
- best_checkpoint_metric: "loss"
- maximize_best_checkpoint_metric: false
- patience: -1
- checkpoint_suffix: ""
-bmuf:
- block_lr: 1
- block_momentum: 0.875
- global_sync_iter: 50
- warmup_iterations: 500
- use_nbm: false
- average_sync: false
-defaults:
- - task: language_modeling
- - model: null
- - criterion: null
- - optimizer: null
- - lr_scheduler: null
- - bpe: null
- - tokenizer: null
- - scoring: null
- - generation: null
- - common_eval: null
- - eval_lm: null
diff --git a/config/criterion/adaptive_loss.yaml b/config/criterion/adaptive_loss.yaml
deleted file mode 100644
index 7997b0766e..0000000000
--- a/config/criterion/adaptive_loss.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-# @package _group_
-sentence_avg: ${optimization.sentence_avg}
-ddp_backend: ${distributed_training.ddp_backend}
diff --git a/config/criterion/cross_entropy.yaml b/config/criterion/cross_entropy.yaml
deleted file mode 100644
index ad3d4148c2..0000000000
--- a/config/criterion/cross_entropy.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-# @package _group_
-sentence_avg: ${optimization.sentence_avg}
diff --git a/config/lr_scheduler/cosine.yaml b/config/lr_scheduler/cosine.yaml
deleted file mode 100644
index 0f91e0d240..0000000000
--- a/config/lr_scheduler/cosine.yaml
+++ /dev/null
@@ -1,7 +0,0 @@
-# @package _group_
-warmup_updates: 0
-warmup_init_lr: -1
-max_lr: 1.0
-t_mult: 1.0
-lr_period_updates: -1
-lr_shrink: 0.1
diff --git a/config/lr_scheduler/inverse_sqrt.yaml b/config/lr_scheduler/inverse_sqrt.yaml
deleted file mode 100644
index 0eac7d88eb..0000000000
--- a/config/lr_scheduler/inverse_sqrt.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-# @package _group_
-warmup_updates: 4000
-warmup_init_lr: -1
diff --git a/config/model/transformer_lm.yaml b/config/model/transformer_lm.yaml
deleted file mode 100644
index 3837ea54e1..0000000000
--- a/config/model/transformer_lm.yaml
+++ /dev/null
@@ -1,36 +0,0 @@
-# @package _group_
-activation_fn: "relu"
-dropout: 0.1
-attention_dropout: 0.0
-activation_dropout: 0.0
-relu_dropout: 0.0
-decoder_embed_dim: 512
-decoder_output_dim: 512
-decoder_input_dim: 512
-decoder_ffn_embed_dim: 2048
-decoder_layers: 6
-decoder_attention_heads: 8
-decoder_normalize_before: true
-no_decoder_final_norm: false
-adaptive_softmax_cutoff: null
-adaptive_softmax_dropout: 0
-adaptive_softmax_factor: 4
-no_token_positional_embeddings: false
-share_decoder_input_output_embed: false
-character_embeddings: false
-character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
-character_embedding_dim: 4
-char_embedder_highway_layers: 2
-adaptive_input: false
-adaptive_input_factor: 4
-adaptive_input_cutoff: null
-tie_adaptive_weights: false
-tie_adaptive_proj: false
-decoder_learned_pos: false
-decoder_layerdrop: 0
-decoder_layers_to_keep: null
-layernorm_embedding: false
-no_scale_embedding: false
-quant_noise_pq: 0
-quant_noise_pq_block_size: 8
-quant_noise_scalar: 0
diff --git a/config/optimizer/adam.yaml b/config/optimizer/adam.yaml
deleted file mode 100644
index e5264f895e..0000000000
--- a/config/optimizer/adam.yaml
+++ /dev/null
@@ -1,5 +0,0 @@
-# @package _group_
-adam_betas: "(0.9, 0.999)"
-adam_eps: 1.0e-8
-weight_decay: 0
-use_old_adam: false
diff --git a/config/optimizer/nag.yaml b/config/optimizer/nag.yaml
deleted file mode 100644
index 4ab2745686..0000000000
--- a/config/optimizer/nag.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-# @package _group_
-momentum: 0.99
-weight_decay: 0.0
diff --git a/config/task/language_modeling.yaml b/config/task/language_modeling.yaml
deleted file mode 100644
index 58a2ad1358..0000000000
--- a/config/task/language_modeling.yaml
+++ /dev/null
@@ -1,10 +0,0 @@
-# @package _group_
-data: ???
-sample_break_mode: "none"
-tokens_per_sample: 1024
-output_dictionary_size: -1
-self_target: false
-future_target: false
-past_target: false
-add_bos_token: false
-max_target_positions: null
diff --git a/docs/_static/theme_overrides.css b/docs/_static/theme_overrides.css
deleted file mode 100644
index 2a07641936..0000000000
--- a/docs/_static/theme_overrides.css
+++ /dev/null
@@ -1,9 +0,0 @@
-.wy-table-responsive table td kbd {
- white-space: nowrap;
-}
-.wy-table-responsive table td {
- white-space: normal !important;
-}
-.wy-table-responsive {
- overflow: visible !important;
-}
diff --git a/docs/conf.py b/docs/conf.py
index 440784bfae..0bc049f802 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -55,7 +55,7 @@
copyright = "Facebook AI Research (FAIR)"
author = "Facebook AI Research (FAIR)"
-github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
+github_doc_root = "https://github.com/pytorch/fairseq/tree/main/docs/"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
@@ -88,43 +88,7 @@
# -- Options for HTML output ----------------------------------------------
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-#
-html_theme = "sphinx_rtd_theme"
-
-# Theme options are theme-specific and customize the look and feel of a theme
-# further. For a list of options available for each theme, see the
-# documentation.
-#
-# html_theme_options = {}
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ["_static"]
-
-html_context = {
- "css_files": [
- "_static/theme_overrides.css", # override wide tables in RTD theme
- ],
-}
-
-# Custom sidebar templates, must be a dictionary that maps document names
-# to template names.
-#
-# This is required for the alabaster theme
-# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
-# html_sidebars = {
-# '**': [
-# 'about.html',
-# 'navigation.html',
-# 'relations.html', # needs 'show_related': True theme option to display
-# 'searchbox.html',
-# 'donate.html',
-# ]
-# }
-
+html_theme = "classic"
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
diff --git a/docs/getting_started.rst b/docs/getting_started.rst
index fa5971dd31..745ad7763c 100644
--- a/docs/getting_started.rst
+++ b/docs/getting_started.rst
@@ -90,7 +90,7 @@ well for the IWSLT 2014 dataset:
> mkdir -p checkpoints/fconv
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
- --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
+ --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
@@ -170,21 +170,31 @@ The easiest way to launch jobs is with the `torch.distributed.launch
For example, to train a large English-German Transformer model on 2 nodes each
with 8 GPUs (in total 16 GPUs), run the following command on each node,
-replacing ``node_rank=0`` with ``node_rank=1`` on the second node:
+replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
+sure to update ``--master_addr`` to the IP address of the first node:
.. code-block:: console
> python -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
- --master_port=1234 \
+ --master_port=12345 \
$(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
- --lr 0.0005 --min-lr 1e-09 \
+ --lr 0.0005 \
--dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 3584 \
- --fp16 --distributed-no-spawn
+ --max-epoch 70 \
+ --fp16
+
+On SLURM clusters, fairseq will automatically detect the number of nodes and
+GPUs, but a port number must be provided:
+
+.. code-block:: console
+
+ > salloc --gpus=16 --nodes 2 (...)
+ > srun fairseq-train --distributed-port 12345 (...).
Sharding very large datasets
----------------------------
diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md
index 0973cd279e..6a15298382 100644
--- a/docs/hydra_integration.md
+++ b/docs/hydra_integration.md
@@ -1,111 +1,284 @@
+## Hydra
+[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
+framework that simplifies the development of research and other complex
+applications. The key feature is the ability to dynamically create a
+hierarchical configuration by composition and override it through config files
+and the command line. The name Hydra comes from its ability to run multiple
+similar jobs - much like a Hydra with multiple heads.
-## Hydra
+## Motivation
+
+Until recently, all components in fairseq were configured through a shared
+`args` namespace that was created at application startup. Components declared
+their own `add_args` method to update the argparse parser, hoping that the names
+would not clash with arguments from other components. While this model works for
+smaller applications, as fairseq grew and became integrated into other
+applications, this became problematic. In order to determine how to configure
+each component, one needed to a) examine what args were added by this component,
+and b) read the code to figure out what shared arguments it is using that were
+added in other places. Reproducing models involved sharing commands that often
+contained dozens of command line switches.
+
+The model described above is still supported by fairseq for backward
+compatibility, but will be deprecated some time in the future.
+
+New components in fairseq should now create a dataclass that encapsulates all
+parameters required to configure this component. The dataclass is registered
+along with the component, and fairseq takes care of constructing and providing
+this configuration object to the component's constructor. Note that sharing
+parameters can optionally still work, but one has to explicitly point to the
+"source of truth" (see inheritance example below). These changes make components
+in fairseq more independent and re-usable by other applications: all that is
+needed to create a component is to initialize its dataclass and overwrite some
+of the defaults.
+
+While configuring fairseq through command line (using either the legacy argparse
+based or the new Hydra based entry points) is still fully supported, you can now
+take advantage of configuring fairseq completely or piece-by-piece through
+hierarchical YAML configuration files. These files can also be shipped as
+examples that others can use to run an identically configured job.
+
+Additionally, Hydra has a rich and growing [library of
+plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
+provide functionality such as hyperparameter sweeping (including using bayesian
+optimization through the [Ax](https://github.com/facebook/Ax) library), job
+launching across various platforms, and more.
+
+## Creating or migrating components
-Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads.
+In general, each new (or updated) component should provide a companion
+[dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
+typically located in the same file as the component and are passed as arguments
+to the `register_*()` functions. Top-level configs that should be present in
+every fairseq application are placed in the
+[global](fairseq/dataclass/configs.py) config file and added to the
+`FairseqConfig` object.
-## Train models with hydra interface
+Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
+classes are decorated with a `@dataclass` decorator, and typically inherit from
+`FairseqDataclass` (which adds some functionality for backward compatibility).
+Each field must have a type, and generally has metadata (such as a help string)
+and a default value. Only primitive types or other config objects are allowed as
+data types for each field.
-#### Provide parameters in `.yaml` files
-For example, if we'd like to train a language model with transformer, we could provide parameters in yaml files. Note that the modules used (task, model, criterion, optimizer, lr scheduler) in training must be migrated with hydra interface already (See session below).
+#### Example:
-- Provide top level choices on which generic parameter file, and which modules to use: `config/config.yaml`, this will look like for example:
+```python
+from dataclasses import dataclass, field
+from fairseq.dataclass import FairseqDataclass
+@dataclass
+class InteractiveConfig(FairseqDataclass):
+ buffer_size: int = field(
+ default=0,
+ metadata={
+ "help": "read this many sentences into a buffer before processing them"
+ },
+ )
+ input: str = field(
+ default="-",
+ metadata={"help": "file to read from; use - for stdin"},
+ )
```
-defaults:
- - task: language_modeling
- - model: transformer_lm
- - criterion: cross_entropy
- - optimizer: adam
- - lr_scheduler: inverse_sqrt
+
+### Inherting values
+
+Some components require sharing a value. For example, a learning rate scheduler
+and an optimizer may both need to know the initial learning rate value. One can
+declare a field that, by default, will inherit its value from another config
+node in the same hierarchy:
+
+```python
+@dataclass
+FairseqAdamConfig(FairseqDataclass):
+ ...
+ lr: List[float] = II("optimization.lr")
+ ...
```
-- Provide generic parameters common across different jobs: `config.yaml`
-- Provide task parameters: `config/task/language_modeling.yaml`
-- Provide model parameters: `config/model/transformer_lm.yaml`
-- Provide criterion parameters: `config/criterion/cross_entropy.yaml`
-- Provide optimizer parameters: `config/optimizer/adam.yaml`
-- Provide lr_scheduler parameters `config/lr_scheduler/inverse_sqrt.yaml`
+`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
+the value one can use in a YAML config file or through command line to achieve
+the same effect. Note that this assumes that there is an "optimization" config
+object in the root config and it has a field called "lr".
+
+### Tasks and Models
-#### Command line overriding
-`train_hydra.py` is the main entry point for training with hydra interface. If we specify all parameters we want in `.yaml` files, then we could simply use command:
+Creating Tasks and Models works same as before, except that legacy
+implementations now inherit from `LegacyFairseq*` base classes, while new
+components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
+to the `register_*()` functions.
+#### Task example:
+
+```python
+@dataclass
+class LanguageModelingConfig(FairseqDataclass):
+ data: Optional[str] = field(
+ default=None, metadata={"help": "path to data directory"}
+ )
+ ...
+
+@register_task("language_modeling", dataclass=LanguageModelingConfig)
+class LanguageModelingTask(FairseqTask):
+ ...
+ @classmethod
+ def setup_task(cls, cfg: LanguageModelingConfig):
+ ...
```
-# task.data is requested field marked by `???` in yaml
-python fairseq_cli/train_hydra.py \
-task.data=/private/home/abaevski/data/wiki103 \
+
+#### Model example:
+
+```python
+@dataclass
+class TransformerLanguageModelConfig(FairseqDataclass):
+ activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
+ default="relu", metadata={"help": "activation function to use"}
+ )
+ dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
+ ...
+
+@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
+class TransformerLanguageModel(FairseqLanguageModel):
+ ...
+ @classmethod
+ def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
+ ...
```
-Alternatively, if we need to override certain params from the command line, we could do so as below (note the structure of where each parameter sits)
+### Other components
+
+Other components work as before, but they now take their configuration dataclass
+as the only constructor argument:
+
+```python
+@dataclass
+class MosesTokenizerConfig(FairseqDataclass):
+ source_lang: str = field(default="en", metadata={"help": "source language"})
+ ...
+@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
+class MosesTokenizer(object):
+ def __init__(self, cfg: MosesTokenizerConfig):
+ ...
```
-python fairseq_cli/train_hydra.py
-task=language_modeling \
-task.data=/private/home/abaevski/data/wiki103 \
-task.tokens_per_sample=512 \
-task.sample_break_mode=none \
-model=transformer_lm \
-model.share_decoder_input_output_embed=true \
-model.dropout=0.1 \
-optimizer=adam \
-optimizer.adam_betas="'(0.9, 0.98)'" \
-optimizer.weight_decay=0.01 \
-lr_scheduler=inverse_sqrt \
-lr_scheduler.warmup_updates=4000 \
-lr_scheduler.warmup_init_lr=1e-07 \
-criterion=cross_entropy \
-common.fp16=true \
-common.log_format=json \
-common.log_interval=1 \
-dataset.max_tokens=1024 \
-dataset.num_workers=4 \
-optimization.update_freq=[16] \
-optimization.max_update=50000 \
-optimization.clip_norm=0.0 \
-optimization.lr=[0.0005] \
-checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
-checkpoint.save_interval_updates=10
+
+Note that if you are adding a new registry for a new set of components, you need
+to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
+
+```python
+@dataclass
+class FairseqConfig(object):
+ ...
+ my_new_registry: Any = None
```
-## Migrate existing/Creating new modules to hydra interface
+## Training with `fairseq-hydra-train`
+
+To fully take advantage of configuration flexibility offered by Hydra, you may
+want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
+tools such as `fairseq-train` will remain supported for the foreseeable future
+but will be deprecated eventually.
-In each of the modules we want to migrated/create with hydra interface, fundamentally we need to
+On startup, Hydra will create a configuration object that contains a hierarchy
+of all the necessary dataclasses populated with their default values in the
+code. The default values are overwritten by values found in YAML files in
+`fairseq/config` directory (which currently sets minimal defaults) and then
+further overwritten by values provided through command line arguments.
-- Provide a dataclass that layouts the parameters used in the module.
+Some of the most common use cases are shown below:
-- Modify the builder and/or constructor that previously takes `argparse.Namespace` argument `args`, into taking `omegaconf.DictConfig` config objects. At this moment we allow `Union[omegaconf.DictConfig, argparse.Namespace]` to support compatibility.
+### 1. Override default values through command line:
-- For `add_args()`, we need to extract argument from the dataclass defined in the same file, and append them into `parser`. This is also to support compatibility. This is simply supported with `gen_parser_from_dataclass` API, see examples files below.
+```shell script
+$ fairseq-hydra-train \
+ distributed_training.distributed_world_size=1 \
+ dataset.batch_size=2 \
+ task.data=data-bin \
+ model=transformer_lm/transformer_lm_gpt \
+ task=language_modeling \
+ optimization.max_update=5000
+```
+
+Note that along with explicitly providing values for parameters such as
+`dataset.batch_size`, this also tells Hydra to overlay configuration found in
+`fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
+values in the dataclass. If you want to train a model without specifying a
+particular architecture you can simply specify `model=transformer_lm`. This only
+works for migrated tasks and models.
-#### Migrated examples:
+### 2. Replace bundled configs with an external config:
-- Task: `fairseq/tasks/language_modeling.py`
+```shell script
+$ fairseq-hydra-train \
+ --config-dir /path/to/external/configs \
+ --config-name wiki103
+```
-- Model: `fairseq/models/transformer_lm.py`
+where `/path/to/external/configs/wiki103.yaml` contains:
-- Criterion: `fairseq/criterions/adaptive_loss.py` and `fairseq/criterions/cross_entropy.py`
+```yaml
+# @package _group_
-- Optimizer: `fairseq/optim/adam.py` and `fairseq/optim/nag.py`
+model:
+ _name: transformer_lm
+distributed_training:
+ distributed_world_size: 1
+dataset:
+ batch_size: 2
+task:
+ _name: language_modeling
+ data: /path/to/data
+ add_bos_token: false
+ max_target_positions: 1024
+optimization:
+ max_update: 50000
+ lr: [ 0.25 ]
+criterion: cross_entropy
+optimizer: adam
+lr_scheduler:
+ _name: cosine
+```
-- LR scheduler: `fairseq/optim/lr_scheduler/cosine_lr_scheduler.py` and `fairseq/optim/lr_scheduler/inverse_square_root_schedule.py`
+Note that here bundled configs from `fairseq/config` directory are not used,
+however the defaults from each dataclass will still be used (unless overwritten
+by your external config).
+Additionally you can choose to break up your configs by creating a directory
+structure in the same location as your main config file, with the names of the
+top-level fields (such as "model", "dataset", etc), and placing config files
+with meaningful names that would populate that specific section of your
+top-level config file (for example, you might have
+`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
+can then specify the correct configuration via command line, defaults in the
+main config, or even launch all of them as a sweep (see Hydra documentation on
+how to do this).
-## Interpolate parameters across different places
+### 3. Add an external config directory to Hydra search path:
-## Support of legacy interface
-If you still like to pass legacy style arguments in command line, `fairseq_cli/train.py` can support this. Internally it coverted `args` into hydra config objects whenever there are migrated modules aligned.
+This allows combining default configuration (including using any bundled config
+files), while specifying your own config files for some parts of the
+configuration.
+```shell script
+$ fairseq-hydra-train \
+ distributed_training.distributed_world_size=1 \
+ dataset.batch_size=2 \
+ task.data=/path/to/data/ \
+ model=transformer_lm/2_layers \
+ task=language_modeling \
+ optimization.max_update=5000 \
+ --config-dir /path/to/external/configs
```
-python fairseq_cli/train.py --task language_modeling \
-/private/home/abaevski/data/wiki103 \
---save-dir /checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
---arch transformer_lm --share-decoder-input-output-embed \
---dropout 0.1 \
---optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
---lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
---tokens-per-sample 512 --sample-break-mode none \
---max-tokens 1024 --update-freq 16 \
---fp16 \
---max-update 50000 --log-format json --log-interval 1 --num-workers 4 \
---save-interval-updates 10
+
+where `/path/to/external/configs` has the following structure:
+```
+.
++-- model
+| +-- transformer_lm
+| | +-- 2_layers.yaml
```
+
+and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
+`decoder_layers` set to 2. You can add other configs to configure other
+components as well.
diff --git a/docs/requirements.txt b/docs/requirements.txt
deleted file mode 100644
index c734a1f04f..0000000000
--- a/docs/requirements.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-sphinx<2.0
-sphinx-argparse
diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst
index b02fec0489..de099f08f5 100644
--- a/docs/tutorial_classifying_names.rst
+++ b/docs/tutorial_classifying_names.rst
@@ -208,7 +208,7 @@ following contents::
import torch
from fairseq.data import Dictionary, LanguagePairDataset
- from fairseq.tasks import FairseqTask, register_task
+ from fairseq.tasks import LegacyFairseqTask, register_task
@register_task('simple_classification')
diff --git a/examples/MMPT/.gitignore b/examples/MMPT/.gitignore
new file mode 100644
index 0000000000..70a255dc91
--- /dev/null
+++ b/examples/MMPT/.gitignore
@@ -0,0 +1,139 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+runs
+data
+pretrained_models
+projects/mmfusion_*
+log_test
+third-party
+python_log
+slurm_snapshot_code
+lightning_logs
+demos
diff --git a/examples/MMPT/CONFIG.md b/examples/MMPT/CONFIG.md
new file mode 100644
index 0000000000..bbd1403dfa
--- /dev/null
+++ b/examples/MMPT/CONFIG.md
@@ -0,0 +1,41 @@
+### Config Files Explained
+
+Taking `projects/mfmmlm.yaml` for example, which run pretraining using masked frame model (MFM) and masked language model (MLM) on a single BERT:
+
+```yaml
+project_dir: mfmmlm # specify the project dir for this baseline.
+run_task:
+ - how2.yaml # run pretraining on how2 when launching `projects/taskmfmmlm.yaml`
+ - [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml] # run fine-tuning tasks.
+base_dir: task # a global template folder to specify each training task.
+task_group:
+ pretrain: # section for pretraining. Most baselines differs in this section.
+ task_list:
+ - how2.yaml # reconfig `projects/task/how2.yaml`
+ dataset:
+ aligner: MFMMLMAligner # overwrite the aligner for MFMMLM training task.
+ model:
+ model_cls: MMFusionMFMMLM # overwrite the model, which constructs negative examples for MFM on-the-fly.
+ loss:
+ loss_cls: MFMMLM # overwrite the loss as MFMMLM, which combines MFM and MLM together.
+ fairseq: # all fairseq args can be expecified under this name.
+ dataset:
+ batch_size: 128
+ finetune: # section for fine-tuning tasks, we don't need to change anything here mostly since we want to see how pretraining can contribute to finetuning.
+ task_list: # specify the list of downstream tasks, e.g., copy `projects/task/vtt.yaml` to `projects/mfmmlm`.
+ - vtt.yaml
+ - vttqa.yaml
+ - youcook.yaml
+ - youcookcap.yaml
+ - crosstask.yaml
+ - coin.yaml
+ test: # section for testing.
+ task_list:
+ - test_vtt.yaml
+ - test_vttqa.yaml
+ - test_youcook.yaml
+ - test_youcookcap.yaml
+ - test_crosstask.yaml
+ - test_crosstask_zs.yaml
+ - test_coin.yaml
+```
diff --git a/examples/MMPT/DATASET.md b/examples/MMPT/DATASET.md
new file mode 100644
index 0000000000..930403eb36
--- /dev/null
+++ b/examples/MMPT/DATASET.md
@@ -0,0 +1,34 @@
+# Dataset
+
+We understand video data are challenging to download and process. For videos, we provide our preprocessing scripts under `scripts/video_feature_extractor` (deeply adapted from `https://github.com/antoine77340/video_feature_extractor`); for text, we pre-tokenizing scripts under `scripts/text_token_extractor`.
+
+### S3D Feature Extraction
+We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
+
+We implement a `PathBuilder` to automatically track video ids, source video paths to their feature locations (you may need `conda install -c anaconda pandas`). Decoding may need `pip install ffmpeg-python`.
+
+### Howto100M
+[Howto100M](https://www.di.ens.fr/willow/research/howto100m/) is a large-scale video pre-training datasets. You may download videos by yourself and run preprocessing of our scripts.
+
+Several key differences of our preprocessing from existing papers: (1) we use `raw_caption.json` instead of `caption.json` to have pure self-supervision on text (`caption.json` has manual removal of stop words); (2) we remove partially duplicated texts that are originally designed for real-time readability (see `mmpt/processors/dedupprocessor.py`); (3) then we shard video/text features using `SharedTensor` in `mmpt/utils/shardedtensor.py` for fast loading during training (faster than `h5py`).
+
+#### Steps
+##### video
+To extract video features: edit and run `bash scripts/video_feature_extractor/how2/s3d.sh`. (consider to run this on multiple machines; by default, we store features in fp16 to save space and also for faster training).
+
+Split available video ids as `data/how2/how2_s3d_train.lst` and `data/how2/how2_s3d_val.lst`.
+
+Lastly, pack video features into `ShardedTensor` using `python scripts/video_feature_extractor/shard_feature.py`.
+
+##### text
+Clean captions using `python -m mmpt.processors.dedupprocessor`.
+
+Tokenize dedupped captions `data/how2/raw_caption_dedup.pkl` into sharded numpy arrays:
+```
+python scripts/text_token_extractor/pretokenization.py scripts/text_token_extractor/configs/bert-base-uncased.yaml
+```
+
+### Youcook, MSRVTT etc.
+We use the version of Youcook and MSRVTT come with Howto100M and MILNCE. Please download the data to `data/youcook` and `data/msrvtt` accordingly, you can also check `projects/task/youcook.yaml` and `projects/task/vtt.yaml` etc. in details.
+We extract features for Youcook, MSRVTT similar to the first step of Howto100M but we read text from meta data directly and perform on-the-fly tokenization.
+
diff --git a/examples/MMPT/README.md b/examples/MMPT/README.md
new file mode 100644
index 0000000000..4a84819d9d
--- /dev/null
+++ b/examples/MMPT/README.md
@@ -0,0 +1,166 @@
+# VideoCLIP and VLM
+
+You just find this toolkit for multimodal video understanding! It contains implementation of two recent multi-modal video understanding papers [VideoCLIP](https://arxiv.org/pdf/2109.14084.pdf) (EMNLP, 2021) and [VLM](https://aclanthology.org/2021.findings-acl.370.pdf) (ACL Findings, 2021), along with high-performance toolkits that are typically lacking in existing codebase. The toolkit is desigend to contain generic performance-tuned components that can be potentially adapted to other frameworks (we initially use fairseq).
+
+VideoCLIP is a contrastive learning model for zero-shot transfer to retrieval/classification/sequence labeling style tasks.
+
+
+
+VLM is a masked language model style pre-training using only one encoder with masked modality model (MMM) for retrieval/generation/sequence labeling style tasks.
+
+
+
+### News
+[Oct. 2021] Initial release of implementation for the following papers:
+[VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding](https://arxiv.org/pdf/2109.14084.pdf) (Xu et. al., EMNLP 2021)
+[VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding](https://aclanthology.org/2021.findings-acl.370.pdf) (Xu et. al., ACL Findings 2021)
+
+
+### Installation
+We aim to minimize the dependency of this repo on other packages.
+We use fairseq as the main trainer (no models/datasets dependency on fairseq. We will support other trainer in future):
+```
+git clone https://github.com/pytorch/fairseq
+cd fairseq
+pip install -e . # also optionally follow fairseq README for apex installation for fp16 training.
+export MKL_THREADING_LAYER=GNU # fairseq may need this for numpy.
+```
+
+Then install this toolkit:
+```
+cd examples/MMPT # MMPT can be in any folder, not necessarily under fairseq/examples.
+pip install -e .
+```
+
+The code is developed under Python=3.8.8, Pytorch=1.8, cuda=11.0 with fairseq=1.0.0a0+af0389f and tested under Python=3.8.8 pytorch=1.9 cuda=11.0 fairseq=1.0.0a0+8e7bc73 during code release.
+Most models require `transformers==3.4` for API compatibility `pip install transformers==3.4`.
+In addition, some downstream tasks may need `conda install pandas`.
+
+
+### Usage
+#### Download Checkpoints
+We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
+
+Download VideoCLIP checkpoint `https://dl.fbaipublicfiles.com/MMPT/retri/videoclip/checkpoint_best.pt` to `runs/retri/videoclip` or VLM checkpoint `https://dl.fbaipublicfiles.com/MMPT/mtm/vlm/checkpoint_best.pt` to `runs/mtm/vlm`.
+
+#### Demo of Inference
+run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` to get all `.yaml`s for VideoCLIP.
+
+```python
+import torch
+
+from mmpt.models import MMPTModel
+
+
+model, tokenizer, aligner = MMPTModel.from_pretrained(
+ "projects/retri/videoclip/how2.yaml")
+
+model.eval()
+
+
+# B, T, FPS, H, W, C (VideoCLIP is trained on 30 fps of s3d)
+video_frames = torch.randn(1, 2, 30, 224, 224, 3)
+caps, cmasks = aligner._build_text_seq(
+ tokenizer("some text", add_special_tokens=False)["input_ids"]
+)
+
+caps, cmasks = caps[None, :], cmasks[None, :] # bsz=1
+
+with torch.no_grad():
+ output = model(video_frames, caps, cmasks, return_score=True)
+print(output["score"]) # dot-product
+```
+
+#### Data Preparation
+See [dataset](DATASET.md) for each dataset.
+
+#### Global Config for Training Pipeline
+We organize a global config file for a training/testing pipeline under projects (see a detailed [explanation](CONFIG.md)). For example, VideoCLIP in `projects/retri/videoclip.yaml` and VLM is in `projects/mtm/vlm.yaml`.
+
+We wrap all cmds into `locallaunch.py` and `mmpt_cli/localjob.py`. You can check concrete cmds by `--dryrun` and then drop it for actual run.
+
+First, run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` will generate configs for all configs of pre-training, zero-shot evaluation, fine-tuning and testing, for VideoCLIP under `projects/retri/videoclip`.
+
+Then each (either training or evaluation) process will be configed by a concrete config file (we save all complex arguments into the concrete config file for reproducibility, including fairseq args). For example, run zero-shot evaluation on youcook,
+```
+python locallaunch.py projects/retri/videoclip/test_youcook_zs.yaml --jobtype local_predict # zero-shot evaluation.
+python locallaunch.py projects/retri/videoclip/youcook_videoclip.yaml --jobtype local_single --dryrun # fine-tuning: use --dryrun to check cmds and drop it to make an actual run; local_small will run on two gpus (as in paper).
+python locallaunch.py projects/retri/videoclip/test_youcook_videoclip.yaml --jobtype local_predict # testing on fine-tuned model.
+```
+
+Pretraining can be run as:
+```
+python locallaunch.py projects/retri/videoclip/how2.yaml --jobtype local_single --dryrun # check then drop dryrun; paper is ran on local_big as 8 gpus.
+```
+You may need to change `--jobtype`, check/extend `LocalJob` in `mmpt_cli/localjob.py` for multi-gpu/multi-node pre-training.
+
+The detailed instructions of pretraining and fine-tuning can be found at [pretraining instruction](pretraining.md) and [finetuning instruction](endtask.md).
+
+
+### Development
+Several components of this toolkit can be re-used for future research (and also our ongoing research).
+
+#### Framework Wrapper
+We currently only support fairseq, but most components can be easily fit into other frameworks like huggingface. This repo is a `--user-dir` of fairseq with fairseq wrapper. For example, `mmpt/tasks` includes a `FairseqMMTTask`, which manages `mmpt/datasets` with `FairseqDataset`, `mmpt/models` with `FairseqModel`, `mmpt/losses` with `FairseqCriterion`.
+
+#### Processors
+**Multi**modal research introduces the complexity on modality alignment from different input sources to losses. Inspired by [MMF](https://github.com/facebookresearch/mmf), this toolkit leverages `mmpt/processors` to handle various needs of data preprocessing and loading, **alleviating** the needs of multiple `torch.data.utils.Dataset` (that can be tricky for ablation study).
+Processors can also be decoupled from `torch.data.utils.Dataset` for offline preprocessing instead of on-the-fly data preprocessing.
+
+We decouple a `mmpt.MMDataset` as 3 types of processors: `MetaProcessor`, `VideoProcessor`, `TextProcessor` and `Aligner`. They can be configed in `dataset` field of a config file (e.g., see `projects/task/how2.yaml`).
+`MetaProcessor` is used to load the meta data about a dataset, aka, all video_ids of how2 dataset.
+`VideoProcessor` is used to load the video features about a dataset. For example, S3D features for each second of a video.
+`TextProcessor` is used to load the text (feature). For example, BERT pre-tokenized text clips for how2 dataset (with `start`s, `end`s of timestamps and `cap` for `token_ids`).
+`Aligner` is the core class for different baselines that prepares the training data. For example, sampling a clip, masking tokens for MLM, etc.
+
+#### Performance-tuned Components
+To speed up pre-training, this toolkit uses sharded features stored in mmaped numpy, backed by `ShardedTensor` in `mmpt/utils/shardedtensor.py` (adopted from MARGE paper). This reduces the loads of IO for multi-GPU training without loading all features for a video into the memory each time and `ShardedTensor` ensure features are stored in continuous disk space for near random access. This is used for both How2 video features and texts in `mmpt/processors/how2processor.py`.
+
+
+### Citation
+If this codebase is useful for your work, please cite the following papers:
+
+```BibTeX
+@inproceedings{xu-etal-2021-videoclip,
+ title = "{VideoCLIP}: Contrastive Pre-training for\\Zero-shot Video-Text Understanding",
+ author = "Xu, Hu and
+ Ghosh, Gargi and
+ Huang, Po-Yao and
+ Okhonko, Dmytro and
+ Aghajanyan, Armen and
+ Metze, Florian and
+ Zettlemoyer, Luke and
+ Feichtenhofer, Christoph",
+ booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
+ month = nov,
+ year = "2021",
+ address = "Online",
+ publisher = "Association for Computational Linguistics",
+}
+
+@inproceedings{xu-etal-2021-vlm,
+ title = "{VLM}: Task-agnostic Video-Language Model Pre-training for Video Understanding",
+ author = "Xu, Hu and
+ Ghosh, Gargi and
+ Huang, Po-Yao and
+ Arora, Prahal and
+ Aminzadeh, Masoumeh and
+ Feichtenhofer, Christoph and
+ Metze, Florian and
+ Zettlemoyer, Luke",
+ booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021",
+ month = aug,
+ year = "2021",
+ address = "Online",
+ publisher = "Association for Computational Linguistics",
+ url = "https://aclanthology.org/2021.findings-acl.370",
+ doi = "10.18653/v1/2021.findings-acl.370",
+ pages = "4227--4239",
+}
+```
+
+### Bug Reports
+This repo is in its initial stage, welcome bug reports to huxu@fb.com
+
+### Copyright
+The majority of Multimodal Pre-training (MMPT) is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Evaluation Codes/Models: Howto100M and HuggingFace Transformers are licensed under the Apache2.0 license; COIN and NLG-eval are licensed under the MIT license; CrossTask is licensed under the BSD-3; DiDeMo is licensed under the BSD-2 license.
diff --git a/examples/MMPT/endtask.md b/examples/MMPT/endtask.md
new file mode 100644
index 0000000000..7690955327
--- /dev/null
+++ b/examples/MMPT/endtask.md
@@ -0,0 +1,41 @@
+# Zero-shot Transfer and Finetuning
+
+(If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.)
+All finetuning datasets (specifically `processors`) are defined in `mmpt.processors.dsprocessor`.
+Given the complexity of different types of finetuning tasks, each task may have their own meta/video/text/aligner processors and `mmpt/evaluators/{Predictor,Metric}`.
+
+### Tasks
+
+Currently, we support 5 end datasets: `MSRVTT`, `Youcook`, `COIN`, `Crosstask` and `DiDeMo` with the following tasks:
+text-video retrieval: `MSRVTT`, `Youcook`, `DiDeMo`;
+video captioning: `Youcook`;
+Video Question and Answering: `MSRVTT-QA`.
+
+To add your own dataset, you can specify the corresponding processors and config them in the `dataset` field of a config file, such as `projects/task/vtt.yaml`.
+
+### Zero-shot Transfer (no Training)
+Zero-shot transfer will run the pre-trained model (e.g., VideoCLIP) directly on testing data. Configs with pattern: `projects/task/*_zs_*.yaml` are dedicated for zero-shot transfer.
+
+### Fine-tuning
+
+The training of a downstream task is similar to pretraining, execept you may need to specify the `restore_file` in `fairseq.checkpoint` and reset optimizers, see `projects/task/ft.yaml` that is included by `projects/task/vtt.yaml`.
+
+We typically do finetuning on 2 gpus (`local_small`).
+
+### Testing
+For each finetuning dataset, you may need to specify a testing config, similar to `projects/task/test_vtt.yaml`.
+
+We define `mmpt.evaluators.Predictor` for different types of prediction. For example, `MSRVTT` and `Youcook` are video-retrieval tasks and expecting to use `RetrievalPredictor`. You may need to define your new type of predictors and specify that in `predictor` field of a testing config.
+
+Each task may also have their own metric for evaluation. This can be created in `mmpt.evaluators.Metric` and specified in the `metric` field of a testing config.
+
+Launching a testing is as simple as training by specifying the path of a testing config:
+```python locallaunch.py projects/mfmmlm/test_vtt.yaml```
+Testing will be launched locally by default since prediction is computationally less expensive.
+
+### Third-party Libraries
+We list the following finetuning tasks that require third-party libraries.
+
+Youcook captioning: `https://github.com/Maluuba/nlg-eval`
+
+CrossTask: `https://github.com/DmZhukov/CrossTask`'s `dp` under `third-party/CrossTask` (`python setup.py build_ext --inplace`)
diff --git a/examples/MMPT/locallaunch.py b/examples/MMPT/locallaunch.py
new file mode 100644
index 0000000000..e20fd816fa
--- /dev/null
+++ b/examples/MMPT/locallaunch.py
@@ -0,0 +1,148 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import argparse
+import os
+
+from omegaconf import OmegaConf
+
+from mmpt.utils import recursive_config, overwrite_dir
+from mmpt_cli.localjob import LocalJob
+
+
+class JobLauncher(object):
+ JOB_CONFIG = {
+ "local": LocalJob,
+ }
+
+ def __init__(self, yaml_file):
+ self.yaml_file = yaml_file
+ job_key = "local"
+
+ if yaml_file.endswith(".yaml"):
+ config = recursive_config(yaml_file)
+ if config.task_type is not None:
+ job_key = config.task_type.split("_")[0]
+ else:
+ raise ValueError("unknown extension of job file:", yaml_file)
+ self.job_key = job_key
+
+ def __call__(self, job_type=None, dryrun=False):
+ if job_type is not None:
+ self.job_key = job_type.split("_")[0]
+ print("[JobLauncher] job_key", self.job_key)
+ job = JobLauncher.JOB_CONFIG[self.job_key](
+ self.yaml_file, job_type=job_type, dryrun=dryrun)
+ return job.submit()
+
+
+class Pipeline(object):
+ """a job that loads yaml config."""
+
+ def __init__(self, fn):
+ """
+ load a yaml config of a job and save generated configs as yaml for each task.
+ return: a list of files to run as specified by `run_task`.
+ """
+ if fn.endswith(".py"):
+ # a python command.
+ self.backend = "python"
+ self.run_yamls = [fn]
+ return
+
+ job_config = recursive_config(fn)
+ if job_config.base_dir is None: # single file job config.
+ self.run_yamls = [fn]
+ return
+
+ self.project_dir = os.path.join("projects", job_config.project_dir)
+ self.run_dir = os.path.join("runs", job_config.project_dir)
+
+ if job_config.run_task is not None:
+ run_yamls = []
+ for stage in job_config.run_task:
+ # each stage can have multiple tasks running in parallel.
+ if OmegaConf.is_list(stage):
+ stage_yamls = []
+ for task_file in stage:
+ stage_yamls.append(
+ os.path.join(self.project_dir, task_file))
+ run_yamls.append(stage_yamls)
+ else:
+ run_yamls.append(os.path.join(self.project_dir, stage))
+ self.run_yamls = run_yamls
+ configs_to_save = self._overwrite_task(job_config)
+ self._save_configs(configs_to_save)
+
+ def __getitem__(self, idx):
+ yaml_files = self.run_yamls[idx]
+ if isinstance(yaml_files, list):
+ return [JobLauncher(yaml_file) for yaml_file in yaml_files]
+ return [JobLauncher(yaml_files)]
+
+ def __len__(self):
+ return len(self.run_yamls)
+
+ def _save_configs(self, configs_to_save: dict):
+ # save
+ os.makedirs(self.project_dir, exist_ok=True)
+ for config_file in configs_to_save:
+ config = configs_to_save[config_file]
+ print("saving", config_file)
+ OmegaConf.save(config=config, f=config_file)
+
+ def _overwrite_task(self, job_config):
+ configs_to_save = {}
+ self.base_project_dir = os.path.join("projects", job_config.base_dir)
+ self.base_run_dir = os.path.join("runs", job_config.base_dir)
+
+ for config_sets in job_config.task_group:
+ overwrite_config = job_config.task_group[config_sets]
+ if (
+ overwrite_config.task_list is None
+ or len(overwrite_config.task_list) == 0
+ ):
+ print(
+ "[warning]",
+ job_config.task_group,
+ "has no task_list specified.")
+ # we don't want this added to a final config.
+ task_list = overwrite_config.pop("task_list", None)
+ for config_file in task_list:
+ config_file_path = os.path.join(
+ self.base_project_dir, config_file)
+ config = recursive_config(config_file_path)
+ # overwrite it.
+ if overwrite_config:
+ config = OmegaConf.merge(config, overwrite_config)
+ overwrite_dir(config, self.run_dir, basedir=self.base_run_dir)
+ save_file_path = os.path.join(self.project_dir, config_file)
+ configs_to_save[save_file_path] = config
+ return configs_to_save
+
+
+def main(args):
+ job_type = args.jobtype if args.jobtype else None
+ # parse multiple pipelines.
+ pipelines = [Pipeline(fn) for fn in args.yamls.split(",")]
+
+ for pipe_id, pipeline in enumerate(pipelines):
+ if not hasattr(pipeline, "project_dir"):
+ for job in pipeline[0]:
+ job(job_type=job_type, dryrun=args.dryrun)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("yamls", type=str)
+ parser.add_argument(
+ "--dryrun",
+ action="store_true",
+ help="run config and prepare to submit without launch the job.",
+ )
+ parser.add_argument(
+ "--jobtype", type=str, default="",
+ help="force to run jobs as specified.")
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/MMPT/mmpt/__init__.py b/examples/MMPT/mmpt/__init__.py
new file mode 100644
index 0000000000..6ff86ddd5c
--- /dev/null
+++ b/examples/MMPT/mmpt/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+try:
+ # fairseq user dir
+ from .datasets import FairseqMMDataset
+ from .losses import FairseqCriterion
+ from .models import FairseqMMModel
+ from .tasks import FairseqMMTask
+except ImportError:
+ pass
diff --git a/examples/MMPT/mmpt/datasets/__init__.py b/examples/MMPT/mmpt/datasets/__init__.py
new file mode 100644
index 0000000000..2578235e17
--- /dev/null
+++ b/examples/MMPT/mmpt/datasets/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .mmdataset import *
+
+try:
+ from .fairseqmmdataset import *
+except ImportError:
+ pass
diff --git a/examples/MMPT/mmpt/datasets/fairseqmmdataset.py b/examples/MMPT/mmpt/datasets/fairseqmmdataset.py
new file mode 100644
index 0000000000..02c49141db
--- /dev/null
+++ b/examples/MMPT/mmpt/datasets/fairseqmmdataset.py
@@ -0,0 +1,57 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+TODO (huxu): fairseq wrapper class for all dataset you defined: mostly MMDataset.
+"""
+
+from collections import OrderedDict
+
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+from fairseq.data import FairseqDataset, data_utils
+
+
+class FairseqMMDataset(FairseqDataset):
+ """
+ A wrapper class for MMDataset for fairseq.
+ """
+
+ def __init__(self, mmdataset):
+ if not isinstance(mmdataset, Dataset):
+ raise TypeError("mmdataset must be of type `torch.utils.data.dataset`.")
+ self.mmdataset = mmdataset
+
+ def set_epoch(self, epoch, **unused):
+ super().set_epoch(epoch)
+ self.epoch = epoch
+
+ def __getitem__(self, idx):
+ with data_utils.numpy_seed(43211, self.epoch, idx):
+ return self.mmdataset[idx]
+
+ def __len__(self):
+ return len(self.mmdataset)
+
+ def collater(self, samples):
+ if hasattr(self.mmdataset, "collator"):
+ return self.mmdataset.collator(samples)
+ if len(samples) == 0:
+ return {}
+ if isinstance(samples[0], dict):
+ batch = OrderedDict()
+ for key in samples[0]:
+ if samples[0][key] is not None:
+ batch[key] = default_collate([sample[key] for sample in samples])
+ return batch
+ else:
+ return default_collate(samples)
+
+ def size(self, index):
+ """dummy implementation: we don't use --max-tokens"""
+ return 1
+
+ def num_tokens(self, index):
+ """dummy implementation: we don't use --max-tokens"""
+ return 1
diff --git a/examples/MMPT/mmpt/datasets/mmdataset.py b/examples/MMPT/mmpt/datasets/mmdataset.py
new file mode 100644
index 0000000000..3d07283f91
--- /dev/null
+++ b/examples/MMPT/mmpt/datasets/mmdataset.py
@@ -0,0 +1,111 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from collections import OrderedDict
+
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+from ..utils import set_seed
+
+
+class MMDataset(Dataset):
+ """
+ A generic multi-modal dataset.
+ Args:
+ `meta_processor`: a meta processor,
+ handling loading meta data and return video_id and text_id.
+ `video_processor`: a video processor,
+ handling e.g., decoding, loading .np files.
+ `text_processor`: a text processor,
+ handling e.g., tokenization.
+ `aligner`: combine the video and text feature
+ as one training example.
+ """
+
+ def __init__(
+ self,
+ meta_processor,
+ video_processor,
+ text_processor,
+ align_processor,
+ ):
+ self.split = meta_processor.split
+ self.meta_processor = meta_processor
+ self.video_processor = video_processor
+ self.text_processor = text_processor
+ self.align_processor = align_processor
+
+ def __len__(self):
+ return len(self.meta_processor)
+
+ def __getitem__(self, idx):
+ if self.split == "test":
+ set_seed(idx)
+ video_id, text_id = self.meta_processor[idx]
+ video_feature = self.video_processor(video_id)
+ text_feature = self.text_processor(text_id)
+ output = self.align_processor(video_id, video_feature, text_feature)
+ # TODO (huxu): the following is for debug purpose.
+ output.update({"idx": idx})
+ return output
+
+ def collater(self, samples):
+ """This collator is deprecated.
+ set self.collator = MMDataset.collater.
+ see collator in FairseqMMDataset.
+ """
+
+ if len(samples) == 0:
+ return {}
+ if isinstance(samples[0], dict):
+ batch = OrderedDict()
+ for key in samples[0]:
+ if samples[0][key] is not None:
+ batch[key] = default_collate(
+ [sample[key] for sample in samples])
+ # if torch.is_tensor(batch[key]):
+ # print(key, batch[key].size())
+ # else:
+ # print(key, len(batch[key]))
+ return batch
+ else:
+ return default_collate(samples)
+
+ def print_example(self, output):
+ print("[one example]", output["video_id"])
+ if (
+ hasattr(self.align_processor, "subsampling")
+ and self.align_processor.subsampling is not None
+ and self.align_processor.subsampling > 1
+ ):
+ for key in output:
+ if torch.is_tensor(output[key]):
+ output[key] = output[key][0]
+
+ # search tokenizer to translate ids back.
+ tokenizer = None
+ if hasattr(self.text_processor, "tokenizer"):
+ tokenizer = self.text_processor.tokenizer
+ elif hasattr(self.align_processor, "tokenizer"):
+ tokenizer = self.align_processor.tokenizer
+ if tokenizer is not None:
+ caps = output["caps"].tolist()
+ if isinstance(caps[0], list):
+ caps = caps[0]
+ print("caps", tokenizer.decode(caps))
+ print("caps", tokenizer.convert_ids_to_tokens(caps))
+
+ for key, value in output.items():
+ if torch.is_tensor(value):
+ if len(value.size()) >= 3: # attention_mask.
+ print(key, value.size())
+ print(key, "first", value[0, :, :])
+ print(key, "last", value[-1, :, :])
+ else:
+ print(key, value)
+ print("[end of one example]")
diff --git a/examples/MMPT/mmpt/evaluators/__init__.py b/examples/MMPT/mmpt/evaluators/__init__.py
new file mode 100644
index 0000000000..2d06b9d797
--- /dev/null
+++ b/examples/MMPT/mmpt/evaluators/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .metric import *
+from .evaluator import *
+
+
+# experimental.
+try:
+ from .expmetric import *
+except ImportError:
+ pass
diff --git a/examples/MMPT/mmpt/evaluators/evaluator.py b/examples/MMPT/mmpt/evaluators/evaluator.py
new file mode 100644
index 0000000000..94d9c5ec9a
--- /dev/null
+++ b/examples/MMPT/mmpt/evaluators/evaluator.py
@@ -0,0 +1,54 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import glob
+import numpy as np
+
+from . import metric as metric_path
+from . import predictor as predictor_path
+
+
+class Evaluator(object):
+ """
+ perform evaluation on a single (downstream) task.
+ make this both offline and online.
+ TODO(huxu) saving evaluation results.
+ """
+
+ def __init__(self, config, eval_dataloader=None):
+ if config.metric is None:
+ raise ValueError("config.metric is", config.metric)
+ metric_cls = getattr(metric_path, config.metric)
+ self.metric = metric_cls(config)
+ if config.predictor is None:
+ raise ValueError("config.predictor is", config.predictor)
+ predictor_cls = getattr(predictor_path, config.predictor)
+ self.predictor = predictor_cls(config)
+ self.eval_dataloader = eval_dataloader
+
+ def __call__(self):
+ try:
+ print(self.predictor.pred_dir)
+ for pred_file in glob.glob(
+ self.predictor.pred_dir + "/*_merged.npy"):
+ outputs = np.load(pred_file)
+ results = self.metric.compute_metrics(outputs)
+ self.metric.print_computed_metrics(results)
+
+ outputs = np.load(os.path.join(
+ self.predictor.pred_dir, "merged.npy"))
+ results = self.metric.compute_metrics(outputs)
+ return {"results": results, "metric": self.metric}
+ except FileNotFoundError:
+ print("\n[missing]", self.predictor.pred_dir)
+ return {}
+
+ def evaluate(self, model, eval_dataloader=None, output_file="merged"):
+ if eval_dataloader is None:
+ eval_dataloader = self.eval_dataloader
+ outputs = self.predictor.predict_loop(
+ model, eval_dataloader, output_file)
+ results = self.metric.compute_metrics(**outputs)
+ return results
diff --git a/examples/MMPT/mmpt/evaluators/metric.py b/examples/MMPT/mmpt/evaluators/metric.py
new file mode 100644
index 0000000000..163724bb25
--- /dev/null
+++ b/examples/MMPT/mmpt/evaluators/metric.py
@@ -0,0 +1,313 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import json
+
+
+class Metric(object):
+ def __init__(self, config, metric_names):
+ self.metric_names = metric_names
+
+ def best_metric(self, metric):
+ return metric[self.metric_names[0]]
+
+ def save_metrics(self, fn, metrics):
+ with open(fn, "w") as fw:
+ json.dump(fw, metrics)
+
+ def print_computed_metrics(self, metrics):
+ raise NotImplementedError
+
+
+class RetrievalMetric(Metric):
+ """
+ this is modified from `howto100m/metrics.py`.
+ History of changes:
+ refactor as a class.
+ add metric_key in __init__
+ """
+
+ def __init__(self, config, metric_names=["R1", "R5", "R10", "MR"]):
+ super().__init__(config, metric_names)
+ self.error = False # TODO(huxu): add to config to print error.
+
+ def compute_metrics(self, outputs, texts, **kwargs):
+ x = outputs
+ sx = np.sort(-x, axis=1)
+ d = np.diag(-x)
+ d = d[:, np.newaxis]
+ ind = sx - d
+ ind = np.where(ind == 0)
+ ind = ind[1]
+ metrics = {}
+ metrics["R1"] = float(np.sum(ind == 0)) / len(ind)
+ metrics["R5"] = float(np.sum(ind < 5)) / len(ind)
+ metrics["R10"] = float(np.sum(ind < 10)) / len(ind)
+ metrics["MR"] = np.median(ind) + 1
+
+ max_idx = np.argmax(outputs, axis=1)
+ if self.error:
+ # print top-20 errors.
+ error = []
+ for ex_idx in range(20):
+ error.append((texts[ex_idx], texts[max_idx[ex_idx]]))
+ metrics["error"] = error
+ return metrics
+
+ def print_computed_metrics(self, metrics):
+ r1 = metrics["R1"]
+ r5 = metrics["R5"]
+ r10 = metrics["R10"]
+ mr = metrics["MR"]
+ print(
+ "R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}".format(
+ r1, r5, r10, mr
+ )
+ )
+ if "error" in metrics:
+ print(metrics["error"])
+
+
+class DiDeMoMetric(Metric):
+ """
+ History of changes:
+ python 2.x to python 3.x.
+ merge utils.py into eval to save one file.
+ reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
+ Code to evaluate your results on the DiDeMo dataset.
+ """
+ def __init__(self, config, metric_names=["rank1", "rank5", "miou"]):
+ super().__init__(config, metric_names)
+
+ def compute_metrics(self, outputs, targets, **kwargs):
+ assert len(outputs) == len(targets)
+ rank1, rank5, miou = self._eval_predictions(outputs, targets)
+ metrics = {
+ "rank1": rank1,
+ "rank5": rank5,
+ "miou": miou
+ }
+ return metrics
+
+ def print_computed_metrics(self, metrics):
+ rank1 = metrics["rank1"]
+ rank5 = metrics["rank5"]
+ miou = metrics["miou"]
+ # print("Average rank@1: %f" % rank1)
+ # print("Average rank@5: %f" % rank5)
+ # print("Average iou: %f" % miou)
+
+ print(
+ "Average rank@1: {:.4f} Average rank@5: {:.4f} Average iou: {:.4f}".format(
+ rank1, rank5, miou
+ )
+ )
+
+ def _iou(self, pred, gt):
+ intersection = max(0, min(pred[1], gt[1]) + 1 - max(pred[0], gt[0]))
+ union = max(pred[1], gt[1]) + 1 - min(pred[0], gt[0])
+ return float(intersection)/union
+
+ def _rank(self, pred, gt):
+ return pred.index(tuple(gt)) + 1
+
+ def _eval_predictions(self, segments, data):
+ '''
+ Inputs:
+ segments: For each item in the ground truth data, rank possible video segments given the description and video.
+ In DiDeMo, there are 21 posible moments extracted for each video so the list of video segments will be of length 21.
+ The first video segment should be the video segment that best corresponds to the text query.
+ There are 4180 sentence in the validation data, so when evaluating a model on the val dataset,
+ segments should be a list of lenght 4180, and each item in segments should be a list of length 21.
+ data: ground truth data
+ '''
+ average_ranks = []
+ average_iou = []
+ for s, d in zip(segments, data):
+ pred = s[0]
+ ious = [self._iou(pred, t) for t in d['times']]
+ average_iou.append(np.mean(np.sort(ious)[-3:]))
+ ranks = [self._rank(s, t) for t in d['times'] if tuple(t) in s] # if t in s] is added for s, e not in prediction.
+ average_ranks.append(np.mean(np.sort(ranks)[:3]))
+ rank1 = np.sum(np.array(average_ranks) <= 1)/float(len(average_ranks))
+ rank5 = np.sum(np.array(average_ranks) <= 5)/float(len(average_ranks))
+ miou = np.mean(average_iou)
+
+ # print("Average rank@1: %f" % rank1)
+ # print("Average rank@5: %f" % rank5)
+ # print("Average iou: %f" % miou)
+ return rank1, rank5, miou
+
+
+class NLGMetric(Metric):
+ def __init__(
+ self,
+ config,
+ metric_names=[
+ "Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4",
+ "METEOR", "ROUGE_L", "CIDEr"
+ ]
+ ):
+ super().__init__(config, metric_names)
+ # please install NLGEval from `https://github.com/Maluuba/nlg-eval`
+ from nlgeval import NLGEval
+ self.nlg = NLGEval()
+
+ def compute_metrics(self, outputs, targets, **kwargs):
+ return self.nlg.compute_metrics(
+ hyp_list=outputs, ref_list=targets)
+
+ def print_computed_metrics(self, metrics):
+ Bleu_1 = metrics["Bleu_1"]
+ Bleu_2 = metrics["Bleu_2"]
+ Bleu_3 = metrics["Bleu_3"]
+ Bleu_4 = metrics["Bleu_4"]
+ METEOR = metrics["METEOR"]
+ ROUGE_L = metrics["ROUGE_L"]
+ CIDEr = metrics["CIDEr"]
+
+ print(
+ "Bleu_1: {:.4f} - Bleu_2: {:.4f} - Bleu_3: {:.4f} - Bleu_4: {:.4f} - METEOR: {:.4f} - ROUGE_L: {:.4f} - CIDEr: {:.4f}".format(
+ Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, CIDEr
+ )
+ )
+
+
+class QAMetric(Metric):
+ def __init__(
+ self,
+ config,
+ metric_names=["acc"]
+ ):
+ super().__init__(config, metric_names)
+
+ def compute_metrics(self, outputs, targets, **kwargs):
+ from sklearn.metrics import accuracy_score
+ return {"acc": accuracy_score(targets, outputs)}
+
+ def print_computed_metrics(self, metrics):
+ print("acc: {:.4f}".format(metrics["acc"]))
+
+
+class COINActionSegmentationMetric(Metric):
+ """
+ COIN dataset listed 3 repos for Action Segmentation.
+ Action Sets, NeuralNetwork-Viterbi, TCFPN-ISBA.
+ The first and second are the same.
+ https://github.com/alexanderrichard/action-sets/blob/master/eval.py
+
+ Future reference for the third:
+ `https://github.com/Zephyr-D/TCFPN-ISBA/blob/master/utils/metrics.py`
+ """
+ def __init__(self, config, metric_name=["frame_acc"]):
+ super().__init__(config, metric_name)
+
+ def compute_metrics(self, outputs, targets):
+ n_frames = 0
+ n_errors = 0
+ n_errors = sum(outputs != targets)
+ n_frames = len(targets)
+ return {"frame_acc": 1.0 - float(n_errors) / n_frames}
+
+ def print_computed_metrics(self, metrics):
+ fa = metrics["frame_acc"]
+ print("frame accuracy:", fa)
+
+
+class CrossTaskMetric(Metric):
+ def __init__(self, config, metric_names=["recall"]):
+ super().__init__(config, metric_names)
+
+ def compute_metrics(self, outputs, targets, **kwargs):
+ """refactored from line 166:
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
+
+ recalls = self._get_recalls(Y_true=targets, Y_pred=outputs)
+ results = {}
+ for task, rec in recalls.items():
+ results[str(task)] = rec
+
+ avg_recall = np.mean(list(recalls.values()))
+ results["recall"] = avg_recall
+ return results
+
+ def print_computed_metrics(self, metrics):
+ print('Recall: {0:0.3f}'.format(metrics["recall"]))
+ for task in metrics:
+ if task != "recall":
+ print('Task {0}. Recall = {1:0.3f}'.format(
+ task, metrics[task]))
+
+ def _get_recalls(self, Y_true, Y_pred):
+ """refactored from
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
+
+ step_match = {task: 0 for task in Y_true.keys()}
+ step_total = {task: 0 for task in Y_true.keys()}
+ for task, ys_true in Y_true.items():
+ ys_pred = Y_pred[task]
+ for vid in set(ys_pred.keys()).intersection(set(ys_true.keys())):
+ y_true = ys_true[vid]
+ y_pred = ys_pred[vid]
+ step_total[task] += (y_true.sum(axis=0) > 0).sum()
+ step_match[task] += (y_true*y_pred).sum()
+ recalls = {
+ task: step_match[task] / n for task, n in step_total.items()}
+ return recalls
+
+
+class ActionRecognitionMetric(Metric):
+ def __init__(
+ self,
+ config,
+ metric_names=["acc", "acc_splits", "r1_splits", "r5_splits", "r10_splits"]
+ ):
+ super().__init__(config, metric_names)
+
+ def compute_metrics(self, outputs, targets, splits, **kwargs):
+ all_video_embd = outputs
+ labels = targets
+ split1, split2, split3 = splits
+ accs = []
+ r1s = []
+ r5s = []
+ r10s = []
+ for split in range(3):
+ if split == 0:
+ s = split1
+ elif split == 1:
+ s = split2
+ else:
+ s = split3
+
+ X_pred = all_video_embd[np.where(s == 2)[0]]
+ label_test = labels[np.where(s == 2)[0]]
+ logits = X_pred
+ X_pred = np.argmax(X_pred, axis=1)
+ acc = np.sum(X_pred == label_test) / float(len(X_pred))
+ accs.append(acc)
+ # compute recall.
+ sorted_pred = (-logits).argsort(axis=-1)
+ label_test_sp = label_test.reshape(-1, 1)
+
+ r1 = np.mean((sorted_pred[:, :1] == label_test_sp).sum(axis=1), axis=0)
+ r5 = np.mean((sorted_pred[:, :5] == label_test_sp).sum(axis=1), axis=0)
+ r10 = np.mean((sorted_pred[:, :10] == label_test_sp).sum(axis=1), axis=0)
+ r1s.append(r1)
+ r5s.append(r5)
+ r10s.append(r10)
+
+ return {"acc": accs[0], "acc_splits": accs, "r1_splits": r1s, "r5_splits": r5s, "r10_splits": r10s}
+
+ def print_computed_metrics(self, metrics):
+ for split, acc in enumerate(metrics["acc_splits"]):
+ print("Top 1 accuracy on split {}: {}; r1 {}; r5 {}; r10 {}".format(
+ split + 1, acc,
+ metrics["r1_splits"][split],
+ metrics["r5_splits"][split],
+ metrics["r10_splits"][split],
+ )
+ )
diff --git a/examples/MMPT/mmpt/evaluators/predictor.py b/examples/MMPT/mmpt/evaluators/predictor.py
new file mode 100644
index 0000000000..2ffef6ab47
--- /dev/null
+++ b/examples/MMPT/mmpt/evaluators/predictor.py
@@ -0,0 +1,595 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import random
+import json
+import numpy as np
+import torch
+import pickle
+import math
+
+from tqdm import tqdm
+
+
+class Predictor(object):
+ """this base class is used to save predictions to disk
+ (and being called by a evaluator later).
+ Predictor has minimum support of single gpu prediction.
+ """
+ def __init__(self, config):
+ self.pred_dir = None # on-the-fly eval does not save the results.
+ if hasattr(config, "eval") and config.eval is not None:
+ self.pred_dir = config.eval.save_path
+ os.makedirs(self.pred_dir, exist_ok=True)
+
+ def __call__(self, outputs):
+ """extract the prediction and save it."""
+ raise NotImplementedError
+
+ def predict_loop(self, model, eval_dataloader, output_file=None):
+ """on-the-fly prediction on a single gpu."""
+ self.full_scores = []
+ model.eval()
+ model = model.to(0)
+ with torch.no_grad():
+ for data in eval_dataloader:
+ data = self.to_ctx(data)
+ outputs = model(**data)
+ outputs.update(data)
+ self(outputs)
+ return self.finalize(output_file)
+
+ def finalize(self, output_file):
+ pass
+
+ def to_ctx(self, data, ctx=0, dtype=None):
+ if isinstance(data, dict):
+ for key in data:
+ if torch.is_tensor(data[key]):
+ if dtype is not None and data[key].dtype == torch.float32:
+ data[key] = data[key].to(dtype)
+ data[key] = data[key].to(ctx)
+ return data
+ else:
+ raise ValueError("non-dict type of batch is not supported yet.")
+
+
+class NLGPredictor(Predictor):
+ """Predicting Text from MMFusion models."""
+ """TODO: make a context."""
+ def __init__(self, config):
+ super().__init__(config)
+ from transformers import AutoTokenizer
+
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ config.dataset.bert_name,
+ bos_token="[CLS]", eos_token="[SEP]")
+ self.bos_token_id = self.tokenizer.bos_token_id
+ self.eos_token_id = self.tokenizer.eos_token_id
+
+ def predict_loop(self, model, eval_dataloader, output_file=None):
+ """TODO: refactor base classes."""
+ ctx = 0
+ outputs = {"outputs": [], "targets": [[]]}
+ model.eval()
+ model = model.to(ctx)
+ with torch.no_grad():
+ for data in tqdm(eval_dataloader):
+ data = self.to_ctx(data, ctx)
+ self(data, model, outputs)
+ return self.finalize(outputs, output_file)
+
+ def __call__(self, data, model, outputs):
+ data.update({
+ "bos_token_id": self.bos_token_id,
+ "eos_token_id": self.eos_token_id
+ })
+
+ output = model.generate(**data)
+ assert len(output) == len(data["ref"])
+ for idx, _output in enumerate(output):
+ generated_text = self.tokenizer.decode(
+ _output, skip_special_tokens=True)
+ if generated_text == "":
+ generated_text = "none"
+ outputs["outputs"].append(generated_text)
+ outputs["targets"][0].append(data["ref"][idx])
+ if random.random() < 0.001:
+ print("_output", _output)
+ print("generated_text", generated_text)
+ print("ref", data["ref"][idx])
+
+ def finalize(self, outputs, output_file=None):
+ if output_file is not None:
+ with open(os.path.join(
+ self.pred_dir, output_file + ".json"), "w") as fw:
+ json.dump(outputs, fw, indent=4)
+ return outputs
+
+
+class RetrievalPredictor(Predictor):
+ """generated `pooled_video` and `pooled_text`."""
+ def __init__(self, config):
+ super().__init__(config)
+ from transformers import AutoTokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ config.dataset.bert_name)
+
+ def predict_loop(
+ self,
+ model,
+ eval_dataloader,
+ output_file="retrieval.npy"
+ ):
+ """on-the-fly prediction on a single gpu."""
+ full_scores = []
+ texts = []
+ model.eval()
+ model = model.cuda()
+ with torch.no_grad():
+ for data in eval_dataloader:
+ # convert to dict.
+ if not isinstance(data, dict):
+ data = {
+ "caps": data[0],
+ "cmasks": data[1],
+ "vfeats": data[2],
+ "vmasks": data[3],
+ "video_id": data[4]
+ }
+ data = self.to_ctx(data)
+ outputs = model(**data)
+ outputs.update(data)
+ self(outputs, full_scores)
+ for _cap in data["caps"]:
+ texts.append(
+ self.tokenizer.decode(_cap, skip_special_tokens=True)
+ )
+
+ return self.finalize(full_scores, texts, output_file)
+
+ def __call__(self, sample, full_scores):
+ scores = self._get_pooled_outputs(sample)
+ self._append_scores(scores, full_scores)
+
+ def finalize(self, full_scores, texts, output_file=None):
+ outputs = self._aggregate_scores(full_scores)
+ if output_file is not None:
+ np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
+ return {"outputs": outputs, "texts": texts}
+
+ def _get_pooled_outputs(self, outputs):
+ if "pooled_video" in outputs:
+ return outputs["pooled_video"], outputs["pooled_text"]
+ else:
+ raise ValueError("unknown format of outputs.")
+
+ def _append_scores(self, scores, full_scores):
+ assert len(scores) == 2
+ if len(full_scores) == 0:
+ full_scores.append([])
+ full_scores.append([])
+ full_scores[0].append(scores[0].cpu().detach().numpy())
+ full_scores[1].append(scores[1].cpu().detach().numpy())
+
+ def _aggregate_scores(self, scores):
+ assert len(scores) == 2
+ video_hidden = np.concatenate(scores[0], axis=0)
+ text_hidden = np.concatenate(scores[1], axis=0)
+ # clear up.
+ self.full_scores = []
+ return np.matmul(text_hidden, video_hidden.T)
+
+
+class QAPredictor(Predictor):
+ """generated `pooled_video` and `pooled_text`."""
+ def __init__(self, config):
+ super().__init__(config)
+ """predictor maintains scores and aggregate them."""
+
+ def predict_loop(self, model, eval_dataloader, output_file="qa.npy"):
+ """on-the-fly prediction on a single gpu."""
+ self.full_scores = []
+ model.eval()
+ model = model.cuda()
+ with torch.no_grad():
+ for data in eval_dataloader:
+ # reshape ans and dup video 5 times.
+ v_len = data["vfeats"].size(1)
+ hidden_size = data["vfeats"].size(2)
+ data["vfeats"] = data["vfeats"].unsqueeze(1).repeat(1, 5, 1, 1).view(-1, v_len, hidden_size)
+ data["vmasks"] = data["vmasks"].unsqueeze(1).repeat(1, 5, 1).view(-1, v_len)
+
+ t_len = data["caps"].size(-1)
+ data["caps"] = data["caps"].view(-1, t_len)
+ data["cmasks"] = data["cmasks"].view(-1, t_len)
+
+ data = self.to_ctx(data)
+ outputs = model(**data)
+ outputs.update(data)
+ self(outputs)
+ return self.finalize(output_file)
+
+ def __call__(self, sample):
+ hidden_size = sample["pooled_video"].size(-1)
+ pooled_video = sample["pooled_video"].view(-1, 5, hidden_size)
+ pooled_text = sample["pooled_text"].view(-1, 5, hidden_size)
+ scores = torch.bmm(pooled_video, pooled_text.transpose(2, 1))
+ scores = scores.argmax(-1)
+ self._append_scores(scores[:, 0], sample["answers"], self.full_scores)
+
+ def finalize(self, output_file=None):
+ outputs, targets = self._aggregate_scores(self.full_scores)
+ if output_file is not None:
+ np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
+ return {"outputs": outputs, "targets": targets}
+
+ def _append_scores(self, scores, answers, full_scores):
+ if len(full_scores) == 0:
+ full_scores.append([])
+ full_scores.append([])
+ full_scores[0].append(scores.cpu().detach().numpy())
+ full_scores[1].append(answers.cpu().detach().numpy())
+
+ def _aggregate_scores(self, scores):
+ assert len(scores) == 2
+ outputs = np.concatenate(scores[0], axis=0)
+ targets = np.concatenate(scores[1], axis=0)
+ # clear up.
+ self.full_scores = []
+ return outputs, targets
+
+
+class CrossTaskPredictor(Predictor):
+ """
+ CrossTaskPredictor needs to compute the average of logits
+ for overlapped sliding-window.
+ """
+ def __init__(self, config):
+ super().__init__(config)
+ self.lsm = torch.nn.LogSoftmax(dim=1)
+ self.max_video_len = config.dataset.max_video_len
+ self.sliding_window = config.dataset.sliding_window
+ self.sliding_window_size = config.dataset.sliding_window_size
+ self.annotation_path = config.dataset.annotation_path
+
+ def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
+ """refactored from line 144:
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py
+ """
+ ctx = 0
+ model.eval()
+ model = model.to(ctx)
+ # this is not a loss but just compute neg_log_prob.
+ Y_pred = {}
+ Y_true = {}
+ with torch.no_grad():
+ for batch in eval_dataloader:
+ self(batch, model, Y_pred, Y_true)
+ return self.finalize(Y_pred, Y_true, output_file)
+
+ def __call__(self, sample, model, Y_pred, Y_true):
+ # please install dp from `https://github.com/DmZhukov/CrossTask`
+ from dp import dp
+ vid, task = sample['video_id'][0], sample['task'][0]
+ sample = self.to_ctx(sample)
+ # compute the average logits over sliding windows.
+ output = model(**sample)
+ batch_logits = output["logits"].cpu()
+
+ video_len = sample["video_len"][0]
+
+ # the following version is slow.
+ logits = torch.zeros((video_len, batch_logits.size(1)))
+ logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
+ # use the same loop as aligner to recover.
+ batch_logit_idx = 0
+ for window_start in range(0, video_len, self.sliding_window):
+ video_end = min(video_len - window_start, self.sliding_window_size)
+ logits[window_start: window_start + video_end] += batch_logits[
+ batch_logit_idx: batch_logit_idx + video_end]
+ batch_logit_idx += video_end
+ logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
+
+ if (video_len - window_start) <= self.sliding_window_size:
+ break
+
+ logits /= logits_counts
+ assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
+
+ O = self.lsm(logits)
+ y = np.zeros(O.size(), dtype=np.float32)
+ dp(y, -O.detach().cpu().numpy())
+ if task not in Y_pred:
+ Y_pred[task] = {}
+ Y_pred[task][vid] = y
+ annot_path = os.path.join(
+ self.annotation_path, task+'_'+vid+'.csv')
+ if os.path.exists(annot_path):
+ if task not in Y_true:
+ Y_true[task] = {}
+ Y_true[task][vid] = self._read_assignment(
+ *y.shape, annot_path)
+
+ def finalize(self, Y_pred, Y_true, output_file=None):
+ if output_file is not None:
+ with open(
+ os.path.join(self.pred_dir, output_file + ".pkl"),
+ "wb") as fw:
+ pickle.dump(
+ {"Y_pred": Y_pred, "Y_true": Y_true}, fw,
+ protocol=pickle.HIGHEST_PROTOCOL)
+ return {"outputs": Y_pred, "targets": Y_true}
+
+ def _read_assignment(self, T, K, path):
+ """
+ refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py
+ Howto interpret contraints on loss that is going to be minimized:
+ lambd is a big number;
+ self.lambd * C is a big number for all valid position (csv stores invalids)
+
+ def forward(self, O, Y, C):
+ return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum()
+
+ This will load the csv file and fill-in the step col from start to end rows.
+ """
+
+ Y = np.zeros([T, K], dtype=np.uint8)
+ with open(path, 'r') as f:
+ for line in f:
+ step, start, end = line.strip().split(',')
+ start = int(math.floor(float(start)))
+ end = int(math.ceil(float(end)))
+ step = int(step) - 1
+ Y[start:end, step] = 1
+ return Y
+
+
+class COINPredictor(Predictor):
+ """
+ COINPredictor is similar to CrossTask on sliding windows.
+ """
+ def __init__(self, config):
+ super().__init__(config)
+ self.max_video_len = config.dataset.max_video_len
+ self.sliding_window = config.dataset.sliding_window
+ self.sliding_window_size = config.dataset.sliding_window_size
+
+ def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
+ """refactored from line 144:
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py
+ """
+ ctx = 0
+ model.eval()
+ model = model.to(ctx)
+ # this is not a loss but just compute neg_log_prob.
+ Y_pred = []
+ Y_true = []
+ with torch.no_grad():
+ for batch in eval_dataloader:
+ self(batch, model, Y_pred, Y_true)
+ return self.finalize(Y_pred, Y_true, output_file)
+
+ def __call__(self, sample, model, Y_pred, Y_true):
+ sample = self.to_ctx(sample)
+ # compute the average logits over sliding windows.
+ output = model(**sample)
+ logits = self._merge_windows(sample, output)
+ Y_pred.append(logits.argmax(dim=1))
+ Y_true.append(sample["video_targets"].squeeze(0).cpu())
+
+ def _merge_windows(self, sample, output):
+ targets = sample["targets"].reshape(-1).cpu()
+ valid_mask = targets != -100
+ targets = targets[valid_mask]
+ batch_logits = output["logits"].cpu()
+ batch_logits = batch_logits.reshape(-1, batch_logits.size(-1))
+ batch_logits = batch_logits[valid_mask]
+
+ video_len = sample["video_len"][0]
+
+ # the following version is slow.
+ logits = torch.zeros((video_len, batch_logits.size(1)))
+ logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
+ # use the same loop as aligner to recover.
+ batch_logit_idx = 0
+ for window_start in range(0, video_len, self.sliding_window):
+ video_end = min(video_len - window_start, self.sliding_window_size)
+ logits[window_start: window_start + video_end] += batch_logits[
+ batch_logit_idx: batch_logit_idx + video_end]
+ batch_logit_idx += video_end
+ logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
+ if (video_len - window_start) <= self.sliding_window_size:
+ break
+ logits /= logits_counts
+ assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
+ return logits
+
+ def finalize(self, Y_pred, Y_true, output_file=None):
+ Y_pred = torch.cat(Y_pred, dim=0).numpy()
+ Y_true = torch.cat(Y_true, dim=0).numpy()
+ assert len(Y_pred) == len(Y_true)
+
+ error_mask = Y_pred != Y_true
+ print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
+ print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
+
+ if output_file is not None:
+ with open(
+ os.path.join(self.pred_dir, output_file + ".pkl"),
+ "wb") as fw:
+ pickle.dump(
+ {"Y_pred": Y_pred, "Y_true": Y_true}, fw,
+ protocol=pickle.HIGHEST_PROTOCOL)
+ return {"outputs": Y_pred, "targets": Y_true}
+
+
+class COINZSPredictor(COINPredictor):
+ """
+ COINZSPredictor for COIN zero-shot prediction.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.dataset_config = config.dataset
+
+ def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
+ """refactored from line 144:
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py
+ """
+ ctx = 0
+ model.eval()
+ model = model.to(ctx)
+
+ with torch.no_grad():
+ outputs = eval_dataloader.dataset.meta_processor.meta_text_labels(
+ self.dataset_config)
+ outputs = self.to_ctx(outputs, ctx)
+ label_hidden_states = model.forward_text(**outputs).cpu()
+ label_sim = label_hidden_states @ label_hidden_states.t()
+ num_labels = label_sim.size(0)
+ eye_mask = ~torch.eye(num_labels, dtype=torch.bool)
+ label_sim = label_sim.masked_select(eye_mask).view(num_labels, num_labels - 1)
+ lbd = label_sim.max()
+
+ # this is not a loss but just compute neg_log_prob.
+ Y_pred = []
+ Y_true = []
+ with torch.no_grad():
+ for batch in eval_dataloader:
+ self(batch, label_hidden_states, model, lbd, Y_pred, Y_true)
+ return self.finalize(Y_pred, Y_true, output_file)
+
+ def reshape_subsample(self, sample):
+ for key in sample:
+ if torch.is_tensor(sample[key]):
+ sample[key] = self.flat_subsample(sample[key])
+ return sample
+
+ def flat_subsample(self, tensor):
+ if len(tensor.size()) > 1 and tensor.size(0) == 1:
+ tensor = tensor.squeeze(0)
+ return tensor
+
+ def __call__(self, sample, label_hidden_states, model, lbd, Y_pred, Y_true):
+ sample = self.reshape_subsample(sample)
+ sample = self.to_ctx(sample)
+ # compute the average logits over sliding windows.
+ sample["output_hidden_states"] = True
+ video_outputs = model.forward_video(**sample).cpu()
+ output = {"logits": video_outputs[:, 1:sample["vmasks"].size(1)+1] @ label_hidden_states.t()}
+ logits = self._merge_windows(sample, output)
+ # logic of zero-shot for sequence labeling.
+ logits_argmax = logits.argmax(dim=1) + 1 # 0 is "O" label.
+ logits_max = logits.max(dim=1)[0]
+
+ pred = torch.zeros_like(logits_argmax)
+ label_select = logits_max > lbd # 73 or 74
+ pred[label_select] = logits_argmax[label_select]
+
+ Y_pred.append(pred)
+ Y_true.append(sample["video_targets"].squeeze(0).cpu())
+
+ def finalize(self, Y_pred, Y_true, output_file=None):
+ Y_pred = torch.cat(Y_pred, dim=0).numpy()
+ Y_true = torch.cat(Y_true, dim=0).numpy()
+ assert len(Y_pred) == len(Y_true)
+
+ error_mask = Y_pred != Y_true
+ print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
+ print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
+
+ if output_file is not None:
+ with open(
+ os.path.join(self.pred_dir, output_file + ".pkl"),
+ "wb") as fw:
+ pickle.dump(
+ {"Y_pred": Y_pred, "Y_true": Y_true}, fw,
+ protocol=pickle.HIGHEST_PROTOCOL)
+ return {"outputs": Y_pred, "targets": Y_true}
+
+
+class DiDeMoPredictor(Predictor):
+ """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
+ https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
+ """
+ def __init__(self, config):
+ super().__init__(config)
+ # load targets.
+ with open(config.dataset.test_path) as data_file:
+ self.test_data = json.load(data_file)
+
+ def predict_loop(self, model, eval_dataloader, output_file="didemo.npy"):
+ """
+ TODO: two solutions here.
+ """
+ import itertools
+ # 21 chunks.
+ self.possible_segments = [(0,0), (1,1), (2,2), (3,3), (4,4), (5,5)]
+ for i in itertools.combinations(range(6), 2):
+ self.possible_segments.append(i)
+ # pick segments from a video.
+
+ """on-the-fly prediction on a single gpu."""
+ self.full_scores = []
+ model.eval()
+ model = model.cuda()
+ with torch.no_grad():
+ for data in eval_dataloader:
+ # TODO special forwarding logic here.
+ data = self.to_ctx(data)
+ data["output_hidden_states"] = True
+ hidden_video = model.forward_video(**data)
+ data["output_hidden_states"] = False
+ pooled_text = model.forward_text(**data)
+ outputs = {
+ "hidden_video": hidden_video,
+ "pooled_text": pooled_text
+ }
+ outputs.update(data)
+ self(outputs)
+ return self.finalize(output_file)
+
+ def __call__(self, sample):
+ # TODO: make an index select from self.possible_segments.
+ hidden_video = sample["hidden_video"]
+ pooled_text = sample["pooled_text"]
+ vmasks = sample["vmasks"]
+ # probably maintain valid results here.
+
+ hidden_video = hidden_video[:, 1:-1, :]
+ # probably maintain valid results here.
+ pooled_video = []
+ for s, e in self.possible_segments:
+ pooled_video.append(
+ torch.mean(
+ hidden_video[:, int(s*5):int((e+1)*5), :],
+ dim=1, keepdim=True)
+ )
+ pooled_video = torch.cat(pooled_video, dim=1)
+ scores = torch.bmm(
+ pooled_video, pooled_text.unsqueeze(-1)).squeeze(-1).cpu()
+
+ ranks = scores.argsort(dim=-1, descending=True)
+
+ for batch_idx, rank in enumerate(ranks):
+ rank_of_moment = []
+ for m_idx, moment in enumerate(rank):
+ s, e = self.possible_segments[moment.item()]
+ if torch.any(
+ vmasks[batch_idx, int(s*5):int((e+1)*5)]
+ ):
+ rank_of_moment.append((s, e))
+ self.full_scores.append(rank_of_moment)
+
+ def finalize(self, output_file=None):
+ outputs = self._aggregate_scores(self.full_scores)
+ if output_file is not None:
+ np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
+ return {"outputs": outputs, "targets": self.test_data}
+
+ def _aggregate_scores(self, scores):
+ self.full_scores = []
+ return scores
diff --git a/examples/MMPT/mmpt/losses/__init__.py b/examples/MMPT/mmpt/losses/__init__.py
new file mode 100644
index 0000000000..8dc32c96d2
--- /dev/null
+++ b/examples/MMPT/mmpt/losses/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .loss import *
+from .nce import *
+
+try:
+ from .fairseqmmloss import *
+except ImportError:
+ pass
+
+try:
+ from .expnce import *
+except ImportError:
+ pass
diff --git a/examples/MMPT/mmpt/losses/fairseqmmloss.py b/examples/MMPT/mmpt/losses/fairseqmmloss.py
new file mode 100644
index 0000000000..a95e5ecf45
--- /dev/null
+++ b/examples/MMPT/mmpt/losses/fairseqmmloss.py
@@ -0,0 +1,63 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+TODO (huxu): a general fairseq criterion for all your pre-defined losses.
+"""
+
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.logging import metrics
+
+
+@register_criterion("mmloss")
+class MMCriterion(FairseqCriterion):
+ def __init__(self, task):
+ super().__init__(task)
+ # TODO (huxu): wrap forward call of loss_fn and eval_fn into task.
+ self.mmtask = task.mmtask
+
+ def forward(self, model, sample):
+ """Compute the loss for the given sample.
+ Returns a tuple with three elements:
+ 1) the loss
+ 2) the sample size, which is used as the denominator for the gradient
+ 3) logging outputs to display while training
+ """
+ outputs = self.mmtask(model, sample)
+
+ loss, loss_scalar, max_len, batch_size, sample_size = (
+ outputs["loss"],
+ outputs["loss_scalar"],
+ outputs["max_len"],
+ outputs["batch_size"],
+ outputs["sample_size"],
+ )
+
+ logging_output = {
+ "loss": loss_scalar,
+ "ntokens": max_len * batch_size, # dummy report.
+ "nsentences": batch_size, # dummy report.
+ "sample_size": sample_size,
+ }
+
+ return loss, 1, logging_output
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ """Aggregate logging outputs from data parallel training."""
+ """since we use NCE, our actual batch_size is 1 per GPU.
+ Then we take the mean of each worker."""
+ loss_sum = sum(log.get("loss", 0.0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+ metrics.log_scalar("loss", loss_sum / sample_size, round=3)
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return True
diff --git a/examples/MMPT/mmpt/losses/loss.py b/examples/MMPT/mmpt/losses/loss.py
new file mode 100644
index 0000000000..99c05d067e
--- /dev/null
+++ b/examples/MMPT/mmpt/losses/loss.py
@@ -0,0 +1,87 @@
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import torch
+
+from torch import nn
+
+
+class Loss(object):
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+# Dummy Loss for testing.
+class DummyLoss(Loss):
+ def __init__(self):
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, logits, targets, **kwargs):
+ return self.loss(logits, targets)
+
+
+class DummyK400Loss(Loss):
+ """dummy k400 loss for MViT."""
+ def __init__(self):
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, logits, targets, **kwargs):
+ return self.loss(
+ logits, torch.randint(0, 400, (logits.size(0),), device=logits.device))
+
+
+class CrossEntropy(Loss):
+ def __init__(self):
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, logits, targets, **kwargs):
+ return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
+
+
+class ArgmaxCrossEntropy(Loss):
+ def __init__(self):
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, logits, targets, **kwargs):
+ return self.loss(logits, targets.argmax(dim=1))
+
+
+class BCE(Loss):
+ def __init__(self):
+ self.loss = nn.BCEWithLogitsLoss()
+
+ def __call__(self, logits, targets, **kwargs):
+ targets = targets.squeeze(0)
+ return self.loss(logits, targets)
+
+
+class NLGLoss(Loss):
+ def __init__(self):
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, logits, text_label, **kwargs):
+ targets = text_label[text_label != -100]
+ return self.loss(logits, targets)
+
+
+class MSE(Loss):
+ def __init__(self):
+ self.loss = nn.MSELoss()
+
+ def __call__(self, logits, targets, **kwargs):
+ return self.loss(logits, targets)
+
+
+class L1(Loss):
+ def __init__(self):
+ self.loss = nn.L1Loss()
+
+ def __call__(self, logits, targets, **kwargs):
+ return self.loss(logits, targets)
+
+
+class SmoothL1(Loss):
+ def __init__(self):
+ self.loss = nn.SmoothL1Loss()
+
+ def __call__(self, logits, targets, **kwargs):
+ return self.loss(logits, targets)
diff --git a/examples/MMPT/mmpt/losses/nce.py b/examples/MMPT/mmpt/losses/nce.py
new file mode 100644
index 0000000000..ed7be8d372
--- /dev/null
+++ b/examples/MMPT/mmpt/losses/nce.py
@@ -0,0 +1,156 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+softmax-based NCE loss, used by this project.
+"""
+
+import torch
+
+from torch import nn
+
+from .loss import Loss
+
+
+class NCE(Loss):
+ def __init__(self):
+ # TODO (huxu): define temperature.
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, align_scores, **kargs):
+ # note: we reuse the same shape as cls head in BERT (batch_size, 2)
+ # but NCE only needs one logits.
+ # (so we drop all weights in the second neg logits.)
+ align_scores = align_scores[:, :1]
+ # duplicate negative examples
+ batch_size = align_scores.size(0) // 2
+ pos_scores = align_scores[:batch_size]
+ neg_scores = align_scores[batch_size:].view(1, batch_size).repeat(
+ batch_size, 1)
+ scores = torch.cat([pos_scores, neg_scores], dim=1)
+ return self.loss(
+ scores,
+ torch.zeros(
+ (batch_size,),
+ dtype=torch.long,
+ device=align_scores.device),
+ )
+
+
+class T2VContraLoss(Loss):
+ """NCE for MM joint space, on softmax text2video matrix.
+ """
+ def __init__(self):
+ # TODO (huxu): define temperature.
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, pooled_video, pooled_text, **kargs):
+ batch_size = pooled_video.size(0)
+ logits = torch.mm(pooled_text, pooled_video.transpose(1, 0))
+ targets = torch.arange(
+ batch_size,
+ dtype=torch.long,
+ device=pooled_video.device)
+ return self.loss(logits, targets)
+
+
+class V2TContraLoss(Loss):
+ """NCE for MM joint space, with softmax on video2text matrix."""
+
+ def __init__(self):
+ # TODO (huxu): define temperature.
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, pooled_video, pooled_text, **kargs):
+ batch_size = pooled_video.size(0)
+ logits = torch.mm(pooled_video, pooled_text.transpose(1, 0))
+ targets = torch.arange(
+ batch_size,
+ dtype=torch.long,
+ device=pooled_video.device)
+ return self.loss(logits, targets)
+
+
+class MMContraLoss(Loss):
+ def __init__(self):
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, pooled_video, pooled_text, **kwargs):
+ logits_per_video = pooled_video @ pooled_text.t()
+ logits_per_text = pooled_text @ pooled_video.t()
+
+ targets = torch.arange(
+ pooled_video.size(0),
+ dtype=torch.long,
+ device=pooled_video.device)
+ loss_video = self.loss(logits_per_video, targets)
+ loss_text = self.loss(logits_per_text, targets)
+ return loss_video + loss_text
+
+
+class MTM(Loss):
+ """Combination of MFM and MLM."""
+
+ def __init__(self):
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(
+ self,
+ video_logits,
+ text_logits,
+ video_label,
+ text_label,
+ **kwargs
+ ):
+ text_logits = torch.cat([
+ text_logits,
+ torch.zeros(
+ (text_logits.size(0), 1), device=text_logits.device)
+ ], dim=1)
+ vt_logits = torch.cat([video_logits, text_logits], dim=0)
+ # loss for video.
+ video_label = torch.zeros(
+ (video_logits.size(0),),
+ dtype=torch.long,
+ device=video_logits.device
+ )
+
+ # loss for text.
+ text_label = text_label.reshape(-1)
+ labels_mask = text_label != -100
+ selected_text_label = text_label[labels_mask]
+
+ vt_label = torch.cat([video_label, selected_text_label], dim=0)
+ return self.loss(vt_logits, vt_label)
+
+
+class MFMMLM(Loss):
+ """Combination of MFM and MLM."""
+
+ def __init__(self):
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(
+ self,
+ video_logits,
+ text_logits,
+ video_label,
+ text_label,
+ **kwargs
+ ):
+ # loss for video.
+ video_label = torch.zeros(
+ (video_logits.size(0),),
+ dtype=torch.long,
+ device=video_logits.device
+ )
+ masked_frame_loss = self.loss(video_logits, video_label)
+
+ # loss for text.
+ text_label = text_label.reshape(-1)
+ labels_mask = text_label != -100
+ selected_text_label = text_label[labels_mask]
+ masked_lm_loss = self.loss(text_logits, selected_text_label)
+ return masked_frame_loss + masked_lm_loss
diff --git a/examples/MMPT/mmpt/models/__init__.py b/examples/MMPT/mmpt/models/__init__.py
new file mode 100644
index 0000000000..825250cd00
--- /dev/null
+++ b/examples/MMPT/mmpt/models/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .mmfusion import *
+from .transformermodel import *
+from .mmfusionnlg import *
+
+try:
+ from .fairseqmmmodel import *
+except ImportError:
+ pass
+
+try:
+ from .expmmfusion import *
+except ImportError:
+ pass
diff --git a/examples/MMPT/mmpt/models/fairseqmmmodel.py b/examples/MMPT/mmpt/models/fairseqmmmodel.py
new file mode 100644
index 0000000000..b7dd643693
--- /dev/null
+++ b/examples/MMPT/mmpt/models/fairseqmmmodel.py
@@ -0,0 +1,51 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.models import (
+ BaseFairseqModel,
+ register_model,
+ register_model_architecture
+)
+
+
+@register_model("mmmodel")
+class FairseqMMModel(BaseFairseqModel):
+ """a fairseq wrapper of model built by `task`."""
+
+ @classmethod
+ def build_model(cls, args, task):
+ return FairseqMMModel(task.mmtask.model)
+
+ def __init__(self, mmmodel):
+ super().__init__()
+ self.mmmodel = mmmodel
+
+ def forward(self, *args, **kwargs):
+ return self.mmmodel(*args, **kwargs)
+
+ def upgrade_state_dict_named(self, state_dict, name):
+
+ super().upgrade_state_dict_named(state_dict, name)
+
+ keys_to_delete = []
+
+ for key in state_dict:
+ if key not in self.state_dict():
+ keys_to_delete.append(key)
+ for key in keys_to_delete:
+ print("[INFO]", key, "not used anymore.")
+ del state_dict[key]
+
+ # copy any newly defined parameters.
+ for key in self.state_dict():
+ if key not in state_dict:
+ print("[INFO] adding", key)
+ state_dict[key] = self.state_dict()[key]
+
+
+# a dummy arch, we config the model.
+@register_model_architecture("mmmodel", "mmarch")
+def mmarch(args):
+ pass
diff --git a/examples/MMPT/mmpt/models/mmfusion.py b/examples/MMPT/mmpt/models/mmfusion.py
new file mode 100644
index 0000000000..2509e26b67
--- /dev/null
+++ b/examples/MMPT/mmpt/models/mmfusion.py
@@ -0,0 +1,926 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+
+import torch
+
+from torch import nn
+
+try:
+ from transformers import AutoConfig, AutoTokenizer
+except ImportError:
+ pass
+
+from . import transformermodel
+
+
+class MMPTModel(nn.Module):
+ """An e2e wrapper of inference model.
+ """
+ @classmethod
+ def from_pretrained(cls, config, checkpoint="checkpoint_best.pt"):
+ import os
+ from ..utils import recursive_config
+ from ..tasks import Task
+ config = recursive_config(config)
+ mmtask = Task.config_task(config)
+ checkpoint_path = os.path.join(config.eval.save_path, checkpoint)
+ mmtask.build_model(checkpoint=checkpoint_path)
+ # TODO(huxu): make the video encoder configurable.
+ from ..processors.models.s3dg import S3D
+ video_encoder = S3D('pretrained_models/s3d_dict.npy', 512)
+ video_encoder.load_state_dict(
+ torch.load('pretrained_models/s3d_howto100m.pth'))
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ config.dataset.bert_name, use_fast=config.dataset.use_fast
+ )
+ from ..processors import Aligner
+ aligner = Aligner(config.dataset)
+ return (
+ MMPTModel(config, mmtask.model, video_encoder),
+ tokenizer,
+ aligner
+ )
+
+ def __init__(self, config, model, video_encoder, **kwargs):
+ super().__init__()
+ self.max_video_len = config.dataset.max_video_len
+ self.video_encoder = video_encoder
+ self.model = model
+
+ def forward(self, video_frames, caps, cmasks, return_score=False):
+ bsz = video_frames.size(0)
+ assert bsz == 1, "only bsz=1 is supported now."
+ seq_len = video_frames.size(1)
+ video_frames = video_frames.view(-1, *video_frames.size()[2:])
+ vfeats = self.video_encoder(video_frames.permute(0, 4, 1, 2, 3))
+ vfeats = vfeats['video_embedding']
+ vfeats = vfeats.view(bsz, seq_len, vfeats.size(-1))
+ padding = torch.zeros(
+ bsz, self.max_video_len - seq_len, vfeats.size(-1))
+ vfeats = torch.cat([vfeats, padding], dim=1)
+ vmasks = torch.cat([
+ torch.ones((bsz, seq_len), dtype=torch.bool),
+ torch.zeros((bsz, self.max_video_len - seq_len), dtype=torch.bool)
+ ],
+ dim=1
+ )
+ output = self.model(caps, cmasks, vfeats, vmasks)
+ if return_score:
+ output = {"score": torch.bmm(
+ output["pooled_video"][:, None, :],
+ output["pooled_text"][:, :, None]
+ ).squeeze(-1).squeeze(-1)}
+ return output
+
+
+class MMFusion(nn.Module):
+ """a MMPT wrapper class for MMBert style models.
+ TODO: move isolated mask to a subclass.
+ """
+ def __init__(self, config, **kwargs):
+ super().__init__()
+ transformer_config = AutoConfig.from_pretrained(
+ config.dataset.bert_name)
+ self.hidden_size = transformer_config.hidden_size
+ self.is_train = False
+ if config.dataset.train_path is not None:
+ self.is_train = True
+ # 0 means no iso; 1-12 means iso up to that layer.
+ self.num_hidden_layers = transformer_config.num_hidden_layers
+ self.last_iso_layer = 0
+ if config.dataset.num_iso_layer is not None:
+ self.last_iso_layer = config.dataset.num_iso_layer - 1 + 1
+
+ if config.model.mm_encoder_cls is not None:
+ mm_encoder_cls = getattr(transformermodel, config.model.mm_encoder_cls)
+ model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
+ model_config.max_video_len = config.dataset.max_video_len
+ # TODO: a general way to add parameter for a model.
+ model_config.use_seg_emb = config.model.use_seg_emb
+ self.mm_encoder = mm_encoder_cls.from_pretrained(
+ config.dataset.bert_name, config=model_config)
+ elif config.model.video_encoder_cls is not None\
+ and config.model.text_encoder_cls is not None:
+ video_encoder_cls = getattr(transformermodel, config.model.video_encoder_cls)
+ model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
+ model_config.max_video_len = config.dataset.max_video_len
+ # TODO: make each model a set of config class.
+ if hasattr(model_config, "num_layers"):
+ model_config.num_layers = config.model.num_hidden_video_layers
+ else:
+ model_config.num_hidden_layers = config.model.num_hidden_video_layers
+ self.video_encoder = video_encoder_cls.from_pretrained(
+ config.dataset.bert_name, config=model_config)
+ # exact same NLP model from Huggingface.
+ text_encoder_cls = getattr(transformermodel, config.model.text_encoder_cls)
+ self.text_encoder = text_encoder_cls.from_pretrained(
+ config.dataset.bert_name)
+ else:
+ raise ValueError("the encoder must be either MM or two backbones.")
+
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ **kwargs
+ ):
+ raise NotImplementedError(
+ "Please derive MMFusion module."
+ )
+
+ def _mm_on_the_fly(
+ self,
+ cmasks,
+ vmasks,
+ attention_mask
+ ):
+ """helper function for mask, seg_ids and token_type_ids."""
+ if attention_mask is None:
+ attention_mask = self._mm_attention_mask(cmasks, vmasks)
+
+ """
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ """
+ token_type_ids = torch.cat(
+ [
+ torch.zeros(
+ (vmasks.size(0), vmasks.size(1) + 2),
+ dtype=torch.long,
+ device=vmasks.device,
+ ),
+ torch.ones(
+ (cmasks.size(0), cmasks.size(1) - 2),
+ dtype=torch.long,
+ device=cmasks.device,
+ ),
+ ],
+ dim=1,
+ )
+ return attention_mask, token_type_ids
+
+ def _mm_attention_mask(self, cmasks, vmasks):
+ assert cmasks.size(0) == vmasks.size(0), "{}, {}, {}, {}".format(
+ str(cmasks.size()),
+ str(vmasks.size()),
+ str(cmasks.size(0)),
+ str(vmasks.size(0)),
+ )
+
+ mm_mask = torch.cat([cmasks[:, :1], vmasks, cmasks[:, 1:]], dim=1)
+ if self.last_iso_layer == 0:
+ # hard attention mask.
+ return mm_mask
+ else:
+ # a gpu iso mask; 0 : num_iso_layer is isolated;
+ # num_iso_layer: are MM-fused.
+ # make an iso layer
+ batch_size = cmasks.size(0)
+ iso_mask = self._make_iso_mask(batch_size, cmasks, vmasks)
+ mm_mask = mm_mask[:, None, :].repeat(1, mm_mask.size(-1), 1)
+ iso_mm_masks = []
+ # hard attention mask.
+ iso_mask = iso_mask[:, None, :, :].repeat(
+ 1, self.last_iso_layer, 1, 1)
+ iso_mm_masks.append(iso_mask)
+ if self.last_iso_layer < self.num_hidden_layers:
+ mm_mask = mm_mask[:, None, :, :].repeat(
+ 1, self.num_hidden_layers - self.last_iso_layer, 1, 1
+ )
+ iso_mm_masks.append(mm_mask)
+ iso_mm_masks = torch.cat(iso_mm_masks, dim=1)
+ return iso_mm_masks
+
+ def _make_iso_mask(self, batch_size, cmasks, vmasks):
+ cls_self_mask = torch.cat(
+ [
+ torch.ones(
+ (batch_size, 1), dtype=torch.bool, device=cmasks.device),
+ torch.zeros(
+ (batch_size, cmasks.size(1) + vmasks.size(1) - 1),
+ dtype=torch.bool, device=cmasks.device)
+ ], dim=1)
+
+ iso_video_mask = torch.cat(
+ [
+ # [CLS] is not used.
+ torch.zeros(
+ (batch_size, 1), dtype=torch.bool, device=cmasks.device
+ ),
+ vmasks,
+ # assume to be 1.
+ cmasks[:, 1:2],
+ # 2 means [CLS] + [SEP]
+ torch.zeros(
+ (batch_size, cmasks.size(1) - 2),
+ dtype=torch.bool,
+ device=cmasks.device,
+ ),
+ ],
+ dim=1,
+ )
+ iso_text_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, 2 + vmasks.size(1)),
+ dtype=torch.bool,
+ device=cmasks.device,
+ ), # [CLS] is not used.
+ cmasks[:, 2:], # assume to be 1.
+ ],
+ dim=1,
+ )
+ cls_self_mask = cls_self_mask[:, None, :]
+ iso_video_mask = iso_video_mask[:, None, :].repeat(
+ 1, vmasks.size(1) + 1, 1)
+ iso_text_mask = iso_text_mask[:, None, :].repeat(
+ 1, cmasks.size(1) - 2, 1)
+ return torch.cat([cls_self_mask, iso_video_mask, iso_text_mask], dim=1)
+
+ def _pooling_vt_layer(
+ self,
+ layered_sequence_output,
+ cmasks,
+ vmasks
+ ):
+ layer_idx = self.last_iso_layer \
+ if self.last_iso_layer > 0 else self.num_hidden_layers
+ hidden_state = layered_sequence_output[layer_idx]
+ # also output pooled_video and pooled_text.
+ batch_size = cmasks.size(0)
+ # pool the modality.
+ text_offset = vmasks.size(1) + 2 # [CLS] + [SEP]
+ # video tokens + [SEP]
+ video_outputs = hidden_state[:, 1:text_offset]
+ video_attention_mask = torch.cat(
+ [
+ vmasks,
+ torch.ones(
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+ ],
+ dim=1,
+ )
+ assert video_outputs.size(1) == video_attention_mask.size(1)
+ pooled_video = torch.sum(
+ video_outputs * video_attention_mask.unsqueeze(-1), dim=1
+ ) / video_attention_mask.sum(1, keepdim=True)
+ # pooled_video = torch.mean(video_outputs[0], dim=1)
+
+ # text tokens + [SEP]
+ text_attention_mask = cmasks[:, 2:]
+ text_outputs = hidden_state[:, text_offset:]
+ assert text_outputs.size(1) == text_attention_mask.size(1)
+ pooled_text = torch.sum(
+ text_outputs * text_attention_mask.unsqueeze(-1), dim=1
+ ) / text_attention_mask.sum(1, keepdim=True)
+ return pooled_video, pooled_text
+
+
+class MMFusionMFMMLM(MMFusion):
+ """forward function for MFM and MLM."""
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ attention_mask=None,
+ video_label=None,
+ text_label=None,
+ **kwargs
+ ):
+ output_hidden_states = False if self.is_train else True
+
+ target_vfeats, non_masked_frame_mask = None, None
+ if video_label is not None:
+ target_vfeats = vfeats.masked_select(
+ video_label.unsqueeze(-1)).view(
+ -1, vfeats.size(-1)
+ )
+ # mask video token.
+ vfeats[video_label] = 0.0
+ non_masked_frame_mask = vmasks.clone()
+ non_masked_frame_mask[video_label] = False
+
+ attention_mask, token_type_ids = self._mm_on_the_fly(
+ cmasks, vmasks, attention_mask)
+
+ outputs = self.mm_encoder(
+ input_ids=caps,
+ input_video_embeds=vfeats,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ masked_frame_labels=video_label,
+ target_video_hidden_states=target_vfeats,
+ non_masked_frame_mask=non_masked_frame_mask,
+ masked_lm_labels=text_label,
+ output_hidden_states=output_hidden_states,
+ )
+
+ video_logits, text_logits = outputs[0], outputs[1]
+
+ if self.is_train: # return earlier for training.
+ return {
+ "video_logits": video_logits,
+ "text_logits": text_logits,
+ }
+
+ pooled_video, pooled_text = self._pooling_vt_layer(
+ outputs[2], cmasks, vmasks)
+ return {"pooled_video": pooled_video, "pooled_text": pooled_text}
+
+
+class MMFusionMTM(MMFusionMFMMLM):
+ def __init__(self, config, **kwargs):
+ super().__init__(config)
+ """
+ For reproducibility:
+ self.mm_encoder will be initialized then discarded.
+ """
+ from .transformermodel import MMBertForMTM
+ model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
+ model_config.max_video_len = config.dataset.max_video_len
+ model_config.use_seg_emb = config.model.use_seg_emb
+ self.mm_encoder = MMBertForMTM.from_pretrained(
+ config.dataset.bert_name, config=model_config)
+
+
+class MMFusionShare(MMFusion):
+ """A retrival wrapper using mm_encoder as both video/text backbone.
+ TODO: move formally.
+ """
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ attention_mask=None,
+ video_label=None,
+ text_label=None,
+ output_hidden_states=False,
+ **kwargs
+ ):
+ pooled_video = self.forward_video(
+ vfeats,
+ vmasks,
+ caps,
+ cmasks,
+ output_hidden_states
+ )
+
+ pooled_text = self.forward_text(
+ caps,
+ cmasks,
+ output_hidden_states
+ )
+
+ return {"pooled_video": pooled_video, "pooled_text": pooled_text}
+
+ def forward_video(
+ self,
+ vfeats,
+ vmasks,
+ caps,
+ cmasks,
+ output_hidden_states=False,
+ **kwargs
+ ):
+ input_ids = caps[:, :2]
+
+ attention_mask = torch.cat([
+ cmasks[:, :1],
+ vmasks,
+ cmasks[:, 1:2]
+ ], dim=1)
+
+ token_type_ids = torch.zeros(
+ (vmasks.size(0), vmasks.size(1) + 2),
+ dtype=torch.long,
+ device=vmasks.device)
+
+ outputs = self.mm_encoder(
+ input_ids=input_ids,
+ input_video_embeds=vfeats,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_hidden_states=True
+ )
+ video_outputs = outputs[0]
+
+ if output_hidden_states:
+ return video_outputs
+
+ batch_size = cmasks.size(0)
+
+ video_attention_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+ vmasks,
+ torch.ones(
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+ ],
+ dim=1,
+ )
+ assert video_outputs.size(1) == video_attention_mask.size(1)
+
+ video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
+ / video_attention_mask.sum(1, keepdim=True)
+
+ pooled_video = torch.bmm(
+ video_outputs.transpose(2, 1),
+ video_attention_mask.unsqueeze(2)
+ ).squeeze(-1)
+ return pooled_video # video_outputs
+
+ def forward_text(
+ self,
+ caps,
+ cmasks,
+ output_hidden_states=False,
+ **kwargs
+ ):
+ input_ids = torch.cat([
+ caps[:, :1], caps[:, 2:],
+ ], dim=1)
+
+ attention_mask = torch.cat([
+ cmasks[:, :1],
+ cmasks[:, 2:]
+ ], dim=1)
+
+ token_type_ids = torch.cat([
+ torch.zeros(
+ (cmasks.size(0), 1),
+ dtype=torch.long,
+ device=cmasks.device),
+ torch.ones(
+ (cmasks.size(0), cmasks.size(1) - 2),
+ dtype=torch.long,
+ device=cmasks.device)
+ ], dim=1)
+
+ outputs = self.mm_encoder(
+ input_ids=input_ids,
+ input_video_embeds=None,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_hidden_states=True
+ )
+ text_outputs = outputs[0]
+
+ if output_hidden_states:
+ return text_outputs
+
+ batch_size = caps.size(0)
+ # text tokens + [SEP]
+ text_attention_mask = torch.cat([
+ torch.zeros(
+ (batch_size, 1), dtype=torch.bool, device=cmasks.device),
+ cmasks[:, 2:]
+ ], dim=1)
+
+ assert text_outputs.size(1) == text_attention_mask.size(1)
+
+ text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
+ / text_attention_mask.sum(1, keepdim=True)
+
+ pooled_text = torch.bmm(
+ text_outputs.transpose(2, 1),
+ text_attention_mask.unsqueeze(2)
+ ).squeeze(-1)
+ return pooled_text # text_outputs
+
+
+class MMFusionSeparate(MMFusionShare):
+ def forward_video(
+ self,
+ vfeats,
+ vmasks,
+ caps,
+ cmasks,
+ output_hidden_states=False,
+ **kwargs
+ ):
+ input_ids = caps[:, :2]
+
+ attention_mask = torch.cat([
+ cmasks[:, :1],
+ vmasks,
+ cmasks[:, 1:2]
+ ], dim=1)
+
+ token_type_ids = torch.zeros(
+ (vmasks.size(0), vmasks.size(1) + 2),
+ dtype=torch.long,
+ device=vmasks.device)
+
+ outputs = self.video_encoder(
+ input_ids=input_ids,
+ input_video_embeds=vfeats,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_hidden_states=True
+ )
+ video_outputs = outputs[0]
+
+ if output_hidden_states:
+ return video_outputs
+
+ batch_size = cmasks.size(0)
+
+ video_attention_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+ vmasks,
+ torch.ones(
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+ ],
+ dim=1,
+ )
+ assert video_outputs.size(1) == video_attention_mask.size(1)
+
+ video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
+ / video_attention_mask.sum(1, keepdim=True)
+
+ pooled_video = torch.bmm(
+ video_outputs.transpose(2, 1),
+ video_attention_mask.unsqueeze(2)
+ ).squeeze(-1)
+ return pooled_video # video_outputs
+
+ def forward_text(
+ self,
+ caps,
+ cmasks,
+ output_hidden_states=False,
+ **kwargs
+ ):
+ input_ids = torch.cat([
+ caps[:, :1], caps[:, 2:],
+ ], dim=1)
+
+ attention_mask = torch.cat([
+ cmasks[:, :1],
+ cmasks[:, 2:]
+ ], dim=1)
+ # different from sharing, we use all-0 type.
+ token_type_ids = torch.zeros(
+ (cmasks.size(0), cmasks.size(1) - 1),
+ dtype=torch.long,
+ device=cmasks.device)
+
+ outputs = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_hidden_states=True
+ )
+ text_outputs = outputs[0]
+
+ if output_hidden_states:
+ return text_outputs
+
+ batch_size = caps.size(0)
+ # text tokens + [SEP]
+ text_attention_mask = torch.cat([
+ torch.zeros(
+ (batch_size, 1), dtype=torch.bool, device=cmasks.device),
+ cmasks[:, 2:]
+ ], dim=1)
+
+ assert text_outputs.size(1) == text_attention_mask.size(1)
+
+ text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
+ / text_attention_mask.sum(1, keepdim=True)
+
+ pooled_text = torch.bmm(
+ text_outputs.transpose(2, 1),
+ text_attention_mask.unsqueeze(2)
+ ).squeeze(-1)
+ return pooled_text # text_outputs
+
+
+class MMFusionJoint(MMFusion):
+ """fine-tuning wrapper for retrival task."""
+
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ attention_mask=None,
+ video_label=None,
+ text_label=None,
+ **kwargs
+ ):
+ # TODO (huxu): other ways to do negative examples; move the following
+ # into your criterion forward.
+ output_hidden_states = True
+
+ attention_mask, token_type_ids = self._mm_on_the_fly(
+ cmasks, vmasks, attention_mask)
+
+ separate_forward_split = (
+ None if self.is_train else vmasks.size(1) + 2
+ ) # [CLS] + [SEP]
+
+ outputs = self.mm_encoder(
+ input_ids=caps,
+ input_video_embeds=vfeats,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_hidden_states=output_hidden_states,
+ separate_forward_split=separate_forward_split,
+ )
+
+ pooled_video, pooled_text = self._pooling_vt_layer(
+ outputs[2], cmasks, vmasks)
+ return {"pooled_video": pooled_video, "pooled_text": pooled_text}
+
+
+class MMFusionActionSegmentation(MMFusion):
+ """Fine-tuning wrapper for action segmentation.
+ TODO: rename this for VLM.
+ """
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ attention_mask=None,
+ **kwargs
+ ):
+ # ActionLocalization assume of batch_size=1, squeeze it.
+ caps = caps.view(-1, caps.size(-1))
+ cmasks = cmasks.view(-1, cmasks.size(-1))
+ vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
+ vmasks = vmasks.view(-1, vmasks.size(-1))
+
+ # this may not cover all shapes of attention_mask.
+ attention_mask = attention_mask.view(
+ -1, attention_mask.size(2), attention_mask.size(3)) \
+ if attention_mask is not None else None
+
+ # TODO (huxu): other ways to do negative examples; move the following
+ # into your criterion forward.
+ output_hidden_states = True
+
+ # video forwarding, text is dummy; never use attention_mask.
+ attention_mask, token_type_ids = self._mm_on_the_fly(
+ cmasks, vmasks, attention_mask)
+
+ logits = self.mm_encoder(
+ input_ids=caps,
+ input_video_embeds=vfeats,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_hidden_states=output_hidden_states,
+ )
+ return {"logits": logits[0][:, 1:vmasks.size(1)+1]}
+
+
+class MMFusionActionLocalization(MMFusion):
+ """fine-tuning model for retrival task."""
+
+ def __init__(self, config, **kwargs):
+ super().__init__(config)
+ tokenizer = AutoTokenizer.from_pretrained(
+ config.dataset.bert_name)
+ self.cls_token_id = tokenizer.cls_token_id
+ self.sep_token_id = tokenizer.sep_token_id
+ self.pad_token_id = tokenizer.pad_token_id
+
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ attention_mask=None,
+ **kwargs
+ ):
+ # ActionLocalization assume of batch_size=1, squeeze it.
+ caps = caps.squeeze(0)
+ cmasks = cmasks.squeeze(0)
+ vfeats = vfeats.squeeze(0)
+ vmasks = vmasks.squeeze(0)
+ attention_mask = attention_mask.squeeze(0) if attention_mask is not None else None
+
+ # TODO (huxu): other ways to do negative examples; move the following
+ # into your criterion forward.
+ output_hidden_states = True
+
+ # a len1 dummy video token.
+ dummy_vfeats = torch.zeros(
+ (caps.size(0), 1, vfeats.size(-1)), device=vfeats.device, dtype=vfeats.dtype)
+ dummy_vmasks = torch.ones(
+ (caps.size(0), 1), dtype=torch.bool,
+ device=vfeats.device)
+
+ dummy_caps = torch.LongTensor(
+ [[self.cls_token_id, self.sep_token_id,
+ self.pad_token_id, self.sep_token_id]],
+ ).to(caps.device).repeat(vfeats.size(0), 1)
+ dummy_cmasks = torch.BoolTensor(
+ [[0, 1, 0, 1]] # pad are valid for attention.
+ ).to(caps.device).repeat(vfeats.size(0), 1)
+
+ # video forwarding, text is dummy; never use attention_mask.
+ attention_mask, token_type_ids = self._mm_on_the_fly(
+ dummy_cmasks, vmasks, None)
+
+ outputs = self.mm_encoder(
+ input_ids=dummy_caps,
+ input_video_embeds=vfeats,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_hidden_states=output_hidden_states,
+ )
+
+ layer_idx = self.last_iso_layer \
+ if self.last_iso_layer > 0 else self.num_hidden_layers
+
+ video_seq = outputs[2][layer_idx][:, 1:vmasks.size(1)+1].masked_select(
+ vmasks.unsqueeze(-1)
+ ).view(-1, self.hidden_size)
+
+ # text forwarding, video is dummy
+ attention_mask, token_type_ids = self._mm_on_the_fly(
+ cmasks, dummy_vmasks, None)
+
+ outputs = self.mm_encoder(
+ input_ids=caps,
+ input_video_embeds=dummy_vfeats,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_hidden_states=output_hidden_states,
+ )
+
+ _, pooled_text = self._pooling_vt_layer(
+ outputs[2], cmasks, dummy_vmasks)
+ # this line is not right.
+ logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
+ return {"logits": logits}
+
+
+# --------------- MMFusionSeparate for end tasks ---------------
+
+class MMFusionSeparateActionSegmentation(MMFusionSeparate):
+ """Fine-tuning wrapper for action segmentation."""
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ attention_mask=None,
+ **kwargs
+ ):
+ # ActionLocalization assume of batch_size=1, squeeze it.
+ caps = caps.view(-1, caps.size(-1))
+ cmasks = cmasks.view(-1, cmasks.size(-1))
+ vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
+ vmasks = vmasks.view(-1, vmasks.size(-1))
+ logits = self.forward_video(
+ vfeats,
+ vmasks,
+ caps,
+ cmasks,
+ output_hidden_states=True
+ )
+ return {"logits": logits[:, 1:vmasks.size(1)+1]}
+
+
+class MMFusionSeparateActionLocalization(MMFusionSeparate):
+ def __init__(self, config, **kwargs):
+ super().__init__(config)
+ tokenizer = AutoTokenizer.from_pretrained(
+ config.dataset.bert_name)
+ self.cls_token_id = tokenizer.cls_token_id
+ self.sep_token_id = tokenizer.sep_token_id
+ self.pad_token_id = tokenizer.pad_token_id
+
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ **kwargs
+ ):
+ # ActionLocalization assume of batch_size=1, squeeze it.
+ caps = caps.squeeze(0)
+ cmasks = cmasks.squeeze(0)
+ vfeats = vfeats.squeeze(0)
+ vmasks = vmasks.squeeze(0)
+
+ # TODO (huxu): other ways to do negative examples; move the following
+ # into your criterion forward.
+ dummy_caps = torch.LongTensor(
+ [[self.cls_token_id, self.sep_token_id,
+ self.pad_token_id, self.sep_token_id]],
+ ).to(caps.device).repeat(vfeats.size(0), 1)
+ dummy_cmasks = torch.BoolTensor(
+ [[0, 1, 0, 1]] # pad are valid for attention.
+ ).to(caps.device).repeat(vfeats.size(0), 1)
+
+ outputs = self.forward_video(
+ vfeats,
+ vmasks,
+ dummy_caps,
+ dummy_cmasks,
+ output_hidden_states=True
+ )
+
+ video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
+ vmasks.unsqueeze(-1)
+ ).view(-1, self.hidden_size)
+
+ pooled_text = self.forward_text(
+ caps,
+ cmasks,
+ output_hidden_states=False
+ )
+
+ # this line is not right.
+ logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
+ return {"logits": logits}
+
+
+class MMFusionShareActionLocalization(MMFusionShare):
+ def __init__(self, config, **kwargs):
+ super().__init__(config)
+ tokenizer = AutoTokenizer.from_pretrained(
+ config.dataset.bert_name)
+ self.cls_token_id = tokenizer.cls_token_id
+ self.sep_token_id = tokenizer.sep_token_id
+ self.pad_token_id = tokenizer.pad_token_id
+
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ **kwargs
+ ):
+ # ActionLocalization assume of batch_size=1, squeeze it.
+ caps = caps.squeeze(0)
+ cmasks = cmasks.squeeze(0)
+ vfeats = vfeats.squeeze(0)
+ vmasks = vmasks.squeeze(0)
+
+ # TODO (huxu): other ways to do negative examples; move the following
+ # into your criterion forward.
+ dummy_caps = torch.LongTensor(
+ [[self.cls_token_id, self.sep_token_id,
+ self.pad_token_id, self.sep_token_id]],
+ ).to(caps.device).repeat(vfeats.size(0), 1)
+ dummy_cmasks = torch.BoolTensor(
+ [[0, 1, 0, 1]] # pad are valid for attention.
+ ).to(caps.device).repeat(vfeats.size(0), 1)
+
+ outputs = self.forward_video(
+ vfeats,
+ vmasks,
+ dummy_caps,
+ dummy_cmasks,
+ output_hidden_states=True
+ )
+
+ video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
+ vmasks.unsqueeze(-1)
+ ).view(-1, self.hidden_size)
+
+ pooled_text = self.forward_text(
+ caps,
+ cmasks,
+ output_hidden_states=False
+ )
+
+ # this line is not right.
+ logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
+ return {"logits": logits}
diff --git a/examples/MMPT/mmpt/models/mmfusionnlg.py b/examples/MMPT/mmpt/models/mmfusionnlg.py
new file mode 100644
index 0000000000..9207e77dab
--- /dev/null
+++ b/examples/MMPT/mmpt/models/mmfusionnlg.py
@@ -0,0 +1,999 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+
+import torch
+
+from torch.nn import functional as F
+
+from typing import Optional, Iterable
+
+try:
+ from transformers import BertPreTrainedModel
+ from transformers.modeling_bert import BertOnlyMLMHead
+
+ from transformers.file_utils import ModelOutput
+ from transformers.modeling_outputs import CausalLMOutput
+ from transformers.generation_utils import (
+ BeamHypotheses,
+ top_k_top_p_filtering
+ )
+except ImportError:
+ pass
+
+from .mmfusion import MMFusion
+from .transformermodel import MMBertModel
+from ..modules import VideoTokenMLP
+
+
+class MMFusionNLG(MMFusion):
+ def __init__(self, config, **kwargs):
+ super().__init__(config)
+ if config.model.max_decode_length is not None:
+ self.max_length = min(
+ config.model.max_decode_length,
+ config.dataset.max_len - config.dataset.max_video_len - 3
+ )
+ else:
+ self.max_length = \
+ config.dataset.max_len - config.dataset.max_video_len - 3
+ self.gen_param = config.gen_param if config.gen_param is not None \
+ else {}
+
+ def forward(
+ self,
+ caps,
+ cmasks,
+ vfeats,
+ vmasks,
+ attention_mask,
+ video_label=None,
+ text_label=None,
+ **kwargs
+ ):
+ """use pre-trained LM header for generation."""
+ attention_mask, token_type_ids = self._mm_on_the_fly(
+ cmasks, vmasks, attention_mask)
+
+ outputs = self.mm_encoder(
+ input_ids=caps,
+ input_video_embeds=vfeats,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ masked_lm_labels=text_label,
+ )
+ return {"logits": outputs[0]}
+
+ @torch.no_grad()
+ def generate(
+ self,
+ caps, cmasks, vfeats, vmasks,
+ attention_mask=None,
+ bos_token_id=None,
+ eos_token_id=None,
+ **kwargs
+ ):
+ # a simplified interface from
+ # https://huggingface.co/transformers/v3.4.0/_modules/transformers/generation_utils.html#GenerationMixin.generate
+
+ # caps now only have
+ # [CLS], [SEP] (for video) and [CLS] (as bos_token)
+ assert caps.size(1) == 3
+
+ attention_mask, token_type_ids = self._mm_on_the_fly(
+ cmasks, vmasks, attention_mask)
+
+ output = self.mm_encoder.generate(
+ input_ids=caps,
+ input_video_embeds=vfeats,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ max_length=self.max_length,
+ **self.gen_param
+ )
+ return output
+
+
+class MMBertForNLG(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.bert = MMBertModel(config)
+ self.videomlp = VideoTokenMLP(config)
+ # we do not use `BertGenerationOnlyLMHead`
+ # because we can reuse pretraining.
+ self.cls = BertOnlyMLMHead(config)
+ self.hidden_size = config.hidden_size
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def forward(
+ self,
+ input_ids=None,
+ input_video_embeds=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ masked_lm_labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ # similar to MMBertForMFMMLM without MFM.
+ video_tokens = self.videomlp(input_video_embeds)
+ outputs = self.bert(
+ input_ids,
+ video_tokens,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ prediction_scores = None
+ if masked_lm_labels is not None:
+ text_offset = input_video_embeds.size(1) + 1 # [CLS]
+ # recover caps format: [CLS] [SEP] text [SEP]
+ text_sequence_output = torch.cat(
+ [sequence_output[:, :1], sequence_output[:, text_offset:]],
+ dim=1
+ )
+
+ # only compute select tokens to training to speed up.
+ hidden_size = text_sequence_output.size(-1)
+ # masked_lm_labels = masked_lm_labels.reshape(-1)
+ labels_mask = masked_lm_labels != -100
+
+ selected_text_output = text_sequence_output.masked_select(
+ labels_mask.unsqueeze(-1)
+ ).view(-1, hidden_size)
+ prediction_scores = self.cls(selected_text_output)
+
+ if not return_dict:
+ output = (
+ prediction_scores,
+ ) + outputs[2:]
+ return output
+
+ # for generation.
+ text_offset = input_video_embeds.size(1) + 2 # [CLS]
+ text_sequence_output = sequence_output[:, text_offset:]
+ prediction_scores = self.cls(text_sequence_output)
+ return CausalLMOutput(
+ loss=None,
+ logits=prediction_scores,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ input_video_embeds,
+ attention_mask=None,
+ token_type_ids=None,
+ **model_kwargs
+ ):
+ # must return a dictionary.
+ seq_len = input_ids.size(1) + input_video_embeds.size(1)
+ if attention_mask is not None:
+ if len(attention_mask.size()) == 4:
+ attention_mask = attention_mask[:, :, :seq_len, :seq_len]
+ elif len(attention_mask.size()) == 3:
+ attention_mask = attention_mask[:, :seq_len, :seq_len]
+ else:
+ attention_mask = attention_mask[:, :seq_len]
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids[:, :seq_len]
+
+ return {
+ "input_ids": input_ids,
+ "input_video_embeds": input_video_embeds,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ }
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ max_length: Optional[int] = None,
+ min_length: Optional[int] = None,
+ do_sample: Optional[bool] = None,
+ early_stopping: Optional[bool] = None,
+ num_beams: Optional[int] = None,
+ temperature: Optional[float] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ repetition_penalty: Optional[float] = None,
+ bad_words_ids: Optional[Iterable[int]] = None,
+ bos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[int] = None,
+ length_penalty: Optional[float] = None,
+ no_repeat_ngram_size: Optional[int] = None,
+ num_return_sequences: Optional[int] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_start_token_id: Optional[int] = None,
+ use_cache: Optional[bool] = None,
+ **model_kwargs
+ ) -> torch.LongTensor:
+ r"""
+ Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
+ beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
+ Adapted in part from `Facebook's XLM beam search code
+ `__.
+ Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
+ attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
+ indicated are the default values of those config.
+ Most of these parameters are explained in more detail in `this blog post
+ `__.
+ Parameters:
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ The sequence used as a prompt for the generation. If :obj:`None` the method initializes
+ it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
+ decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only
+ decoder_start_token_id is passed as the first token to the decoder.
+ max_length (:obj:`int`, `optional`, defaults to 20):
+ The maximum length of the sequence to be generated.
+ min_length (:obj:`int`, `optional`, defaults to 10):
+ The minimum length of the sequence to be generated.
+ do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not to use sampling ; use greedy decoding otherwise.
+ early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
+ num_beams (:obj:`int`, `optional`, defaults to 1):
+ Number of beams for beam search. 1 means no beam search.
+ temperature (:obj:`float`, `optional`, defaults tp 1.0):
+ The value used to module the next token probabilities.
+ top_k (:obj:`int`, `optional`, defaults to 50):
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
+ top_p (:obj:`float`, `optional`, defaults to 1.0):
+ If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or
+ higher are kept for generation.
+ repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
+ The parameter for repetition penalty. 1.0 means no penalty. See `this paper
+ `__ for more details.
+ pad_token_id (:obj:`int`, `optional`):
+ The id of the `padding` token.
+ bos_token_id (:obj:`int`, `optional`):
+ The id of the `beginning-of-sequence` token.
+ eos_token_id (:obj:`int`, `optional`):
+ The id of the `end-of-sequence` token.
+ length_penalty (:obj:`float`, `optional`, defaults to 1.0):
+ Exponential penalty to the length. 1.0 means no penalty.
+ Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
+ order to encourage the model to produce longer sequences.
+ no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
+ If set to int > 0, all ngrams of that size can only occur once.
+ bad_words_ids(:obj:`List[int]`, `optional`):
+ List of token ids that are not allowed to be generated. In order to get the tokens of the words that
+ should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
+ num_return_sequences(:obj:`int`, `optional`, defaults to 1):
+ The number of independently computed returned sequences for each element in the batch.
+ attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
+ tokens that are not masked, and 0 for masked tokens.
+ If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token.
+ `What are attention masks? <../glossary.html#attention-mask>`__
+ decoder_start_token_id (:obj:`int`, `optional`):
+ If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
+ use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not the model should use the past last key/values attentions (if applicable to the model) to
+ speed up decoding.
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
+ Return:
+ :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
+ The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
+ shorter if all batches finished early due to the :obj:`eos_token_id`.
+ Examples::
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
+ outputs = model.generate(max_length=40) # do greedy decoding
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
+ tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
+ input_context = 'The dog'
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
+ outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
+ for i in range(3): # 3 output sequences were generated
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
+ input_context = 'The dog'
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
+ outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling
+ for i in range(3): # 3 output sequences were generated
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
+ tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
+ input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
+ outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
+ tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
+ input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
+ bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
+ outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
+ """
+
+ # We cannot generate if the model does not have a LM head
+ if self.get_output_embeddings() is None:
+ raise AttributeError(
+ "You tried to generate sequences with a model that does not have a LM Head."
+ "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
+ )
+
+ max_length = max_length if max_length is not None else self.config.max_length
+ min_length = min_length if min_length is not None else self.config.min_length
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
+ temperature = temperature if temperature is not None else self.config.temperature
+ top_k = top_k if top_k is not None else self.config.top_k
+ top_p = top_p if top_p is not None else self.config.top_p
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
+ no_repeat_ngram_size = (
+ no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
+ )
+ bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
+ num_return_sequences = (
+ num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
+ )
+ decoder_start_token_id = (
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
+ )
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0] # overriden by the input batch_size
+ else:
+ batch_size = 1
+
+ assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
+ assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
+ assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
+ assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
+ assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
+ assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
+ assert temperature > 0, "`temperature` should be strictly positive."
+ assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
+ assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
+ assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
+ assert input_ids is not None or (
+ isinstance(bos_token_id, int) and bos_token_id >= 0
+ ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
+ assert pad_token_id is None or (
+ isinstance(pad_token_id, int) and (pad_token_id >= 0)
+ ), "`pad_token_id` should be a positive integer."
+ assert (eos_token_id is None) or (
+ isinstance(eos_token_id, int) and (eos_token_id >= 0)
+ ), "`eos_token_id` should be a positive integer."
+ assert length_penalty > 0, "`length_penalty` should be strictly positive."
+ assert (
+ isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
+ ), "`no_repeat_ngram_size` should be a positive integer."
+ assert (
+ isinstance(num_return_sequences, int) and num_return_sequences > 0
+ ), "`num_return_sequences` should be a strictly positive integer."
+ assert (
+ bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
+ ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
+
+ if input_ids is None:
+ assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
+ "you should either supply a context to complete as `input_ids` input "
+ "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
+ )
+ input_ids = torch.full(
+ (batch_size, 1),
+ bos_token_id,
+ dtype=torch.long,
+ device=next(self.parameters()).device,
+ )
+ else:
+ assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
+
+ # not allow to duplicate outputs when greedy decoding
+ if do_sample is False:
+ if num_beams == 1:
+ # no_beam_search greedy generation conditions
+ assert (
+ num_return_sequences == 1
+ ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
+
+ else:
+ # beam_search greedy generation conditions
+ assert (
+ num_beams >= num_return_sequences
+ ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
+
+ # create attention mask if necessary
+ # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
+ if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
+ attention_mask = input_ids.ne(pad_token_id).long()
+ elif attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ # set pad_token_id to eos_token_id if not set. Important that this is done after
+ # attention_mask is created
+ if pad_token_id is None and eos_token_id is not None:
+ print(
+ "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
+ )
+ pad_token_id = eos_token_id
+
+ # vocab size
+ if hasattr(self.config, "vocab_size"):
+ vocab_size = self.config.vocab_size
+ elif (
+ self.config.is_encoder_decoder
+ and hasattr(self.config, "decoder")
+ and hasattr(self.config.decoder, "vocab_size")
+ ):
+ vocab_size = self.config.decoder.vocab_size
+ else:
+ raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined")
+
+ # set effective batch size and effective batch multiplier according to do_sample
+ if do_sample:
+ effective_batch_size = batch_size * num_return_sequences
+ effective_batch_mult = num_return_sequences
+ else:
+ effective_batch_size = batch_size
+ effective_batch_mult = 1
+
+ if self.config.is_encoder_decoder:
+ if decoder_start_token_id is None:
+ # see if BOS token can be used for decoder_start_token_id
+ if bos_token_id is not None:
+ decoder_start_token_id = bos_token_id
+ elif (
+ hasattr(self.config, "decoder")
+ and hasattr(self.config.decoder, "bos_token_id")
+ and self.config.decoder.bos_token_id is not None
+ ):
+ decoder_start_token_id = self.config.decoder.bos_token_id
+ else:
+ raise ValueError(
+ "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
+ )
+
+ assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
+ assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
+
+ # get encoder and store encoder outputs
+ encoder = self.get_encoder()
+ encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
+
+ # Expand input ids if num_beams > 1 or num_return_sequences > 1
+ if num_return_sequences > 1 or num_beams > 1:
+ # TODO: make this a call-back function.
+ # input_ids=caps,
+ # input_video_embeds=vfeats,
+ # attention_mask=attention_mask,
+ # token_type_ids=token_type_ids,
+ input_video_embeds = model_kwargs.pop("input_video_embeds", None)
+ token_type_ids = model_kwargs.pop("token_type_ids", None)
+
+ input_ids_len = input_ids.shape[-1]
+ input_ids = input_ids.unsqueeze(1).expand(
+ batch_size, effective_batch_mult * num_beams, input_ids_len)
+
+ input_video_embeds_len, input_video_embeds_hidden = input_video_embeds.size(1), input_video_embeds.size(2)
+ input_video_embeds = input_video_embeds.unsqueeze(1).expand(
+ batch_size, effective_batch_mult * num_beams, input_video_embeds_len, input_video_embeds_hidden)
+
+ attention_mask_from_len, attention_mask_to_len = attention_mask.size(1), attention_mask.size(2)
+ attention_mask = attention_mask.unsqueeze(1).expand(
+ batch_size, effective_batch_mult * num_beams, attention_mask_from_len, attention_mask_to_len
+ )
+
+ token_type_ids_len = token_type_ids.size(1)
+ token_type_ids = token_type_ids.unsqueeze(1).expand(
+ batch_size, effective_batch_mult * num_beams, token_type_ids_len
+ )
+
+ # contiguous ...
+ input_ids = input_ids.contiguous().view(
+ effective_batch_size * num_beams, input_ids_len
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
+
+ input_video_embeds = input_video_embeds.contiguous().view(
+ effective_batch_size * num_beams, input_video_embeds_len, input_video_embeds_hidden)
+
+ attention_mask = attention_mask.contiguous().view(
+ effective_batch_size * num_beams, attention_mask_from_len, attention_mask_to_len
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
+
+ token_type_ids = token_type_ids.contiguous().view(
+ effective_batch_size * num_beams, token_type_ids_len
+ )
+
+ model_kwargs["input_video_embeds"] = input_video_embeds
+ model_kwargs["token_type_ids"] = token_type_ids
+
+ if self.config.is_encoder_decoder:
+ device = next(self.parameters()).device
+ if decoder_input_ids is not None:
+ # give initial decoder input ids
+ input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device)
+ else:
+ # create empty decoder input_ids
+ input_ids = torch.full(
+ (effective_batch_size * num_beams, 1),
+ decoder_start_token_id,
+ dtype=torch.long,
+ device=device,
+ )
+ cur_len = input_ids.shape[-1]
+
+ assert (
+ batch_size == encoder_outputs.last_hidden_state.shape[0]
+ ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
+
+ # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
+ expanded_batch_idxs = (
+ torch.arange(batch_size)
+ .view(-1, 1)
+ .repeat(1, num_beams * effective_batch_mult)
+ .view(-1)
+ .to(input_ids.device)
+ )
+
+ # expand encoder_outputs
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
+ 0, expanded_batch_idxs
+ )
+
+ # save encoder_outputs in `model_kwargs`
+ model_kwargs["encoder_outputs"] = encoder_outputs
+
+ else:
+ cur_len = input_ids.shape[-1]
+
+ assert (
+ cur_len < max_length
+ ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
+
+ if num_beams > 1:
+ output = self._generate_beam_search(
+ input_ids,
+ cur_len=cur_len,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=do_sample,
+ early_stopping=early_stopping,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ no_repeat_ngram_size=no_repeat_ngram_size,
+ bad_words_ids=bad_words_ids,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ batch_size=effective_batch_size,
+ num_return_sequences=num_return_sequences,
+ length_penalty=length_penalty,
+ num_beams=num_beams,
+ vocab_size=vocab_size,
+ attention_mask=attention_mask,
+ use_cache=use_cache,
+ model_kwargs=model_kwargs,
+ )
+ else:
+ output = self._generate_no_beam_search(
+ input_ids,
+ cur_len=cur_len,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ no_repeat_ngram_size=no_repeat_ngram_size,
+ bad_words_ids=bad_words_ids,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ batch_size=effective_batch_size,
+ attention_mask=attention_mask,
+ use_cache=use_cache,
+ model_kwargs=model_kwargs,
+ )
+
+ return output
+
+ def _generate_beam_search(
+ self,
+ input_ids,
+ cur_len,
+ max_length,
+ min_length,
+ do_sample,
+ early_stopping,
+ temperature,
+ top_k,
+ top_p,
+ repetition_penalty,
+ no_repeat_ngram_size,
+ bad_words_ids,
+ pad_token_id,
+ eos_token_id,
+ batch_size,
+ num_return_sequences,
+ length_penalty,
+ num_beams,
+ vocab_size,
+ attention_mask,
+ use_cache,
+ model_kwargs,
+ ):
+ """Generate sequences for each example with beam search."""
+
+ # generated hypotheses
+ generated_hyps = [
+ BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
+ for _ in range(batch_size)
+ ]
+
+ # scores for each sentence in the beam
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
+
+ # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
+ if do_sample is False:
+ beam_scores[:, 1:] = -1e9
+ beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
+
+ # cache compute states
+ past = None
+
+ # done sentences
+ done = [False for _ in range(batch_size)]
+
+ while cur_len < max_length:
+ model_inputs = self.prepare_inputs_for_generation(
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
+ )
+ outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
+ next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
+
+ # if model has past, then set the past variable to speed up decoding
+ if "past_key_values" in outputs:
+ past = outputs.past_key_values
+ elif "mems" in outputs:
+ past = outputs.mems
+
+ if self.config.is_encoder_decoder and do_sample is False:
+ # TODO (PVP) still a bit hacky here - there might be a better solution
+ next_token_logits = self.adjust_logits_during_generation(
+ next_token_logits, cur_len=cur_len, max_length=max_length
+ )
+
+ scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
+
+ scores = self.postprocess_next_token_scores(
+ scores=scores,
+ input_ids=input_ids,
+ no_repeat_ngram_size=no_repeat_ngram_size,
+ bad_words_ids=bad_words_ids,
+ cur_len=cur_len,
+ min_length=min_length,
+ max_length=max_length,
+ eos_token_id=eos_token_id,
+ repetition_penalty=repetition_penalty,
+ batch_size=batch_size,
+ num_beams=num_beams,
+ )
+
+ assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
+ scores.shape, (batch_size * num_beams, vocab_size)
+ )
+
+ if do_sample:
+ _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
+ # Temperature
+ if temperature != 1.0:
+ _scores = _scores / temperature
+ # Top-p/top-k filtering
+ _scores = top_k_top_p_filtering(
+ _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
+ ) # (batch_size * num_beams, vocab_size)
+ # re-organize to group the beam together to sample from all beam_idxs
+ _scores = _scores.contiguous().view(
+ batch_size, num_beams * vocab_size
+ ) # (batch_size, num_beams * vocab_size)
+
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
+ probs = F.softmax(_scores, dim=-1)
+ next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
+ # Compute next scores
+ next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
+ # sort the sampled vector to make sure that the first num_beams samples are the best
+ next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
+ next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
+
+ else:
+ next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
+
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
+ next_scores = next_scores.view(
+ batch_size, num_beams * vocab_size
+ ) # (batch_size, num_beams * vocab_size)
+
+ next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
+
+ assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
+
+ # next batch beam content
+ next_batch_beam = []
+
+ # for each sentence
+ for batch_idx in range(batch_size):
+
+ # if we are done with this sentence, add a pad token
+ if done[batch_idx]:
+ assert (
+ len(generated_hyps[batch_idx]) >= num_beams
+ ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
+ assert (
+ eos_token_id is not None and pad_token_id is not None
+ ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
+ next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
+ continue
+
+ # next sentence beam content, this will get added to next_batch_beam
+ next_sent_beam = []
+
+ # next tokens for this sentence
+ for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
+ zip(next_tokens[batch_idx], next_scores[batch_idx])
+ ):
+ # get beam and token IDs
+ beam_id = beam_token_id // vocab_size
+ token_id = beam_token_id % vocab_size
+
+ effective_beam_id = batch_idx * num_beams + beam_id
+ # add to generated hypotheses if end of sentence
+ if (eos_token_id is not None) and (token_id.item() == eos_token_id):
+ # if beam_token does not belong to top num_beams tokens, it should not be added
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
+ if is_beam_token_worse_than_top_num_beams:
+ continue
+ generated_hyps[batch_idx].add(
+ input_ids[effective_beam_id].clone(),
+ beam_token_score.item(),
+ )
+ else:
+ # add next predicted token since it is not eos_token
+ next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
+
+ # once the beam for next step is full, don't add more tokens to it.
+ if len(next_sent_beam) == num_beams:
+ break
+
+ # Check if we are done so that we can save a pad step if all(done)
+ done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
+ next_scores[batch_idx].max().item(), cur_len
+ )
+
+ # update next beam content
+ assert len(next_sent_beam) == num_beams, "Beam should always be full"
+ next_batch_beam.extend(next_sent_beam)
+ assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
+
+ # stop when we are done with each sentence
+ if all(done):
+ break
+
+ # sanity check / prepare next batch
+ assert len(next_batch_beam) == batch_size * num_beams
+ beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
+ beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
+ beam_idx = input_ids.new([x[2] for x in next_batch_beam])
+
+ # re-order batch and update current length
+ input_ids = input_ids[beam_idx, :]
+ input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
+ cur_len = cur_len + 1
+
+ # re-order internal states
+ if past is not None:
+ past = self._reorder_cache(past, beam_idx)
+
+ # extend attention_mask for new generated input if only decoder
+ # (huxu): move out since we trim attention_mask by ourselves.
+ # if self.config.is_encoder_decoder is False:
+ # attention_mask = torch.cat(
+ # [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ # )
+
+ # finalize all open beam hypotheses and add to generated hypotheses
+ for batch_idx in range(batch_size):
+ if done[batch_idx]:
+ continue
+
+ # test that beam scores match previously calculated scores if not eos and batch_idx not done
+ if eos_token_id is not None and all(
+ (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
+ ):
+ assert torch.all(
+ next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
+ ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
+ next_scores[:, :num_beams][batch_idx],
+ beam_scores.view(batch_size, num_beams)[batch_idx],
+ )
+
+ # need to add best num_beams hypotheses to generated hyps
+ for beam_id in range(num_beams):
+ effective_beam_id = batch_idx * num_beams + beam_id
+ final_score = beam_scores[effective_beam_id].item()
+ final_tokens = input_ids[effective_beam_id]
+ generated_hyps[batch_idx].add(final_tokens, final_score)
+
+ # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
+ output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
+ output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
+
+ # select the best hypotheses
+ sent_lengths = input_ids.new(output_batch_size)
+ best = []
+
+ # retrieve best hypotheses
+ for i, hypotheses in enumerate(generated_hyps):
+ sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
+ for j in range(output_num_return_sequences_per_batch):
+ effective_batch_idx = output_num_return_sequences_per_batch * i + j
+ best_hyp = sorted_hyps.pop()[1]
+ sent_lengths[effective_batch_idx] = len(best_hyp)
+ best.append(best_hyp)
+
+ # prepare for adding eos
+ sent_max_len = min(sent_lengths.max().item() + 1, max_length)
+ decoded = input_ids.new(output_batch_size, sent_max_len)
+ # shorter batches are padded if needed
+ if sent_lengths.min().item() != sent_lengths.max().item():
+ assert pad_token_id is not None, "`pad_token_id` has to be defined"
+ decoded.fill_(pad_token_id)
+
+ # fill with hypotheses and eos_token_id if the latter fits in
+ for i, hypo in enumerate(best):
+ decoded[i, : sent_lengths[i]] = hypo
+ if sent_lengths[i] < max_length:
+ decoded[i, sent_lengths[i]] = eos_token_id
+
+ return decoded
+
+ def _generate_no_beam_search(
+ self,
+ input_ids,
+ cur_len,
+ max_length,
+ min_length,
+ do_sample,
+ temperature,
+ top_k,
+ top_p,
+ repetition_penalty,
+ no_repeat_ngram_size,
+ bad_words_ids,
+ pad_token_id,
+ eos_token_id,
+ batch_size,
+ attention_mask,
+ use_cache,
+ model_kwargs,
+ ):
+ """Generate sequences for each example without beam search (num_beams == 1).
+ All returned sequence are generated independantly.
+ """
+ # length of generated sentences / unfinished sentences
+ unfinished_sents = input_ids.new(batch_size).fill_(1)
+ sent_lengths = input_ids.new(batch_size).fill_(max_length)
+
+ past = None
+ while cur_len < max_length:
+ model_inputs = self.prepare_inputs_for_generation(
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
+ )
+
+ outputs = self(**model_inputs, return_dict=True)
+ next_token_logits = outputs.logits[:, -1, :]
+ scores = self.postprocess_next_token_scores(
+ scores=next_token_logits,
+ input_ids=input_ids,
+ no_repeat_ngram_size=no_repeat_ngram_size,
+ bad_words_ids=bad_words_ids,
+ cur_len=cur_len,
+ min_length=min_length,
+ max_length=max_length,
+ eos_token_id=eos_token_id,
+ repetition_penalty=repetition_penalty,
+ batch_size=batch_size,
+ num_beams=1,
+ )
+
+ # if model has past, then set the past variable to speed up decoding
+ if "past_key_values" in outputs:
+ past = outputs.past_key_values
+ elif "mems" in outputs:
+ past = outputs.mems
+
+ if do_sample:
+ # Temperature (higher temperature => more likely to sample low probability tokens)
+ if temperature != 1.0:
+ scores = scores / temperature
+ # Top-p/top-k filtering
+ next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
+ # Sample
+ probs = F.softmax(next_token_logscores, dim=-1)
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ # Greedy decoding
+ next_token = torch.argmax(next_token_logits, dim=-1)
+
+ # print(next_token_logits[0,next_token[0]], next_token_logits[0,eos_token_id])
+
+ # update generations and finished sentences
+ if eos_token_id is not None:
+ # pad finished sentences if eos_token_id exist
+ tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
+ else:
+ tokens_to_add = next_token
+
+ # add token and increase length by one
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
+ cur_len = cur_len + 1
+
+ if eos_token_id is not None:
+ eos_in_sents = tokens_to_add == eos_token_id
+ # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
+ is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
+ sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
+ # unfinished_sents is set to zero if eos in sentence
+ unfinished_sents.mul_((~eos_in_sents).long())
+
+ # stop when there is a in each sentence, or if we exceed the maximul length
+ if unfinished_sents.max() == 0:
+ break
+
+
+ # extend attention_mask for new generated input if only decoder
+ # if self.config.is_encoder_decoder is False:
+ # attention_mask = torch.cat(
+ # [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ # )
+
+ return input_ids
diff --git a/examples/MMPT/mmpt/models/transformermodel.py b/examples/MMPT/mmpt/models/transformermodel.py
new file mode 100644
index 0000000000..6acc419f09
--- /dev/null
+++ b/examples/MMPT/mmpt/models/transformermodel.py
@@ -0,0 +1,734 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import torch
+
+from torch import nn
+
+try:
+ from transformers.modeling_bert import (
+ BertPreTrainedModel,
+ BertModel,
+ BertEncoder,
+ BertPredictionHeadTransform,
+ )
+except ImportError:
+ pass
+
+from ..modules import VideoTokenMLP, MMBertEmbeddings
+
+
+# --------------- fine-tuning models ---------------
+class MMBertForJoint(BertPreTrainedModel):
+ """A BertModel with isolated attention mask to separate modality."""
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.videomlp = VideoTokenMLP(config)
+ self.bert = MMBertModel(config)
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ input_video_embeds=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ next_sentence_label=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ separate_forward_split=None,
+ ):
+ return_dict = (
+ return_dict if return_dict is not None
+ else self.config.use_return_dict
+ )
+ video_tokens = self.videomlp(input_video_embeds)
+
+ outputs = self.bert(
+ input_ids,
+ video_tokens,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ separate_forward_split=separate_forward_split,
+ )
+
+ return outputs
+
+
+class MMBertForTokenClassification(BertPreTrainedModel):
+ """A BertModel similar to MMJointUni, with extra wrapper layer
+ to be fine-tuned from other pretrained MMFusion model."""
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.videomlp = VideoTokenMLP(config)
+ self.bert = MMBertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # TODO(huxu): 779 is the number of classes for COIN: move to config?
+ self.classifier = nn.Linear(config.hidden_size, 779)
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ input_video_embeds=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ next_sentence_label=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ separate_forward_split=None,
+ ):
+ return_dict = (
+ return_dict if return_dict is not None
+ else self.config.use_return_dict
+ )
+
+ video_tokens = self.videomlp(input_video_embeds)
+ outputs = self.bert(
+ input_ids,
+ video_tokens,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ separate_forward_split=separate_forward_split,
+ )
+
+ return (self.classifier(outputs[0]),)
+
+
+# ------------ pre-training models ----------------
+
+class MMBertForEncoder(BertPreTrainedModel):
+ """A BertModel for Contrastive Learning."""
+ def __init__(self, config):
+ super().__init__(config)
+ self.videomlp = VideoTokenMLP(config)
+ self.bert = MMBertModel(config)
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ input_video_embeds=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ return_dict = (
+ return_dict if return_dict is not None
+ else self.config.use_return_dict
+ )
+ if input_video_embeds is not None:
+ video_tokens = self.videomlp(input_video_embeds)
+ else:
+ video_tokens = None
+
+ outputs = self.bert(
+ input_ids,
+ video_tokens,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ return outputs
+
+
+class MMBertForMFMMLM(BertPreTrainedModel):
+ """A BertModel with shared prediction head on MFM-MLM."""
+ def __init__(self, config):
+ super().__init__(config)
+ self.videomlp = VideoTokenMLP(config)
+ self.bert = MMBertModel(config)
+ self.cls = MFMMLMHead(config)
+ self.hidden_size = config.hidden_size
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def forward(
+ self,
+ input_ids=None,
+ input_video_embeds=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ masked_frame_labels=None,
+ target_video_hidden_states=None,
+ non_masked_frame_mask=None,
+ masked_lm_labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ return_dict = (
+ return_dict if return_dict is not None
+ else self.config.use_return_dict
+ )
+ if input_video_embeds is not None:
+ video_tokens = self.videomlp(input_video_embeds)
+ else:
+ video_tokens = None
+
+ if target_video_hidden_states is not None:
+ target_video_hidden_states = self.videomlp(
+ target_video_hidden_states)
+
+ non_masked_frame_hidden_states = video_tokens.masked_select(
+ non_masked_frame_mask.unsqueeze(-1)
+ ).view(-1, self.hidden_size)
+
+ outputs = self.bert(
+ input_ids,
+ video_tokens,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ mfm_scores, prediction_scores = None, None
+ if masked_frame_labels is not None and masked_lm_labels is not None:
+ # split the sequence.
+ text_offset = masked_frame_labels.size(1) + 1 # [CLS]
+ video_sequence_output = sequence_output[
+ :, 1:text_offset
+ ] # remove [SEP] as not in video_label.
+ text_sequence_output = torch.cat(
+ [sequence_output[:, :1], sequence_output[:, text_offset:]],
+ dim=1
+ )
+
+ hidden_size = video_sequence_output.size(-1)
+ selected_video_output = video_sequence_output.masked_select(
+ masked_frame_labels.unsqueeze(-1)
+ ).view(-1, hidden_size)
+
+ # only compute select tokens to training to speed up.
+ hidden_size = text_sequence_output.size(-1)
+ # masked_lm_labels = masked_lm_labels.reshape(-1)
+ labels_mask = masked_lm_labels != -100
+
+ selected_text_output = text_sequence_output.masked_select(
+ labels_mask.unsqueeze(-1)
+ ).view(-1, hidden_size)
+ mfm_scores, prediction_scores = self.cls(
+ selected_video_output,
+ target_video_hidden_states,
+ non_masked_frame_hidden_states,
+ selected_text_output,
+ )
+
+ output = (
+ mfm_scores,
+ prediction_scores,
+ ) + outputs
+ return output
+
+
+class BertMFMMLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(
+ config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly
+ # resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(
+ self,
+ video_hidden_states=None,
+ target_video_hidden_states=None,
+ non_masked_frame_hidden_states=None,
+ text_hidden_states=None,
+ ):
+ video_logits, text_logits = None, None
+ if video_hidden_states is not None:
+ video_hidden_states = self.transform(video_hidden_states)
+ non_masked_frame_logits = torch.mm(
+ video_hidden_states,
+ non_masked_frame_hidden_states.transpose(1, 0)
+ )
+ masked_frame_logits = torch.bmm(
+ video_hidden_states.unsqueeze(1),
+ target_video_hidden_states.unsqueeze(-1),
+ ).squeeze(-1)
+ video_logits = torch.cat(
+ [masked_frame_logits, non_masked_frame_logits], dim=1
+ )
+
+ if text_hidden_states is not None:
+ text_hidden_states = self.transform(text_hidden_states)
+ text_logits = self.decoder(text_hidden_states)
+ return video_logits, text_logits
+
+
+class MFMMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertMFMMLMPredictionHead(config)
+
+ def forward(
+ self,
+ video_hidden_states=None,
+ target_video_hidden_states=None,
+ non_masked_frame_hidden_states=None,
+ text_hidden_states=None,
+ ):
+ video_logits, text_logits = self.predictions(
+ video_hidden_states,
+ target_video_hidden_states,
+ non_masked_frame_hidden_states,
+ text_hidden_states,
+ )
+ return video_logits, text_logits
+
+
+class MMBertForMTM(MMBertForMFMMLM):
+ def __init__(self, config):
+ BertPreTrainedModel.__init__(self, config)
+ self.videomlp = VideoTokenMLP(config)
+ self.bert = MMBertModel(config)
+ self.cls = MTMHead(config)
+ self.hidden_size = config.hidden_size
+ self.init_weights()
+
+
+class BertMTMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+ self.decoder = nn.Linear(
+ config.hidden_size, config.vocab_size, bias=False)
+
+ def forward(
+ self,
+ video_hidden_states=None,
+ target_video_hidden_states=None,
+ non_masked_frame_hidden_states=None,
+ text_hidden_states=None,
+ ):
+ non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0)
+ video_logits, text_logits = None, None
+ if video_hidden_states is not None:
+ video_hidden_states = self.transform(video_hidden_states)
+
+ masked_frame_logits = torch.bmm(
+ video_hidden_states.unsqueeze(1),
+ target_video_hidden_states.unsqueeze(-1),
+ ).squeeze(-1)
+
+ non_masked_frame_logits = torch.mm(
+ video_hidden_states,
+ non_masked_frame_hidden_states
+ )
+ video_on_vocab_logits = self.decoder(video_hidden_states)
+ video_logits = torch.cat([
+ masked_frame_logits,
+ non_masked_frame_logits,
+ video_on_vocab_logits], dim=1)
+
+ if text_hidden_states is not None:
+ text_hidden_states = self.transform(text_hidden_states)
+ # text first so label does not need to be shifted.
+ text_on_vocab_logits = self.decoder(text_hidden_states)
+ text_on_video_logits = torch.mm(
+ text_hidden_states,
+ non_masked_frame_hidden_states
+ )
+ text_logits = torch.cat([
+ text_on_vocab_logits,
+ text_on_video_logits
+ ], dim=1)
+
+ return video_logits, text_logits
+
+
+class MTMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertMTMPredictionHead(config)
+
+ def forward(
+ self,
+ video_hidden_states=None,
+ target_video_hidden_states=None,
+ non_masked_frame_hidden_states=None,
+ text_hidden_states=None,
+ ):
+ video_logits, text_logits = self.predictions(
+ video_hidden_states,
+ target_video_hidden_states,
+ non_masked_frame_hidden_states,
+ text_hidden_states,
+ )
+ return video_logits, text_logits
+
+
+class MMBertModel(BertModel):
+ """MMBertModel has MMBertEmbedding to support video tokens."""
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ # overwrite embedding
+ self.embeddings = MMBertEmbeddings(config)
+ self.encoder = MultiLayerAttentionMaskBertEncoder(config)
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ input_video_embeds=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ separate_forward_split=None,
+ ):
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None
+ else self.config.use_return_dict
+ )
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids "
+ "and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ if input_video_embeds is not None:
+ input_shape = (
+ input_ids.size(0),
+ input_ids.size(1) + input_video_embeds.size(1),
+ )
+ else:
+ input_shape = (
+ input_ids.size(0),
+ input_ids.size(1),
+ )
+ elif inputs_embeds is not None:
+ if input_video_embeds is not None:
+ input_shape = (
+ inputs_embeds.size(0),
+ inputs_embeds.size(1) + input_video_embeds.size(1),
+ )
+ else:
+ input_shape = (
+ input_ids.size(0),
+ input_ids.size(1),
+ )
+ else:
+ raise ValueError(
+ "You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None \
+ else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(
+ input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions
+ # [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case
+ # we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = \
+ self.get_extended_attention_mask(
+ attention_mask, input_shape, device)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to
+ # [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (
+ encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(
+ encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or
+ # [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape
+ # [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+
+ head_mask = self.get_head_mask(
+ head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids,
+ input_video_embeds,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ )
+
+ if separate_forward_split is not None:
+ split_embedding_output = \
+ embedding_output[:, :separate_forward_split]
+ split_extended_attention_mask = extended_attention_mask[
+ :, :, :, :separate_forward_split, :separate_forward_split
+ ]
+ split_encoder_outputs = self.encoder(
+ split_embedding_output,
+ attention_mask=split_extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ assert (
+ len(split_encoder_outputs) <= 2
+ ), "we do not support merge on attention for now."
+ encoder_outputs = []
+ encoder_outputs.append([split_encoder_outputs[0]])
+ if len(split_encoder_outputs) == 2:
+ encoder_outputs.append([])
+ for _all_hidden_states in split_encoder_outputs[1]:
+ encoder_outputs[-1].append([_all_hidden_states])
+
+ split_embedding_output = \
+ embedding_output[:, separate_forward_split:]
+ split_extended_attention_mask = extended_attention_mask[
+ :, :, :, separate_forward_split:, separate_forward_split:
+ ]
+
+ split_encoder_outputs = self.encoder(
+ split_embedding_output,
+ attention_mask=split_extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ assert (
+ len(split_encoder_outputs) <= 2
+ ), "we do not support merge on attention for now."
+ encoder_outputs[0].append(split_encoder_outputs[0])
+ encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1)
+ if len(split_encoder_outputs) == 2:
+ for layer_idx, _all_hidden_states in enumerate(
+ split_encoder_outputs[1]
+ ):
+ encoder_outputs[1][layer_idx].append(_all_hidden_states)
+ encoder_outputs[1][layer_idx] = torch.cat(
+ encoder_outputs[1][layer_idx], dim=1
+ )
+ encoder_outputs = tuple(encoder_outputs)
+ else:
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+ pooled_output = (
+ self.pooler(sequence_output) if self.pooler is not None else None
+ )
+
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ def get_extended_attention_mask(self, attention_mask, input_shape, device):
+ """This is borrowed from `modeling_utils.py` with the support of
+ multi-layer attention masks.
+ The second dim is expected to be number of layers.
+ See `MMAttentionMaskProcessor`.
+ Makes broadcastable attention and causal masks so that future
+ and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to,
+ zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, \
+ with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions
+ # [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable
+ # to all heads.
+ if attention_mask.dim() == 4:
+ extended_attention_mask = attention_mask[:, :, None, :, :]
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=self.dtype
+ ) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) \
+ * -10000.0
+ return extended_attention_mask
+ else:
+ return super().get_extended_attention_mask(
+ attention_mask, input_shape, device
+ )
+
+
+class MultiLayerAttentionMaskBertEncoder(BertEncoder):
+ """extend BertEncoder with the capability of
+ multiple layers of attention mask."""
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_attention_mask = (
+ attention_mask[:, i, :, :, :]
+ if attention_mask.dim() == 5
+ else attention_mask
+ )
+
+ if getattr(self.config, "gradient_checkpointing", False):
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ layer_attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ layer_attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ return tuple(
+ v
+ for v in [hidden_states, all_hidden_states, all_attentions]
+ if v is not None
+ )
diff --git a/examples/MMPT/mmpt/modules/__init__.py b/examples/MMPT/mmpt/modules/__init__.py
new file mode 100644
index 0000000000..4c78594c21
--- /dev/null
+++ b/examples/MMPT/mmpt/modules/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .mm import *
+
+try:
+ from .expmm import *
+except ImportError:
+ pass
diff --git a/examples/MMPT/mmpt/modules/mm.py b/examples/MMPT/mmpt/modules/mm.py
new file mode 100644
index 0000000000..5d9777371a
--- /dev/null
+++ b/examples/MMPT/mmpt/modules/mm.py
@@ -0,0 +1,145 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+
+import torch
+
+from torch import nn
+
+try:
+ from transformers.modeling_bert import (
+ BertEmbeddings,
+ ACT2FN,
+ )
+except ImportError:
+ pass
+
+
+class VideoTokenMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ input_dim = config.input_dim if hasattr(config, "input_dim") else 512
+ self.linear1 = nn.Linear(input_dim, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size)
+ self.activation = ACT2FN[config.hidden_act]
+ self.linear2 = nn.Linear(config.hidden_size, config.hidden_size)
+
+ def forward(self, hidden_states):
+ hidden_states = self.linear1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ hidden_states = self.linear2(hidden_states)
+ return hidden_states
+
+
+class MMBertEmbeddings(BertEmbeddings):
+ def __init__(self, config):
+ super().__init__(config)
+ self.max_video_len = config.max_video_len
+ if hasattr(config, "use_seg_emb") and config.use_seg_emb:
+ """the original VLM paper uses seg_embeddings for temporal space.
+ although not used it changed the randomness of initialization.
+ we keep it for reproducibility.
+ """
+ self.seg_embeddings = nn.Embedding(256, config.hidden_size)
+
+ def forward(
+ self,
+ input_ids,
+ input_video_embeds,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ ):
+ input_tensor = input_ids if input_ids is not None else inputs_embeds
+ if input_video_embeds is not None:
+ input_shape = (
+ input_tensor.size(0),
+ input_tensor.size(1) + input_video_embeds.size(1),
+ )
+ else:
+ input_shape = (input_tensor.size(0), input_tensor.size(1))
+
+ if position_ids is None:
+ """
+ Auto skip position embeddings for text only case.
+ use cases:
+ (1) action localization and segmentation:
+ feed in len-1 dummy video token needs text part to
+ skip input_video_embeds.size(1) for the right
+ position_ids for video [SEP] and rest text tokens.
+ (2) MMFusionShare for two forward passings:
+ in `forward_text`: input_video_embeds is None.
+ need to skip video [SEP] token.
+
+ # video_len + 1: [CLS] + video_embed
+ # self.max_video_len + 1: [SEP] for video.
+ # self.max_video_len + 2: [SEP] for video.
+ # self.max_video_len + input_ids.size(1): rest for text.
+ """
+ if input_video_embeds is not None:
+ video_len = input_video_embeds.size(1)
+ starting_offset = self.max_video_len + 1 # video [SEP]
+ ending_offset = self.max_video_len + input_ids.size(1)
+ else:
+ video_len = 0
+ starting_offset = self.max_video_len + 2 # first text token.
+ ending_offset = self.max_video_len + input_ids.size(1) + 1
+ position_ids = torch.cat([
+ self.position_ids[:, :video_len + 1],
+ self.position_ids[:, starting_offset:ending_offset]
+ ], dim=1)
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(
+ input_shape, dtype=torch.long, device=self.position_ids.device
+ )
+
+ """
+ the format of input_ids is [CLS] [SEP] caption [SEP] padding.
+ the goal is to build [CLS] video tokens [SEP] caption [SEP] .
+ """
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ if input_video_embeds is not None:
+ inputs_mm_embeds = torch.cat([
+ inputs_embeds[:, :1], input_video_embeds, inputs_embeds[:, 1:]
+ ], dim=1)
+ else:
+ # text only for `MMFusionShare`.
+ inputs_mm_embeds = inputs_embeds
+
+ position_embeddings = self.position_embeddings(position_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+ embeddings = inputs_mm_embeds + position_embeddings
+ embeddings += token_type_embeddings
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class AlignHead(nn.Module):
+ """this will load pre-trained weights for NSP, which is desirable."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, dropout_pooled_output):
+ logits = self.seq_relationship(dropout_pooled_output)
+ return logits
diff --git a/examples/MMPT/mmpt/modules/retri.py b/examples/MMPT/mmpt/modules/retri.py
new file mode 100644
index 0000000000..d1b288f8e5
--- /dev/null
+++ b/examples/MMPT/mmpt/modules/retri.py
@@ -0,0 +1,429 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import numpy as np
+import pickle
+import time
+
+try:
+ import faiss
+except ImportError:
+ pass
+
+from collections import defaultdict
+
+from ..utils import get_local_rank, print_on_rank0
+
+
+class VectorRetriever(object):
+ """
+ How2 Video Retriver.
+ Reference usage of FAISS:
+ https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
+ """
+
+ def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
+ if db_type == "flatl2":
+ quantizer = faiss.IndexFlatL2(hidden_size) # the other index
+ self.db = faiss.IndexIVFFlat(
+ quantizer, hidden_size, cent, faiss.METRIC_L2)
+ elif db_type == "pq":
+ self.db = faiss.index_factory(
+ hidden_size, f"IVF{cent}_HNSW32,PQ32"
+ )
+ else:
+ raise ValueError("unknown type of db", db_type)
+ self.train_thres = cent * examples_per_cent_to_train
+ self.train_cache = []
+ self.train_len = 0
+ self.videoid_to_vectoridx = {}
+ self.vectoridx_to_videoid = None
+ self.make_direct_maps_done = False
+
+ def make_direct_maps(self):
+ faiss.downcast_index(self.db).make_direct_map()
+
+ def __len__(self):
+ return self.db.ntotal
+
+ def save(self, out_dir):
+ faiss.write_index(
+ self.db,
+ os.path.join(out_dir, "faiss_idx")
+ )
+ with open(
+ os.path.join(
+ out_dir, "videoid_to_vectoridx.pkl"),
+ "wb") as fw:
+ pickle.dump(
+ self.videoid_to_vectoridx, fw,
+ protocol=pickle.HIGHEST_PROTOCOL
+ )
+
+ def load(self, out_dir):
+ fn = os.path.join(out_dir, "faiss_idx")
+ self.db = faiss.read_index(fn)
+ with open(
+ os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
+ self.videoid_to_vectoridx = pickle.load(fr)
+
+ def add(self, hidden_states, video_ids, last=False):
+ assert len(hidden_states) == len(video_ids), "{}, {}".format(
+ str(len(hidden_states)), str(len(video_ids)))
+ assert len(hidden_states.shape) == 2
+ assert hidden_states.dtype == np.float32
+
+ valid_idx = []
+ for idx, video_id in enumerate(video_ids):
+ if video_id not in self.videoid_to_vectoridx:
+ valid_idx.append(idx)
+ self.videoid_to_vectoridx[video_id] = \
+ len(self.videoid_to_vectoridx)
+
+ hidden_states = hidden_states[valid_idx]
+ if not self.db.is_trained:
+ self.train_cache.append(hidden_states)
+ self.train_len += hidden_states.shape[0]
+ if self.train_len < self.train_thres:
+ return
+ self.finalize_training()
+ else:
+ self.db.add(hidden_states)
+
+ def finalize_training(self):
+ hidden_states = np.concatenate(self.train_cache, axis=0)
+ del self.train_cache
+ local_rank = get_local_rank()
+ if local_rank == 0:
+ start = time.time()
+ print("training db on", self.train_thres, "/", self.train_len)
+ self.db.train(hidden_states[:self.train_thres])
+ if local_rank == 0:
+ print("training db for", time.time() - start)
+ self.db.add(hidden_states)
+
+ def search(
+ self,
+ query_hidden_states,
+ orig_dist,
+ ):
+ if len(self.videoid_to_vectoridx) != self.db.ntotal:
+ raise ValueError(
+ "cannot search: size mismatch in-between index and db",
+ len(self.videoid_to_vectoridx),
+ self.db.ntotal
+ )
+
+ if self.vectoridx_to_videoid is None:
+ self.vectoridx_to_videoid = {
+ self.videoid_to_vectoridx[videoid]: videoid
+ for videoid in self.videoid_to_vectoridx
+ }
+ assert len(self.vectoridx_to_videoid) \
+ == len(self.videoid_to_vectoridx)
+
+ # MultilingualFaissDataset uses the following; not sure the purpose.
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+ queried_dist, index = self.db.search(query_hidden_states, 1)
+ queried_dist, index = queried_dist[:, 0], index[:, 0]
+
+ outputs = np.array(
+ [self.vectoridx_to_videoid[_index]
+ if _index != -1 else (-1, -1, -1) for _index in index],
+ dtype=np.int32)
+ outputs[queried_dist <= orig_dist] = -1
+ return outputs
+
+ def search_by_video_ids(
+ self,
+ video_ids,
+ retri_factor
+ ):
+ if len(self.videoid_to_vectoridx) != self.db.ntotal:
+ raise ValueError(
+ len(self.videoid_to_vectoridx),
+ self.db.ntotal
+ )
+
+ if not self.make_direct_maps_done:
+ self.make_direct_maps()
+
+ if self.vectoridx_to_videoid is None:
+ self.vectoridx_to_videoid = {
+ self.videoid_to_vectoridx[videoid]: videoid
+ for videoid in self.videoid_to_vectoridx
+ }
+ assert len(self.vectoridx_to_videoid) \
+ == len(self.videoid_to_vectoridx)
+
+ query_hidden_states = []
+ vector_ids = []
+ for video_id in video_ids:
+ vector_id = self.videoid_to_vectoridx[video_id]
+ vector_ids.append(vector_id)
+ query_hidden_state = self.db.reconstruct(vector_id)
+ query_hidden_states.append(query_hidden_state)
+ query_hidden_states = np.stack(query_hidden_states)
+
+ # MultilingualFaissDataset uses the following; not sure the reason.
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+ _, index = self.db.search(query_hidden_states, retri_factor)
+ outputs = []
+ for sample_idx, sample in enumerate(index):
+ # the first video_id is always the video itself.
+ cands = [video_ids[sample_idx]]
+ for vector_idx in sample:
+ if vector_idx >= 0 \
+ and vector_ids[sample_idx] != vector_idx:
+ cands.append(
+ self.vectoridx_to_videoid[vector_idx]
+ )
+ outputs.append(cands)
+ return outputs
+
+
+class VectorRetrieverDM(VectorRetriever):
+ """
+ with direct map.
+ How2 Video Retriver.
+ Reference usage of FAISS:
+ https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ cent,
+ db_type,
+ examples_per_cent_to_train
+ ):
+ super().__init__(
+ hidden_size, cent, db_type, examples_per_cent_to_train)
+ self.make_direct_maps_done = False
+
+ def make_direct_maps(self):
+ faiss.downcast_index(self.db).make_direct_map()
+ self.make_direct_maps_done = True
+
+ def search(
+ self,
+ query_hidden_states,
+ orig_dist,
+ ):
+ if len(self.videoid_to_vectoridx) != self.db.ntotal:
+ raise ValueError(
+ len(self.videoid_to_vectoridx),
+ self.db.ntotal
+ )
+
+ if not self.make_direct_maps_done:
+ self.make_direct_maps()
+ if self.vectoridx_to_videoid is None:
+ self.vectoridx_to_videoid = {
+ self.videoid_to_vectoridx[videoid]: videoid
+ for videoid in self.videoid_to_vectoridx
+ }
+ assert len(self.vectoridx_to_videoid) \
+ == len(self.videoid_to_vectoridx)
+
+ # MultilingualFaissDataset uses the following; not sure the reason.
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+ queried_dist, index = self.db.search(query_hidden_states, 1)
+ outputs = []
+ for sample_idx, sample in enumerate(index):
+ # and queried_dist[sample_idx] < thres \
+ if sample >= 0 \
+ and queried_dist[sample_idx] < orig_dist[sample_idx]:
+ outputs.append(self.vectoridx_to_videoid[sample])
+ else:
+ outputs.append(None)
+ return outputs
+
+ def search_by_video_ids(
+ self,
+ video_ids,
+ retri_factor=8
+ ):
+ if len(self.videoid_to_vectoridx) != self.db.ntotal:
+ raise ValueError(
+ len(self.videoid_to_vectoridx),
+ self.db.ntotal
+ )
+
+ if not self.make_direct_maps_done:
+ self.make_direct_maps()
+ if self.vectoridx_to_videoid is None:
+ self.vectoridx_to_videoid = {
+ self.videoid_to_vectoridx[videoid]: videoid
+ for videoid in self.videoid_to_vectoridx
+ }
+ assert len(self.vectoridx_to_videoid) \
+ == len(self.videoid_to_vectoridx)
+
+ query_hidden_states = []
+ vector_ids = []
+ for video_id in video_ids:
+ vector_id = self.videoid_to_vectoridx[video_id]
+ vector_ids.append(vector_id)
+ query_hidden_state = self.db.reconstruct(vector_id)
+ query_hidden_states.append(query_hidden_state)
+ query_hidden_states = np.stack(query_hidden_states)
+
+ # MultilingualFaissDataset uses the following; not sure the reason.
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+ _, index = self.db.search(query_hidden_states, retri_factor)
+ outputs = []
+ for sample_idx, sample in enumerate(index):
+ # the first video_id is always the video itself.
+ cands = [video_ids[sample_idx]]
+ for vector_idx in sample:
+ if vector_idx >= 0 \
+ and vector_ids[sample_idx] != vector_idx:
+ cands.append(
+ self.vectoridx_to_videoid[vector_idx]
+ )
+ outputs.append(cands)
+ return outputs
+
+
+class MMVectorRetriever(VectorRetrieverDM):
+ """
+ multimodal vector retriver:
+ text retrieve video or video retrieve text.
+ """
+
+ def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
+ super().__init__(
+ hidden_size, cent, db_type, examples_per_cent_to_train)
+ video_db = self.db
+ super().__init__(
+ hidden_size, cent, db_type, examples_per_cent_to_train)
+ text_db = self.db
+ self.db = {"video": video_db, "text": text_db}
+ self.video_to_videoid = defaultdict(list)
+
+ def __len__(self):
+ assert self.db["video"].ntotal == self.db["text"].ntotal
+ return self.db["video"].ntotal
+
+ def make_direct_maps(self):
+ faiss.downcast_index(self.db["video"]).make_direct_map()
+ faiss.downcast_index(self.db["text"]).make_direct_map()
+
+ def save(self, out_dir):
+ faiss.write_index(
+ self.db["video"],
+ os.path.join(out_dir, "video_faiss_idx")
+ )
+ faiss.write_index(
+ self.db["text"],
+ os.path.join(out_dir, "text_faiss_idx")
+ )
+
+ with open(
+ os.path.join(
+ out_dir, "videoid_to_vectoridx.pkl"),
+ "wb") as fw:
+ pickle.dump(
+ self.videoid_to_vectoridx, fw,
+ protocol=pickle.HIGHEST_PROTOCOL
+ )
+
+ def load(self, out_dir):
+ fn = os.path.join(out_dir, "video_faiss_idx")
+ video_db = faiss.read_index(fn)
+ fn = os.path.join(out_dir, "text_faiss_idx")
+ text_db = faiss.read_index(fn)
+ self.db = {"video": video_db, "text": text_db}
+ with open(
+ os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
+ self.videoid_to_vectoridx = pickle.load(fr)
+ self.video_to_videoid = defaultdict(list)
+
+ def add(self, hidden_states, video_ids):
+ """hidden_states is a pair `(video, text)`"""
+ assert len(hidden_states) == len(video_ids), "{}, {}".format(
+ str(len(hidden_states)), str(len(video_ids)))
+ assert len(hidden_states.shape) == 3
+ assert len(self.video_to_videoid) == 0
+
+ valid_idx = []
+ for idx, video_id in enumerate(video_ids):
+ if video_id not in self.videoid_to_vectoridx:
+ valid_idx.append(idx)
+ self.videoid_to_vectoridx[video_id] = \
+ len(self.videoid_to_vectoridx)
+
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states[valid_idx]
+
+ hidden_states = np.transpose(hidden_states, (1, 0, 2)).copy()
+ if not self.db["video"].is_trained:
+ self.train_cache.append(hidden_states)
+ train_len = batch_size * len(self.train_cache)
+ if train_len < self.train_thres:
+ return
+
+ hidden_states = np.concatenate(self.train_cache, axis=1)
+ del self.train_cache
+ self.db["video"].train(hidden_states[0, :self.train_thres])
+ self.db["text"].train(hidden_states[1, :self.train_thres])
+ self.db["video"].add(hidden_states[0])
+ self.db["text"].add(hidden_states[1])
+
+ def get_clips_by_video_id(self, video_id):
+ if not self.video_to_videoid:
+ for video_id, video_clip, text_clip in self.videoid_to_vectoridx:
+ self.video_to_videoid[video_id].append(
+ (video_id, video_clip, text_clip))
+ return self.video_to_videoid[video_id]
+
+ def search(
+ self,
+ video_ids,
+ target_modality,
+ retri_factor=8
+ ):
+ if len(self.videoid_to_vectoridx) != len(self):
+ raise ValueError(
+ len(self.videoid_to_vectoridx),
+ len(self)
+ )
+
+ if not self.make_direct_maps_done:
+ self.make_direct_maps()
+ if self.vectoridx_to_videoid is None:
+ self.vectoridx_to_videoid = {
+ self.videoid_to_vectoridx[videoid]: videoid
+ for videoid in self.videoid_to_vectoridx
+ }
+ assert len(self.vectoridx_to_videoid) \
+ == len(self.videoid_to_vectoridx)
+
+ src_modality = "text" if target_modality == "video" else "video"
+
+ query_hidden_states = []
+ vector_ids = []
+ for video_id in video_ids:
+ vector_id = self.videoid_to_vectoridx[video_id]
+ vector_ids.append(vector_id)
+ query_hidden_state = self.db[src_modality].reconstruct(vector_id)
+ query_hidden_states.append(query_hidden_state)
+ query_hidden_states = np.stack(query_hidden_states)
+
+ # MultilingualFaissDataset uses the following; not sure the reason.
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+ _, index = self.db[target_modality].search(
+ query_hidden_states, retri_factor)
+ outputs = []
+ for sample_idx, sample in enumerate(index):
+ cands = []
+ for vector_idx in sample:
+ if vector_idx >= 0:
+ cands.append(
+ self.vectoridx_to_videoid[vector_idx]
+ )
+ outputs.append(cands)
+ return outputs
diff --git a/examples/MMPT/mmpt/modules/vectorpool.py b/examples/MMPT/mmpt/modules/vectorpool.py
new file mode 100644
index 0000000000..d2b23d2da8
--- /dev/null
+++ b/examples/MMPT/mmpt/modules/vectorpool.py
@@ -0,0 +1,246 @@
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import torch
+import os
+import numpy as np
+import pickle
+
+from . import retri
+from ..utils import get_local_rank
+
+
+class VectorPool(object):
+ """
+ Base class of retrieval space.
+ """
+
+ def __init__(self, config):
+ from transformers import AutoConfig
+ self.hidden_size = AutoConfig.from_pretrained(
+ config.dataset.bert_name).hidden_size
+ self.retriever_cls = getattr(retri, config.retriever_cls)
+
+ def __call__(self, sample, **kwargs):
+ raise NotImplementedError
+
+ def build_retriver(
+ self,
+ retriever_cls=None,
+ hidden_size=None,
+ centroids=512,
+ db_type="flatl2",
+ examples_per_cent_to_train=48
+ ):
+
+ """merge results from multiple gpus and return a retriver.."""
+ self.retriver = retriever_cls(
+ hidden_size, centroids, db_type, examples_per_cent_to_train)
+ return self.retriver
+
+ def __repr__(self):
+ if hasattr(self, "retriver"):
+ retriver_name = str(len(self.retriver))
+ else:
+ retriver_name = "no retriver field yet"
+ return self.__class__.__name__ \
+ + "(" + retriver_name + ")"
+
+
+class VideoVectorPool(VectorPool):
+ """
+ average clips of a video as video representation.
+ """
+ def __init__(self, config):
+ super().__init__(config)
+ self.build_retriver(self.retriever_cls, self.hidden_size)
+
+ def __call__(self, sample, subsampling, **kwargs):
+ hidden_states = (
+ sample["pooled_video"] + sample["pooled_text"]) / 2.
+ hidden_states = hidden_states.view(
+ -1, subsampling,
+ hidden_states.size(-1))
+ hidden_states = torch.mean(hidden_states, dim=1)
+ hidden_states = hidden_states.cpu().detach().numpy()
+ video_ids = []
+ for offset_idx, video_id in enumerate(sample["video_id"]):
+ if isinstance(video_id, tuple) and len(video_id) == 3:
+ # a sharded video_id.
+ video_id = video_id[0]
+ video_ids.append(video_id)
+ assert len(video_ids) == len(hidden_states)
+ self.retriver.add(
+ hidden_states.astype("float32"),
+ video_ids
+ )
+
+
+class DistributedVectorPool(VectorPool):
+ """
+ support sync of multiple gpus/nodes.
+ """
+ def __init__(self, config):
+ super().__init__(config)
+ self.out_dir = os.path.join(
+ config.fairseq.checkpoint.save_dir,
+ "retri")
+ os.makedirs(self.out_dir, exist_ok=True)
+ self.hidden_states = []
+ self.video_ids = []
+
+ def build_retriver(
+ self,
+ retriever_cls=None,
+ hidden_size=None,
+ centroids=4096,
+ db_type="flatl2",
+ examples_per_cent_to_train=48
+ ):
+ if retriever_cls is None:
+ retriever_cls = self.retriever_cls
+ if hidden_size is None:
+ hidden_size = self.hidden_size
+ """merge results from multiple gpus and return a retriver.."""
+ if torch.distributed.is_initialized():
+ self.save()
+ # sync saving.
+ torch.distributed.barrier()
+ world_size = torch.distributed.get_world_size()
+ else:
+ world_size = 1
+ self.retriver = retriever_cls(
+ hidden_size, centroids, db_type, examples_per_cent_to_train)
+ # each gpu process has its own retriever.
+ for local_rank in range(world_size):
+ if get_local_rank() == 0:
+ print("load local_rank", local_rank)
+ hidden_states, video_ids = self.load(local_rank)
+ hidden_states = hidden_states.astype("float32")
+ self.retriver.add(hidden_states, video_ids)
+ return self.retriver
+
+ def load(self, local_rank):
+ hidden_states = np.load(
+ os.path.join(
+ self.out_dir,
+ "hidden_state" + str(local_rank) + ".npy"
+ )
+ )
+
+ with open(
+ os.path.join(
+ self.out_dir, "video_id" + str(local_rank) + ".pkl"),
+ "rb") as fr:
+ video_ids = pickle.load(fr)
+ return hidden_states, video_ids
+
+ def save(self):
+ hidden_states = np.vstack(self.hidden_states)
+ assert len(hidden_states) == len(self.video_ids), "{}, {}".format(
+ len(hidden_states),
+ len(self.video_ids)
+ )
+ local_rank = torch.distributed.get_rank() \
+ if torch.distributed.is_initialized() else 0
+
+ np.save(
+ os.path.join(
+ self.out_dir,
+ "hidden_state" + str(local_rank) + ".npy"),
+ hidden_states)
+
+ with open(
+ os.path.join(
+ self.out_dir,
+ "video_id" + str(local_rank) + ".pkl"),
+ "wb") as fw:
+ pickle.dump(
+ self.video_ids,
+ fw,
+ protocol=pickle.HIGHEST_PROTOCOL
+ )
+
+
+class DistributedVideoVectorPool(DistributedVectorPool):
+ """
+ average clips of a video as video representation.
+ """
+ def __call__(self, sample, subsampling, **kwargs):
+ hidden_states = (
+ sample["pooled_video"] + sample["pooled_text"]) / 2.
+ hidden_states = hidden_states.view(
+ -1, subsampling,
+ hidden_states.size(-1))
+ hidden_states = torch.mean(hidden_states, dim=1)
+ hidden_states = hidden_states.cpu().detach().numpy()
+ video_ids = []
+ for offset_idx, video_id in enumerate(sample["video_id"]):
+ if isinstance(video_id, tuple) and len(video_id) == 3:
+ # a sharded video_id.
+ video_id = video_id[0]
+ video_ids.append(video_id)
+ assert len(video_ids) == len(hidden_states)
+ self.hidden_states.append(hidden_states)
+ self.video_ids.extend(video_ids)
+
+
+# ------------ the following are deprecated --------------
+
+class TextClipVectorPool(VectorPool):
+ def __init__(self, config):
+ from transformers import AutoConfig
+ hidden_size = AutoConfig.from_pretrained(
+ config.dataset.bert_name).hidden_size
+ retriever_cls = getattr(retri, config.retriever_cls)
+ self.build_retriver(retriever_cls, hidden_size)
+
+ def __call__(self, sample, **kwargs):
+ clip_meta = sample["clip_meta"].cpu()
+ assert torch.all(torch.le(clip_meta[:, 4], clip_meta[:, 5]))
+ text_meta = [tuple(item.tolist()) for item in clip_meta[:, 3:]]
+
+ if hasattr(self, "retriver"):
+ # build_retriver is called.
+ self.retriver.add(
+ sample["pooled_text"].cpu().numpy().astype("float32"),
+ text_meta
+ )
+ else:
+ raise NotImplementedError
+
+
+class MMClipVectorPool(VectorPool):
+ """
+ Multimodal Clip-level vector pool.
+ """
+ def __init__(self, out_dir):
+ """use hidden_states to store `(video, text)`."""
+ """use video_ids to store `(video_id, start, end)`."""
+ super().__init__(out_dir)
+
+ def __call__(self, sample, **kwargs):
+ pooled_video = sample["pooled_video"].cpu().unsqueeze(1).numpy()
+ pooled_text = sample["pooled_text"].cpu().unsqueeze(1).numpy()
+
+ self.hidden_states.append(
+ np.concatenate([pooled_video, pooled_text], axis=1)
+ )
+
+ video_starts = sample["video_start"].cpu()
+ video_ends = sample["video_end"].cpu()
+ assert torch.all(torch.le(video_starts, video_ends))
+
+ text_starts = sample["text_start"].cpu()
+ text_ends = sample["text_end"].cpu()
+ assert torch.all(torch.le(text_starts, text_ends))
+ subsample_size = sample["pooled_video"].size(0) // len(sample["video_id"])
+ video_ids = [video_id for video_id in sample["video_id"]
+ for _ in range(subsample_size)
+ ]
+ for video_id, video_start, video_end, text_start, text_end in zip(
+ video_ids, video_starts, video_ends, text_starts, text_ends):
+ self.video_ids.append((
+ video_id,
+ (int(video_start), int(video_end)),
+ (int(text_start), int(text_end))
+ ))
diff --git a/examples/MMPT/mmpt/processors/__init__.py b/examples/MMPT/mmpt/processors/__init__.py
new file mode 100644
index 0000000000..434d1d92b9
--- /dev/null
+++ b/examples/MMPT/mmpt/processors/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .processor import *
+
+from .how2processor import *
+from .how2retriprocessor import *
+
+from .dsprocessor import *
+
+try:
+ from .rawvideoprocessor import *
+ from .codecprocessor import *
+ from .webvidprocessor import *
+ from .expprocessor import *
+ from .exphow2processor import *
+ from .exphow2retriprocessor import *
+ from .expcodecprocessor import *
+ from .expfeatureencoder import *
+ from .expdsprocessor import *
+except ImportError:
+ pass
diff --git a/examples/MMPT/mmpt/processors/dedupprocessor.py b/examples/MMPT/mmpt/processors/dedupprocessor.py
new file mode 100644
index 0000000000..8a1ad402cd
--- /dev/null
+++ b/examples/MMPT/mmpt/processors/dedupprocessor.py
@@ -0,0 +1,242 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import json
+import pickle
+from tqdm import tqdm
+import os
+import numpy as np
+
+
+class CaptionDedupProcessor(object):
+ """remove overlapping of caption sentences(clip).
+ Some statistics:
+ caption:
+ {'t_clip_len': 246.6448431320854,
+ 'video_len': 281.09174795676245,
+ 'clip_tps': 0.8841283727427481,
+ 'video_tps': 0.7821156477732097,
+ 'min_clip_len': 0.0,
+ 'max_clip_len': 398.3,
+ 'mean_clip_len': 3.196580003006861,
+ 'num_clip': 77.15897706301081}
+
+ raw_caption:
+ {'t_clip_len': 238.95908778424115,
+ 'video_len': 267.5914859862507,
+ 'clip_tps': 2.4941363624267963,
+ 'video_tps': 2.258989769647173,
+ 'min_clip_len': 0.0,
+ 'max_clip_len': 398.3,
+ 'mean_clip_len': 3.0537954186814265,
+ 'num_clip': 78.24986779481756}
+ """
+
+ def __init__(self, pkl_file):
+ with open(pkl_file, "rb") as fd:
+ self.data = pickle.load(fd)
+ self.stat = {
+ "t_clip_len": [],
+ "video_len": [],
+ "clip_tps": [],
+ "video_tps": [],
+ "clip_len": [],
+ }
+
+ def __call__(self):
+ for idx, video_id in enumerate(tqdm(self.data)):
+ caption = json.loads(self.data[video_id])
+ caption = self._dedup(caption)
+ if idx < 4096: # for the first 4096 examples, compute the statistics.
+ self.save_stat(video_id, caption)
+ self.data[video_id] = json.dumps(caption)
+ self.print_stat()
+
+ def single(self, video_id):
+ caption = json.loads(self.data[video_id])
+ for clip_idx, (start, end, text) in enumerate(
+ zip(caption["start"], caption["end"], caption["text"])
+ ):
+ print(start, end, text)
+ print("@" * 100)
+ caption = self._dedup(caption)
+ for clip_idx, (start, end, text) in enumerate(
+ zip(caption["start"], caption["end"], caption["text"])
+ ):
+ print(start, end, text)
+ print("#" * 100)
+ self.save_stat(video_id, caption)
+ self.print_stat()
+
+ def finalize(self, tgt_fn):
+ with open(tgt_fn, "wb") as fw:
+ pickle.dump(self.data, fw, pickle.HIGHEST_PROTOCOL)
+
+ def save_stat(self, video_id, caption):
+ video_fn = os.path.join(
+ "data/feat/feat_how2_s3d", video_id + ".npy"
+ )
+ if os.path.isfile(video_fn):
+ with open(video_fn, "rb", 1) as fr: # 24 is the buffer size. buffered
+ version = np.lib.format.read_magic(fr)
+ shape, fortran, dtype = np.lib.format._read_array_header(fr, version)
+ video_len = shape[0]
+
+ t_clip_len = 0.0
+ t_tokens = 0
+ for idx, (start, end, text) in enumerate(
+ zip(caption["start"], caption["end"], caption["text"])
+ ):
+ clip_len = (
+ (end - max(caption["end"][idx - 1], start))
+ if idx > 0
+ else end - start
+ )
+ t_clip_len += clip_len
+ t_tokens += len(text.split(" "))
+ self.stat["clip_len"].append(clip_len)
+ self.stat["t_clip_len"].append(t_clip_len)
+ self.stat["video_len"].append(video_len)
+ self.stat["clip_tps"].append(t_tokens / t_clip_len)
+ self.stat["video_tps"].append(t_tokens / video_len)
+
+ def print_stat(self):
+ result = {
+ "t_clip_len": np.mean(self.stat["t_clip_len"]),
+ "video_len": np.mean(self.stat["video_len"]),
+ "clip_tps": np.mean(self.stat["clip_tps"]),
+ "video_tps": np.mean(self.stat["video_tps"]),
+ "min_clip_len": min(self.stat["clip_len"]),
+ "max_clip_len": max(self.stat["clip_len"]),
+ "mean_clip_len": np.mean(self.stat["clip_len"]),
+ "num_clip": len(self.stat["clip_len"]) / len(self.stat["video_tps"]),
+ }
+ print(result)
+
+ def _dedup(self, caption):
+ def random_merge(end_idx, start, end, text, starts, ends, texts):
+ if random.random() > 0.5:
+ # print(clip_idx, "[PARTIAL INTO PREV]", end_idx)
+ # overlapped part goes to the end of previous.
+ ends[-1] = max(ends[-1], start) # ?
+ rest_text = text[end_idx:].strip()
+ if rest_text:
+ starts.append(max(ends[-1], start))
+ ends.append(max(end, starts[-1]))
+ texts.append(rest_text)
+ else: # goes to the beginning of the current.
+ # strip the previous.
+ left_text = texts[-1][:-end_idx].strip()
+ if left_text:
+ # print(clip_idx, "[PREV PARTIAL INTO CUR]", end_idx)
+ ends[-1] = min(ends[-1], start)
+ texts[-1] = left_text
+ else:
+ # print(clip_idx, "[PREV LEFT NOTHING ALL INTO CUR]", end_idx)
+ starts.pop(-1)
+ ends.pop(-1)
+ texts.pop(-1)
+ starts.append(start)
+ ends.append(end)
+ texts.append(text)
+
+ starts, ends, texts = [], [], []
+ for clip_idx, (start, end, text) in enumerate(
+ zip(caption["start"], caption["end"], caption["text"])
+ ):
+ if not isinstance(text, str):
+ continue
+ text = text.replace("\n", " ").strip()
+ if len(text) == 0:
+ continue
+ starts.append(start)
+ ends.append(end)
+ texts.append(text)
+ break
+
+ for clip_idx, (start, end, text) in enumerate(
+ zip(
+ caption["start"][clip_idx + 1:],
+ caption["end"][clip_idx + 1:],
+ caption["text"][clip_idx + 1:],
+ )
+ ):
+ if not isinstance(text, str):
+ continue
+ text = text.replace("\n", " ").strip()
+ if len(text) == 0:
+ continue
+
+ # print(clip_idx, texts[-5:])
+ # print(clip_idx, start, end, text)
+ if texts[-1].endswith(text): # subset of prev caption -> merge
+ # print(clip_idx, "[MERGE INTO PREV]")
+ ends[-1] = max(ends[-1], end)
+ elif text.startswith(texts[-1]): # superset of prev caption -> merge
+ # print(clip_idx, "[PREV MERGE INTO CUR]")
+ texts[-1] = text
+ starts[-1] = min(starts[-1], start)
+ ends[-1] = max(ends[-1], end)
+ else: # overlapping or non-overlapping.
+ for end_idx in range(1, len(text) + 1):
+ if texts[-1].endswith(text[:end_idx]):
+ random_merge(end_idx, start, end, text, starts, ends, texts)
+ break
+ else:
+ starts.append(start)
+ ends.append(end)
+ texts.append(text)
+
+ assert (ends[-1] + 0.001) >= starts[-1] and len(
+ texts[-1]
+ ) > 0, "{} {} {} <- {} {} {}, {} {} {}".format(
+ str(starts[-1]),
+ str(ends[-1]),
+ texts[-1],
+ caption["start"][clip_idx - 1],
+ caption["end"][clip_idx - 1],
+ caption["text"][clip_idx - 1],
+ str(start),
+ str(end),
+ text,
+ )
+
+ return {"start": starts, "end": ends, "text": texts}
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="dedup how2 caption")
+ parser.add_argument('--how2dir', default="data/how2")
+ args = parser.parse_args()
+
+ raw_caption_json = os.path.join(args.how2dir, "raw_caption.json")
+ raw_caption_pickle = os.path.join(args.how2dir, "raw_caption.pkl")
+ raw_caption_dedup_pickle = os.path.join(args.how2dir, "raw_caption_dedup.pkl")
+
+ def convert_to_pickle(src_fn, tgt_fn):
+ with open(src_fn) as fd:
+ captions = json.load(fd)
+
+ for video_id in captions:
+ captions[video_id] = json.dumps(captions[video_id])
+
+ with open(tgt_fn, "wb") as fw:
+ pickle.dump(captions, fw, pickle.HIGHEST_PROTOCOL)
+
+ if not os.path.isfile(raw_caption_pickle):
+ convert_to_pickle(raw_caption_json, raw_caption_pickle)
+
+ deduper = CaptionDedupProcessor(raw_caption_pickle)
+ deduper()
+ deduper.finalize(raw_caption_dedup_pickle)
+
+ """
+ # demo
+ deduper = CaptionDedupProcessor("data/how2/raw_caption.pkl")
+ deduper.single("HfIeQ9pzL5U")
+ """
diff --git a/examples/MMPT/mmpt/processors/dsprocessor.py b/examples/MMPT/mmpt/processors/dsprocessor.py
new file mode 100644
index 0000000000..ecebf0eea5
--- /dev/null
+++ b/examples/MMPT/mmpt/processors/dsprocessor.py
@@ -0,0 +1,848 @@
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+"""
+Processors for all downstream (ds) tasks.
+"""
+
+import json
+import os
+import pickle
+import random
+import math
+import numpy as np
+import torch
+
+from collections import defaultdict
+
+from .processor import (
+ MetaProcessor,
+ VideoProcessor,
+ TextProcessor,
+ Aligner,
+ MMAttentionMask2DProcessor,
+)
+
+from .how2processor import TextGenerationProcessor
+
+
+# ------------- A General Aligner for all downstream tasks-----------------
+
+
+class DSAligner(Aligner):
+ """
+ Downstream (DS) aligner shared by all datasets.
+ """
+
+ def __call__(self, video_id, video_feature, text_feature, wps=0.7):
+ # random sample a starting sec for video.
+ video_start = 0
+ video_end = min(len(video_feature), self.max_video_len)
+ # the whole sequence is a single clip.
+ video_clips = {"start": [video_start], "end": [video_end]}
+
+ text_feature = {
+ "cap": [text_feature],
+ "start": [video_start],
+ "end": [len(text_feature) / wps],
+ }
+ text_clip_indexs = [0]
+
+ vfeats, vmasks = self._build_video_seq(
+ video_feature, video_clips
+ )
+ caps, cmasks = self._build_text_seq(
+ text_feature, text_clip_indexs
+ )
+
+ return {
+ "caps": caps,
+ "cmasks": cmasks,
+ "vfeats": vfeats,
+ "vmasks": vmasks,
+ "video_id": video_id,
+ }
+
+
+class NLGTextProcessor(TextProcessor):
+ """
+ Also return the original text as ref.
+ """
+ def __call__(self, text_id):
+ return super().__call__(text_id), text_id
+
+
+class DSNLGAligner(DSAligner):
+ """extend with the capability of 2d mask for generation."""
+ def __init__(self, config):
+ super().__init__(config)
+ self.attnmasker = MMAttentionMask2DProcessor()
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.bert_name, use_fast=self.use_fast,
+ bos_token="[CLS]", eos_token="[SEP]"
+ )
+ self.tokenizer = tokenizer
+ self.bos_token_id = tokenizer.bos_token_id
+ self.eos_token_id = tokenizer.eos_token_id
+ self.textgen = TextGenerationProcessor(tokenizer)
+
+ def __call__(self, video_id, video_feature, text_feature):
+ output = super().__call__(video_id, video_feature, text_feature[0])
+ if self.split == "test":
+ # output.update({"ref": text_feature[1]})
+ output.update({"ref": self.tokenizer.decode(
+ output["caps"], skip_special_tokens=True)})
+ text_label = output["caps"]
+ cmasks = torch.BoolTensor([1] * text_label.size(0))
+ caps = torch.LongTensor([
+ self.cls_token_id,
+ self.sep_token_id,
+ self.bos_token_id])
+ else:
+ caps, text_label = self.textgen(output["caps"])
+ cmasks = output["cmasks"]
+
+ attention_mask = self.attnmasker(
+ output["vmasks"], cmasks, "textgen")
+
+ output.update({
+ "caps": caps,
+ "cmasks": cmasks,
+ "text_label": text_label,
+ "attention_mask": attention_mask,
+ })
+ return output
+
+
+# -------------------- MSRVTT ------------------------
+
+
+class MSRVTTMetaProcessor(MetaProcessor):
+ """MSRVTT dataset.
+ reference: `howto100m/msrvtt_dataloader.py`
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ import pandas as pd
+ data = pd.read_csv(self._get_split_path(config))
+ # TODO: add a text1ka flag.
+ if config.split == "train" \
+ and config.full_test_path is not None \
+ and config.jsfusion_path is not None:
+ # add testing videos from full_test_path not used by jfusion.
+ additional_data = pd.read_csv(config.full_test_path)
+ jsfusion_data = pd.read_csv(config.jsfusion_path)
+
+ for video_id in additional_data["video_id"]:
+ if video_id not in jsfusion_data["video_id"].values:
+ data = data.append(
+ {"video_id": video_id}, ignore_index=True)
+
+ if config.dup is not None and config.split == "train":
+ data = data.append([data] * (config.dup - 1), ignore_index=True)
+ self.data = data
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ """slightly modify with if condition to combine train/test."""
+ vid, sentence = None, None
+ vid = self.data["video_id"].values[idx]
+ if "sentence" in self.data: # for testing.
+ sentence = self.data["sentence"].values[idx]
+ else: # for training.
+ sentence = vid
+ return vid, sentence
+
+
+class MSRVTTTextProcessor(TextProcessor):
+ """MSRVTT dataset.
+ reference: `msrvtt_dataloader.py` `MSRVTT_TrainDataLoader`.
+ TODO (huxu): add max_words.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.sentences = None
+ if config.json_path is not None and config.split == "train":
+ with open(config.json_path) as fd:
+ self.data = json.load(fd)
+ self.sentences = defaultdict(list)
+ for s in self.data["sentences"]:
+ self.sentences[s["video_id"]].append(s["caption"])
+
+ def __call__(self, text_id):
+ if self.sentences is not None:
+ rind = random.randint(0, len(self.sentences[text_id]) - 1)
+ sentence = self.sentences[text_id][rind]
+ else:
+ sentence = text_id
+ caption = self.tokenizer(sentence, add_special_tokens=False)
+ return caption["input_ids"]
+
+
+class MSRVTTNLGTextProcessor(MSRVTTTextProcessor):
+ """TODO: change dsaligner and merge to avoid any NLG text processor."""
+ def __call__(self, text_id):
+ if self.sentences is not None:
+ rind = random.randint(0, len(self.sentences[text_id]) - 1)
+ sentence = self.sentences[text_id][rind]
+ else:
+ sentence = text_id
+ caption = self.tokenizer(sentence, add_special_tokens=False)
+ return caption["input_ids"], sentence
+
+
+class MSRVTTQAMetaProcessor(MetaProcessor):
+ """MSRVTT-QA: retrieval-based multi-choice QA from JSFusion dataset.
+ For simplicity, we use the train retrieval model.
+ reference: `https://github.com/yj-yu/lsmdc`
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ import pandas as pd
+ csv_data = pd.read_csv(self._get_split_path(config), sep="\t")
+ data = []
+ for video_id, a1, a2, a3, a4, a5, answer in zip(
+ csv_data["vid_key"].values,
+ csv_data["a1"].values,
+ csv_data["a2"].values,
+ csv_data["a3"].values,
+ csv_data["a4"].values,
+ csv_data["a5"].values,
+ csv_data["answer"].values):
+ video_id = video_id.replace("msr", "video")
+ data.append((video_id, (answer, [a1, a2, a3, a4, a5])))
+ self.data = data
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+
+class MSRVTTQATextProcessor(TextProcessor):
+ """MSRVTT-QA dataset.
+ text_ans is of format `(answer, [a1, a2, a3, a4, a5])`.
+ """
+
+ def __call__(self, text_ans):
+ for ans_idx, ans in enumerate(text_ans[1]):
+ if isinstance(ans, str):
+ text_ans[1][ans_idx] = self.tokenizer(ans, add_special_tokens=False)["input_ids"]
+ return text_ans
+
+
+class MSRVTTQAAligner(DSAligner):
+ """MSRVTT dataset.
+ similar to sample in how2.
+ we call __call__ multiple times.
+ """
+
+ def __call__(self, video_id, video_feature, text_feature, wps=0.7):
+ caps = []
+ cmasks = []
+ answer = text_feature[0]
+ for ans_idx, _text_feature in enumerate(text_feature[1]):
+ output = super().__call__(
+ video_id, video_feature, _text_feature, wps)
+ caps.append(output["caps"])
+ cmasks.append(output["cmasks"])
+ output.update({
+ "caps": torch.stack(caps),
+ "cmasks": torch.stack(cmasks),
+ "answers": torch.LongTensor([answer]),
+ })
+ return output
+
+
+# -------------------- Youcook -----------------------
+
+
+class YoucookMetaProcessor(MetaProcessor):
+ """Youcook dataset.
+ reference: `howto100m/youcook_dataloader.py`
+ note that the data can be different as the
+ (1) some videos already in Howto100m are removed.
+ (2) stop words are removed from caption
+ TODO (huxu): make a flag to load the original caption.
+ (see youcookii_annotations_trainval.json).
+
+ The max_video_len can be 264 and text can be 64 tokens.
+ In reality we may not need that long. see projects/task/youcook.yaml
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ vfeat_dir = config.vfeat_dir
+ print(self._get_split_path(config))
+ with open(self._get_split_path(config), "rb") as fd:
+ data = pickle.load(fd)
+ all_valid_video_ids = set(
+ [os.path.splitext(fn)[0] for fn in os.listdir(vfeat_dir)]
+ )
+ recs = []
+ video_ids = set()
+ valid_video_ids = set()
+ for rec in data: # filter videos not available.
+ udl_idx = rec["id"].rindex("_")
+ video_id = rec["id"][:udl_idx]
+ video_ids.add(video_id)
+ if video_id in all_valid_video_ids:
+ valid_video_ids.add(video_id)
+ recs.append(rec)
+ print("total video_ids in .pkl", len(video_ids))
+ print("valid video_ids in .pkl", len(valid_video_ids))
+ print("please verify {train,val}_list.txt")
+ data = recs
+ self.data = data
+
+ with open(config.trainval_annotation) as fd:
+ self.youcook_annotation = json.load(fd)["database"]
+ if config.use_annotation_text is True:
+ print("using text in annotation.")
+ self.use_annotation_caption = True
+ else:
+ self.use_annotation_caption = False
+
+ def __getitem__(self, idx):
+ def _get_video_and_caption(rec):
+ vid = rec["id"]
+ udl_idx = vid.rindex("_")
+ video_id, clip_id = vid[:udl_idx], int(vid[udl_idx + 1:])
+ clip = self.youcook_annotation[video_id]["annotations"][clip_id]
+ start, end = clip["segment"]
+ if self.use_annotation_caption:
+ caption = clip["sentence"]
+ else:
+ caption = rec["caption"]
+ return (video_id, start, end), caption
+
+ rec = self.data[idx]
+ video_info, text_info = _get_video_and_caption(rec)
+ return video_info, text_info
+
+
+class YoucookVideoProcessor(VideoProcessor):
+ """video_fn is a tuple of (video_id, start, end) now."""
+
+ def __call__(self, video_fn):
+ video_id, start, end = video_fn
+ feat = np.load(os.path.join(self.vfeat_dir, video_id + ".npy"))
+ return feat[start:end]
+
+
+class YoucookNLGMetaProcessor(MetaProcessor):
+ """NLG uses the original split:
+ `train_list.txt` and `val_list.txt`
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ vfeat_dir = config.vfeat_dir
+ print(self._get_split_path(config))
+ with open(self._get_split_path(config)) as fd:
+ video_ids = [
+ line.strip().split("/")[1] for line in fd.readlines()]
+ print("total video_ids in train/val_list.txt", len(video_ids))
+
+ all_valid_video_ids = set(
+ [os.path.splitext(fn)[0] for fn in os.listdir(vfeat_dir)]
+ )
+ video_ids = [
+ video_id for video_id in video_ids
+ if video_id in all_valid_video_ids]
+
+ print("valid video_ids in train/val_list.txt", len(video_ids))
+ with open(config.trainval_annotation) as fd:
+ self.youcook_annotation = json.load(fd)["database"]
+
+ data = []
+ for video_id in video_ids:
+ for clip in self.youcook_annotation[video_id]["annotations"]:
+ start, end = clip["segment"]
+ caption = clip["sentence"]
+ data.append(((video_id, start, end), caption))
+ self.data = data
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+
+# --------------------- CrossTask -------------------------
+
+class CrossTaskMetaProcessor(MetaProcessor):
+ def __init__(self, config):
+ super().__init__(config)
+ np.random.seed(0) # deterministic random split.
+ task_vids = self._get_vids(
+ config.train_csv_path,
+ config.vfeat_dir,
+ config.annotation_path)
+
+ val_vids = self._get_vids(
+ config.val_csv_path,
+ config.vfeat_dir,
+ config.annotation_path)
+
+ # filter out those task and vids appear in val_vids.
+ task_vids = {
+ task: [
+ vid for vid in vids
+ if task not in val_vids or vid not in val_vids[task]]
+ for task, vids in task_vids.items()}
+
+ primary_info = self._read_task_info(config.primary_path)
+ test_tasks = set(primary_info['steps'].keys())
+
+ # if args.use_related:
+ related_info = self._read_task_info(config.related_path)
+ task_steps = {**primary_info['steps'], **related_info['steps']}
+ n_steps = {**primary_info['n_steps'], **related_info['n_steps']}
+ # else:
+ # task_steps = primary_info['steps']
+ # n_steps = primary_info['n_steps']
+ all_tasks = set(n_steps.keys())
+ # filter and keep task in primary or related.
+ task_vids = {
+ task: vids for task, vids in task_vids.items()
+ if task in all_tasks}
+ # vocab-by-step matrix (A) and vocab (M)
+ # (huxu): we do not use BoW.
+ # A, M = self._get_A(task_steps, share="words")
+
+ train_vids, test_vids = self._random_split(
+ task_vids, test_tasks, config.n_train)
+ print("train_num_videos", sum(len(vids) for vids in train_vids.values()))
+ print("test_num_videos", sum(len(vids) for vids in test_vids.values()))
+ # added by huxu to automatically determine the split.
+ split_map = {
+ "train": train_vids,
+ "valid": test_vids,
+ "test": test_vids
+ }
+ task_vids = split_map[config.split]
+
+ self.vids = []
+ for task, vids in task_vids.items():
+ self.vids.extend([(task, vid) for vid in vids])
+ self.task_steps = task_steps
+ self.n_steps = n_steps
+
+ def __getitem__(self, idx):
+ task, vid = self.vids[idx]
+ n_steps = self.n_steps[task]
+ steps = self.task_steps[task]
+ assert len(steps) == n_steps
+ return (task, vid, steps, n_steps), (task, vid, steps, n_steps)
+
+ def __len__(self):
+ return len(self.vids)
+
+ def _random_split(self, task_vids, test_tasks, n_train):
+ train_vids = {}
+ test_vids = {}
+ for task, vids in task_vids.items():
+ if task in test_tasks and len(vids) > n_train:
+ train_vids[task] = np.random.choice(
+ vids, n_train, replace=False).tolist()
+ test_vids[task] = [
+ vid for vid in vids if vid not in train_vids[task]]
+ else:
+ train_vids[task] = vids
+ return train_vids, test_vids
+
+ def _get_vids(self, path, vfeat_dir, annotation_path):
+ """refactored from
+ https://github.com/DmZhukov/CrossTask/blob/master/data.py
+ changes: add `vfeat_dir` to check if the video is available.
+ add `annotation_path` to check if the video is available.
+ """
+
+ task_vids = {}
+ with open(path, 'r') as f:
+ for line in f:
+ task, vid, url = line.strip().split(',')
+ # double check the video is available.
+ if not os.path.exists(
+ os.path.join(vfeat_dir, vid + ".npy")):
+ continue
+ # double check the annotation is available.
+ if not os.path.exists(os.path.join(
+ annotation_path,
+ task + "_" + vid + ".csv")):
+ continue
+ if task not in task_vids:
+ task_vids[task] = []
+ task_vids[task].append(vid)
+ return task_vids
+
+ def _read_task_info(self, path):
+ titles = {}
+ urls = {}
+ n_steps = {}
+ steps = {}
+ with open(path, 'r') as f:
+ idx = f.readline()
+ while idx != '':
+ idx = idx.strip()
+ titles[idx] = f.readline().strip()
+ urls[idx] = f.readline().strip()
+ n_steps[idx] = int(f.readline().strip())
+ steps[idx] = f.readline().strip().split(',')
+ next(f)
+ idx = f.readline()
+ return {
+ 'title': titles,
+ 'url': urls,
+ 'n_steps': n_steps,
+ 'steps': steps
+ }
+
+ def _get_A(self, task_steps, share="words"):
+ raise ValueError("running get_A is not allowed for BERT.")
+ """Step-to-component matrices."""
+ if share == 'words':
+ # share words
+ task_step_comps = {
+ task: [step.split(' ') for step in steps]
+ for task, steps in task_steps.items()}
+ elif share == 'task_words':
+ # share words within same task
+ task_step_comps = {
+ task: [[task+'_'+tok for tok in step.split(' ')] for step in steps]
+ for task, steps in task_steps.items()}
+ elif share == 'steps':
+ # share whole step descriptions
+ task_step_comps = {
+ task: [[step] for step in steps] for task, steps in task_steps.items()}
+ else:
+ # no sharing
+ task_step_comps = {
+ task: [[task+'_'+step] for step in steps]
+ for task, steps in task_steps.items()}
+ # BERT tokenizer here?
+ vocab = []
+ for task, steps in task_step_comps.items():
+ for step in steps:
+ vocab.extend(step)
+ vocab = {comp: m for m, comp in enumerate(set(vocab))}
+ M = len(vocab)
+ A = {}
+ for task, steps in task_step_comps.items():
+ K = len(steps)
+ a = torch.zeros(M, K)
+ for k, step in enumerate(steps):
+ a[[vocab[comp] for comp in step], k] = 1
+ a /= a.sum(dim=0)
+ A[task] = a
+ return A, M
+
+
+class CrossTaskVideoProcessor(VideoProcessor):
+ def __call__(self, video_fn):
+ task, vid, steps, n_steps = video_fn
+ video_fn = os.path.join(self.vfeat_dir, vid + ".npy")
+ feat = np.load(video_fn)
+ return feat
+
+
+class CrossTaskTextProcessor(TextProcessor):
+ def __call__(self, text_id):
+ task, vid, steps, n_steps = text_id
+ step_ids = []
+ for step_str in steps:
+ step_ids.append(
+ self.tokenizer(step_str, add_special_tokens=False)["input_ids"]
+ )
+ return step_ids
+
+
+class CrossTaskAligner(Aligner):
+ """
+ TODO: it's not clear yet the formulation of the task; finish this later.
+ """
+ def __init__(self, config):
+ super().__init__(config)
+ self.annotation_path = config.annotation_path
+ self.sliding_window = config.sliding_window
+ self.sliding_window_size = config.sliding_window_size
+
+ def __call__(self, video_id, video_feature, text_feature):
+ task, vid, steps, n_steps = video_id
+ annot_path = os.path.join(
+ self.annotation_path, task + '_' + vid + '.csv')
+ video_len = len(video_feature)
+
+ labels = torch.from_numpy(self._read_assignment(
+ video_len, n_steps, annot_path)).float()
+
+ vfeats, vmasks, targets = [], [], []
+ # sliding window on video features and targets.
+ for window_start in range(0, video_len, self.sliding_window):
+ video_start = 0
+ video_end = min(video_len - window_start, self.sliding_window_size)
+ video_clip = {"start": [video_start], "end": [video_end]}
+
+ vfeat, vmask = self._build_video_seq(
+ video_feature[window_start: window_start + video_end],
+ video_clip
+ )
+
+ target = labels[window_start: window_start + video_end]
+ assert len(vfeat) >= len(target), "{},{}".format(len(vfeat), len(target))
+ # TODO: randomly drop all zero targets for training ?
+ # if self.split == "train" and target.sum() == 0:
+ # continue
+ vfeats.append(vfeat)
+ vmasks.append(vmask)
+ targets.append(target)
+
+ if (video_len - window_start) <= self.sliding_window_size:
+ break
+
+ vfeats = torch.stack(vfeats)
+ vmasks = torch.stack(vmasks)
+ targets = torch.cat(targets, dim=0)
+
+ caps, cmasks = [], []
+ for step in text_feature:
+ step_text_feature = {"start": [0], "end": [1], "cap": [step]}
+ step_text_clip_index = [0]
+ cap, cmask = self._build_text_seq(
+ step_text_feature, step_text_clip_index
+ )
+ caps.append(cap)
+ cmasks.append(cmask)
+ caps = torch.stack(caps)
+ cmasks = torch.stack(cmasks)
+
+ return {
+ "caps": caps,
+ "cmasks": cmasks,
+ "vfeats": vfeats, # X for original code.
+ "vmasks": vmasks,
+ "targets": targets,
+ "video_id": vid,
+ "task": task,
+ "video_len": video_len # for later checking.
+ }
+
+ def _read_assignment(self, T, K, path):
+ """
+ refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py
+ Howto interpret contraints on loss that is going to be minimized:
+ lambd is a big number;
+ self.lambd * C is a big number for all valid position (csv stores invalids)
+
+ def forward(self, O, Y, C):
+ return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum()
+
+ This will load the csv file and fill-in the step col from start to end rows.
+ """
+
+ Y = np.zeros([T, K], dtype=np.uint8)
+ with open(path, 'r') as f:
+ for line in f:
+ step, start, end = line.strip().split(',')
+ start = int(math.floor(float(start)))
+ end = int(math.ceil(float(end)))
+ step = int(step) - 1
+ Y[start:end, step] = 1
+ return Y
+
+
+# --------------------- COIN -------------------------
+
+class MetaTextBinarizer(Aligner):
+ def __call__(self, text_feature):
+ text_feature = {
+ "cap": [text_feature],
+ "start": [0.],
+ "end": [100.],
+ }
+ text_clip_indexs = [0]
+
+ caps, cmasks = self._build_text_seq(
+ text_feature, text_clip_indexs
+ )
+ return {"caps": caps, "cmasks": cmasks}
+
+
+class COINActionSegmentationMetaProcessor(MetaProcessor):
+ split_map = {
+ "train": "training",
+ "valid": "testing",
+ "test": "testing",
+ }
+
+ def __init__(self, config):
+ super().__init__(config)
+ with open(self._get_split_path(config)) as fr:
+ database = json.load(fr)["database"]
+ id2label = {}
+ data = []
+ # filter the data by split.
+ for video_id, rec in database.items():
+ # always use testing to determine label_set
+ if rec["subset"] == "testing":
+ for segment in rec["annotation"]:
+ id2label[int(segment["id"])] = segment["label"]
+ # text_labels is used for ZS setting
+ self.text_labels = ["none"] * len(id2label)
+ for label_id in id2label:
+ self.text_labels[label_id-1] = id2label[label_id]
+
+ id2label[0] = "O"
+ print("num of labels", len(id2label))
+
+ for video_id, rec in database.items():
+ if not os.path.isfile(os.path.join(config.vfeat_dir, video_id + ".npy")):
+ continue
+ if rec["subset"] == COINActionSegmentationMetaProcessor.split_map[self.split]:
+ starts, ends, labels = [], [], []
+ for segment in rec["annotation"]:
+ start, end = segment["segment"]
+ label = int(segment["id"])
+ starts.append(start)
+ ends.append(end)
+ labels.append(label)
+ data.append(
+ (video_id, {"start": starts, "end": ends, "label": labels}))
+ self.data = data
+
+ def meta_text_labels(self, config):
+ from transformers import default_data_collator
+ from ..utils import get_local_rank
+
+ text_processor = TextProcessor(config)
+ binarizer = MetaTextBinarizer(config)
+ # TODO: add prompts to .yaml.
+ text_labels = [label for label in self.text_labels]
+
+ if get_local_rank() == 0:
+ print(text_labels)
+
+ outputs = []
+ for text_label in text_labels:
+ text_feature = text_processor(text_label)
+ outputs.append(binarizer(text_feature))
+ return default_data_collator(outputs)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+
+class COINActionSegmentationTextProcessor(TextProcessor):
+ def __call__(self, text_label):
+ return text_label
+
+
+class COINActionSegmentationAligner(Aligner):
+ def __init__(self, config):
+ super().__init__(config)
+ self.sliding_window = config.sliding_window
+ self.sliding_window_size = config.sliding_window_size
+
+ def __call__(self, video_id, video_feature, text_feature):
+ starts, ends, label_ids = text_feature["start"], text_feature["end"], text_feature["label"]
+ # sliding window.
+ video_len = len(video_feature)
+
+ vfeats, vmasks, targets = [], [], []
+ # sliding window on video features and targets.
+ for window_start in range(0, video_len, self.sliding_window):
+ video_start = 0
+ video_end = min(video_len - window_start, self.sliding_window_size)
+ video_clip = {"start": [video_start], "end": [video_end]}
+ vfeat, vmask = self._build_video_seq(
+ video_feature[window_start: window_start + video_end],
+ video_clip
+ )
+ # covers video length only.
+ target = torch.full_like(vmask, -100, dtype=torch.long)
+ target[vmask] = 0
+ for start, end, label_id in zip(starts, ends, label_ids):
+ if (window_start < end) and (start < (window_start + video_end)):
+ start_offset = max(0, math.floor(start) - window_start)
+ end_offset = min(video_end, math.ceil(end) - window_start)
+ target[start_offset:end_offset] = label_id
+ vfeats.append(vfeat)
+ vmasks.append(vmask)
+ targets.append(target)
+ if (video_len - window_start) <= self.sliding_window_size:
+ break
+
+ vfeats = torch.stack(vfeats)
+ vmasks = torch.stack(vmasks)
+ targets = torch.stack(targets)
+ video_targets = torch.full((video_len,), 0)
+ for start, end, label_id in zip(starts, ends, label_ids):
+ start_offset = max(0, math.floor(start))
+ end_offset = min(video_len, math.ceil(end))
+ video_targets[start_offset:end_offset] = label_id
+
+ caps = torch.LongTensor(
+ [[self.cls_token_id, self.sep_token_id,
+ self.pad_token_id, self.sep_token_id]],
+ ).repeat(vfeats.size(0), 1)
+ cmasks = torch.BoolTensor(
+ [[0, 1, 0, 1]] # pad are valid for attention.
+ ).repeat(vfeats.size(0), 1)
+ return {
+ "caps": caps,
+ "cmasks": cmasks,
+ "vfeats": vfeats, # X for original code.
+ "vmasks": vmasks,
+ "targets": targets,
+ "video_id": video_id,
+ "video_len": video_len, # for later checking.
+ "video_targets": video_targets
+ }
+
+
+class DiDeMoMetaProcessor(MetaProcessor):
+ """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
+ https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
+ """
+ def __init__(self, config):
+ super().__init__(config)
+
+ assert "test" in self._get_split_path(config), "DiDeMo only supports zero-shot testing for now."
+
+ with open(self._get_split_path(config)) as data_file:
+ json_data = json.load(data_file)
+
+ data = []
+ for record in json_data:
+ data.append((record["video"], record["description"]))
+ self.data = data
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+
+class DiDeMoTextProcessor(TextProcessor):
+ """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
+ https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
+ """
+
+ def __call__(self, text):
+ return self.tokenizer(text, add_special_tokens=False)["input_ids"]
+
+
+class DiDeMoAligner(DSAligner):
+ """
+ check video length.
+ """
+
+ def __call__(self, video_id, video_feature, text_feature):
+ # print(video_feature.shape[0])
+ return super().__call__(video_id, video_feature, text_feature)
diff --git a/examples/MMPT/mmpt/processors/how2processor.py b/examples/MMPT/mmpt/processors/how2processor.py
new file mode 100644
index 0000000000..bed2168b1d
--- /dev/null
+++ b/examples/MMPT/mmpt/processors/how2processor.py
@@ -0,0 +1,887 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+
+import torch
+import math
+import pickle
+import random
+import os
+import numpy as np
+
+from collections import deque
+from typing import Optional, Tuple, List
+from .processor import (
+ Processor,
+ MetaProcessor,
+ TextProcessor,
+ Aligner,
+ MMAttentionMask2DProcessor
+)
+
+from ..utils import ShardedTensor
+
+
+class How2MetaProcessor(MetaProcessor):
+ def __init__(self, config):
+ super().__init__(config)
+ path = self._get_split_path(config)
+ with open(path) as fd:
+ self.data = [line.strip() for line in fd]
+
+ def __getitem__(self, idx):
+ video_id = self.data[idx]
+ return video_id, video_id
+
+
+class ShardedHow2MetaProcessor(How2MetaProcessor):
+ def __init__(self, config):
+ super().__init__(config)
+ self.split = str(config.split)
+ self.vfeat_dir = config.vfeat_dir
+ self._init_shard()
+
+ def _init_shard(self):
+ if self.split == "train":
+ meta_fn = os.path.join(self.vfeat_dir, "train" + "_meta.pkl")
+ with open(meta_fn, "rb") as fr:
+ meta = pickle.load(fr)
+ elif self.split == "valid":
+ meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
+ with open(meta_fn, "rb") as fr:
+ meta = pickle.load(fr)
+ elif self.split == "test":
+ print("use how2 val as test.")
+ meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
+ with open(meta_fn, "rb") as fr:
+ meta = pickle.load(fr)
+ else:
+ raise ValueError("unsupported for MetaProcessor:", self.split)
+ video_id_to_shard = {}
+ for shard_id in meta:
+ for video_idx, video_id in enumerate(meta[shard_id]):
+ video_id_to_shard[video_id] = (shard_id, video_idx)
+ self.video_id_to_shard = video_id_to_shard
+
+ def __getitem__(self, idx):
+ video_id, video_id = super().__getitem__(idx)
+ shard_id, shard_idx = self.video_id_to_shard[video_id]
+ meta = (video_id, idx, shard_id, shard_idx)
+ return meta, meta
+
+
+class ShardedVideoProcessor(Processor):
+ """
+ mmaped shards of numpy video features.
+ """
+
+ def __init__(self, config):
+ self.split = str(config.split)
+ self.vfeat_dir = config.vfeat_dir
+
+ def __call__(self, video_id):
+ _, _, shard_id, video_idx = video_id
+ if self.split == "train":
+ shard = ShardedTensor.load(
+ os.path.join(self.vfeat_dir, "train" + "_" + str(shard_id)),
+ "r"
+ )
+ elif self.split == "valid":
+ shard = ShardedTensor.load(
+ os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
+ "r"
+ )
+ elif self.split == "test":
+ shard = ShardedTensor.load(
+ os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
+ "r"
+ )
+ else:
+ raise ValueError("unknown split", self.split)
+ feat = shard[video_idx]
+ return feat
+
+
+class ShardedTextProcessor(Processor):
+ def __init__(self, config):
+ self.tfeat_dir = str(config.tfeat_dir)
+ self.split = str(config.split)
+
+ def __call__(self, video_id):
+ _, _, shard_id, shard_idx = video_id
+ if self.split == "train":
+ target_path = self.tfeat_dir + "train" + "_" + str(shard_id)
+ elif self.split == "valid":
+ target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
+ elif self.split == "test":
+ target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
+ else:
+ raise ValueError("unknown split", self.split)
+
+ startend = ShardedTensor.load(
+ target_path + ".startends", "r")[shard_idx]
+ cap_ids = ShardedTensor.load(
+ target_path + ".caps_ids", "r")[shard_idx]
+ cap = []
+ for clip_idx in range(len(cap_ids)):
+ clip = cap_ids[clip_idx]
+ cap.append(clip[clip != -1].tolist())
+ start, end = startend[:, 0].tolist(), startend[:, 1].tolist()
+ return {"start": start, "end": end, "cap": cap}
+
+
+class FixedLenAligner(Aligner):
+ """
+ In the model we assume text is on the left (closer to BERT formulation)
+ and video is on the right.
+ We fix the total length of text + video.
+ max_video_len is in number of secs.
+ max_text_len is in number of tokens.
+
+ special tokens formats:
+ we use the format [CLS] [SEP] text tokens [SEP] [PAD] ...
+ [CLS] will be splitted out into:
+ [CLS] video tokens [SEP] text tokens [SEP] [PAD] ...
+ token_type_ids will be generated by the model (for now).
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ so each sequence owns a [SEP] token for no-ops.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.text_clip_sampler = TextClipSamplingProcessor(
+ self.max_len - self.max_video_len - 3
+ )
+ """
+ decide subsampling:
+ `config.subsampling` will change batch_size in trainer.
+ `config.clip_per_video` (used by RetriTask) doesn't
+ change batch_size in trainer.
+ """
+ subsampling = config.subsampling \
+ if config.subsampling is not None else None
+ if config.clip_per_video is not None:
+ subsampling = config.clip_per_video
+ self.subsampling = subsampling
+
+ def _get_text_maxlen(self):
+ # use max text len
+ return self.text_clip_sampler.max_text_len
+
+ def __call__(self, video_id, video_feature, text_feature):
+ from transformers import default_data_collator
+ video_idx = video_id[1]
+ if self.subsampling is not None and self.subsampling >= 1:
+ batch = []
+ for _ in range(self.subsampling):
+ centerclip_idx = random.randint(
+ 0, len(text_feature["start"]) - 1)
+ batch.append(
+ self.sampling(
+ video_idx,
+ video_feature,
+ text_feature,
+ centerclip_idx,
+ self._get_text_maxlen()
+ ))
+ batch = self.batch_post_processing(batch, video_feature)
+ batch = default_data_collator(batch)
+ else:
+ raise ValueError(
+ "dataset.subsampling must be >= 1 for efficient video loading.")
+ batch = self.sampling(video_idx, video_feature, text_feature)
+ batch = self.batch_post_processing(batch, video_feature)
+
+ batch["video_id"] = video_id if isinstance(video_id, str) \
+ else video_id[0]
+ # e2e: make sure frame ids is into tensor.
+ assert torch.is_tensor(batch["vfeats"])
+ return batch
+
+ def sampling(
+ self,
+ video_idx,
+ video_feature,
+ text_feature,
+ centerclip_idx=None,
+ sampled_max_text_len=None,
+ ):
+ text_clip_indexs = self.text_clip_sampler(
+ text_feature, centerclip_idx,
+ sampled_max_text_len
+ )
+ if isinstance(video_feature, np.ndarray):
+ video_len = len(video_feature)
+ else:
+ video_len = math.ceil(text_feature["end"][-1])
+
+ video_end = min(
+ math.ceil(text_feature["end"][text_clip_indexs[-1]]),
+ video_len
+ )
+ video_start = max(
+ min(
+ math.floor(text_feature["start"][text_clip_indexs[0]]),
+ video_end),
+ 0
+ )
+
+ video_clips = {"start": [video_start], "end": [video_end]}
+
+ # tensorize.
+ vfeats, vmasks = self._build_video_seq(
+ video_feature, video_clips
+ )
+ caps, cmasks = self._build_text_seq(
+ text_feature, text_clip_indexs
+ )
+
+ text_start = text_clip_indexs[0]
+ text_end = text_clip_indexs[-1] + 1
+
+ return {
+ "caps": caps,
+ "cmasks": cmasks,
+ "vfeats": vfeats,
+ "vmasks": vmasks,
+ "video_start": video_start,
+ "video_end": video_end,
+ "text_start": text_start,
+ "text_end": text_end,
+ }
+
+
+class VariedLenAligner(FixedLenAligner):
+ def __init__(self, config):
+ super().__init__(config)
+ self.sampled_min_len = config.sampled_min_len
+ self.sampled_max_len = config.sampled_max_len
+
+ def _get_text_maxlen(self):
+ return random.randint(self.sampled_min_len, self.sampled_max_len)
+
+
+class StartClipAligner(VariedLenAligner):
+ def sampling(
+ self,
+ video_idx,
+ video_feature,
+ text_feature,
+ centerclip_idx=None,
+ sampled_max_text_len=None,
+ ):
+ return super().sampling(
+ video_idx, video_feature, text_feature, 0)
+
+
+class OverlappedAligner(VariedLenAligner):
+ """video clip and text clip has overlappings
+ but may not be the same start/end."""
+ def __init__(self, config):
+ super().__init__(config)
+ self.sampled_video_min_len = config.sampled_video_min_len
+ self.sampled_video_max_len = config.sampled_video_max_len
+
+ self.video_clip_sampler = VideoClipSamplingProcessor()
+
+ def _get_video_maxlen(self):
+ return random.randint(
+ self.sampled_video_min_len, self.sampled_video_max_len)
+
+ def sampling(
+ self,
+ video_idx,
+ video_feature,
+ text_feature,
+ centerclip_idx=None,
+ sampled_max_text_len=None,
+ ):
+ text_clip_indexs = self.text_clip_sampler(
+ text_feature, centerclip_idx,
+ sampled_max_text_len
+ )
+ if isinstance(video_feature, np.ndarray):
+ video_len = len(video_feature)
+ else:
+ video_len = math.ceil(text_feature["end"][-1])
+ low = math.floor(text_feature["start"][text_clip_indexs[0]])
+ high = math.ceil(text_feature["end"][text_clip_indexs[-1]])
+ if low < high:
+ center = random.randint(low, high)
+ else:
+ center = int((low + high) // 2)
+ center = max(0, min(video_feature.shape[0] - 1, center))
+
+ assert 0 <= center < video_feature.shape[0]
+
+ video_clips = self.video_clip_sampler(
+ video_len, self._get_video_maxlen(), center
+ )
+ video_start = video_clips["start"][0]
+ video_end = video_clips["end"][0]
+
+ # tensorize.
+ vfeats, vmasks = self._build_video_seq(
+ video_feature, video_clips
+ )
+ caps, cmasks = self._build_text_seq(
+ text_feature, text_clip_indexs
+ )
+
+ text_start = text_clip_indexs[0]
+ text_end = text_clip_indexs[-1] + 1
+
+ return {
+ "caps": caps,
+ "cmasks": cmasks,
+ "vfeats": vfeats,
+ "vmasks": vmasks,
+ "video_start": video_start,
+ "video_end": video_end,
+ "text_start": text_start,
+ "text_end": text_end,
+ }
+
+
+class MFMMLMAligner(FixedLenAligner):
+ """
+ `FixedLenAligner` with Masked Language Model and Masked Frame Model.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ keep_prob = config.keep_prob if config.keep_prob is not None else 1.0
+ self.text_clip_sampler = TextClipSamplingProcessor(
+ self.max_len - self.max_video_len - 3, keep_prob
+ )
+ self.sampled_min_len = config.sampled_min_len
+ self.sampled_max_len = config.sampled_max_len
+ self.masked_token_sampler = TextMaskingProcessor(config)
+ self.mm_type = config.mm_type \
+ if config.mm_type is not None else "full"
+ self.attnmasker = MMAttentionMask2DProcessor() \
+ if self.mm_type == "textgen" else None
+ self.masked_frame_sampler = FrameMaskingProcessor(config)
+ self.lazy_vfeat_mask = (
+ False if config.lazy_vfeat_mask is None else config.lazy_vfeat_mask
+ )
+ self.mm_prob = config.mm_prob if config.mm_prob is not None else 0.
+
+ def __call__(self, video_id, video_feature, text_feature):
+ from transformers import default_data_collator
+ if self.subsampling is not None and self.subsampling > 1:
+ batch = []
+ for _ in range(self.subsampling):
+ centerclip_idx = random.randint(
+ 0, len(text_feature["start"]) - 1)
+ sampled_max_text_len = random.randint(
+ self.sampled_min_len, self.sampled_max_len
+ )
+ batch.append(
+ self.sampling(
+ video_id,
+ video_feature,
+ text_feature,
+ centerclip_idx,
+ sampled_max_text_len,
+ )
+ )
+ batch = self.batch_post_processing(batch, video_feature)
+ batch = default_data_collator(batch)
+ else:
+ batch = self.sampling(video_id, video_feature, text_feature)
+ batch = self.batch_post_processing(batch, video_feature)
+ batch["video_id"] = video_id if isinstance(video_id, str) \
+ else video_id[0]
+ return batch
+
+ def sampling(
+ self,
+ video_id,
+ video_feature,
+ text_feature,
+ centerclip_idx=None,
+ sampled_max_text_len=None,
+ ):
+ output = FixedLenAligner.sampling(self,
+ video_id, video_feature, text_feature,
+ centerclip_idx, sampled_max_text_len)
+
+ masking_text, masking_video = None, None
+ if random.random() < self.mm_prob:
+ if random.random() > 0.5:
+ masking_text, masking_video = self.mm_type, "no"
+ else:
+ masking_text, masking_video = "no", "full"
+ video_feats = output["vfeats"] if not self.lazy_vfeat_mask else None
+ video_label = self.masked_frame_sampler(
+ output["vmasks"], masking_video, vfeats=video_feats)
+ caps, text_label = self.masked_token_sampler(
+ output["caps"], masking_text)
+
+ output.update({
+ "caps": caps,
+ "video_label": video_label,
+ "text_label": text_label,
+ })
+
+ if self.attnmasker is not None:
+ attention_mask = self.attnmasker(
+ output["vmasks"], output["cmasks"], masking_text)
+ output.update({
+ "attention_mask": attention_mask
+ })
+ return output
+
+
+class FrameMaskingProcessor(Processor):
+ def __init__(self, config):
+ self.mfm_probability = 0.15
+ if config.mfm_probability is not None:
+ self.mfm_probability = config.mfm_probability
+
+ def __call__(self, vmasks, modality_masking=None, vfeats=None):
+ """
+ We perform lazy masking to save data transfer time.
+ It only generates video_labels by default and MFM model
+ will do actualy masking.
+ Return: `video_label` is a binary mask.
+ """
+ video_label = vmasks.clone()
+ if modality_masking is not None:
+ if modality_masking == "full":
+ probability_matrix = torch.full(video_label.shape, 1.)
+ elif modality_masking == "no":
+ probability_matrix = torch.full(video_label.shape, 0.)
+ elif modality_masking == "inverse":
+ probability_matrix = torch.full(
+ video_label.shape, 1. - self.mfm_probability)
+ else:
+ raise ValueError("unknown modality masking.", modality_masking)
+ else:
+ probability_matrix = torch.full(
+ video_label.shape, self.mfm_probability)
+ masked_indices = torch.bernoulli(probability_matrix).bool()
+ # We only compute loss on masked tokens
+ video_label[~masked_indices] = 0
+ if vfeats is not None:
+ vfeats[video_label, :] = 0.0
+ return video_label
+
+
+class TextGenerationProcessor(Processor):
+ def __init__(self, tokenizer):
+ self.bos_token_id = tokenizer.bos_token_id
+ self.pad_token_id = tokenizer.pad_token_id
+
+ def __call__(self, inputs):
+ labels = inputs.clone()
+ # [CLS] [SEP] for video
+ labels[:2] = -100
+ # keep [SEP] for text.
+ pad_mask = labels == self.pad_token_id
+ labels[pad_mask] = -100
+ inputs[2:] = torch.cat([
+ torch.LongTensor([self.bos_token_id]),
+ inputs[2:-1]])
+ inputs[pad_mask] = self.pad_token_id
+ assert len(inputs) == len(labels)
+ return inputs, labels
+
+
+class TextMaskingProcessor(Processor):
+ def __init__(self, config):
+ """this function is borrowed from
+ `transformers/data/data_collator.DataCollatorForLanguageModeling`"""
+ self.mlm_probability = 0.15
+ if config.mlm_probability is not None:
+ self.mlm_probability = config.mlm_probability
+ self.bert_name = config.bert_name
+ # [CLS] is used as bos_token and [SEP] is used as eos_token.
+ # https://huggingface.co/transformers/master/model_doc/bertgeneration.html
+ from transformers import AutoTokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.bert_name, bos_token="[CLS]", eos_token="[SEP]")
+ self.textgen = TextGenerationProcessor(self.tokenizer)
+
+ def __call__(
+ self, inputs: torch.Tensor,
+ modality_masking=None,
+ special_tokens_mask: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ expand modality_masking into
+ None: traditional bert masking.
+ "no": no masking.
+ "full": all [MASK] token for generation.
+ "gen": autoregressive generation.
+ """
+ """
+ Prepare masked tokens inputs/labels for masked language modeling:
+ 80% MASK, 10% random, 10% original.
+ """
+ labels = inputs.clone()
+ # We sample a few tokens in each sequence for MLM training
+ # (with probability `self.mlm_probability`)
+ if modality_masking is not None:
+ if modality_masking == "full":
+ probability_matrix = torch.full(labels.shape, 1.)
+ elif modality_masking == "no":
+ probability_matrix = torch.full(labels.shape, 0.)
+ elif modality_masking.startswith("textgen"):
+ # [CLS] [SEP] ...
+ inputs, labels = self.textgen(inputs)
+ if "mask" not in modality_masking:
+ return inputs, labels
+ inputs = self.mask_input(inputs, special_tokens_mask)
+ return inputs, labels
+ elif modality_masking == "mask":
+ inputs = self.mask_input(inputs, special_tokens_mask)
+ labels = torch.full(inputs.shape, -100)
+ return inputs, labels
+ elif modality_masking == "inverse":
+ probability_matrix = torch.full(labels.shape, 1. - self.mlm_probability)
+ else:
+ raise ValueError("unknown modality masking.", modality_masking)
+ else:
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
+
+ if special_tokens_mask is None:
+ special_tokens_mask = self.get_special_tokens_mask(
+ labels.tolist(), already_has_special_tokens=True
+ )
+ special_tokens_mask = torch.tensor(
+ special_tokens_mask, dtype=torch.bool)
+ else:
+ special_tokens_mask = special_tokens_mask.bool()
+
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
+ masked_indices = torch.bernoulli(probability_matrix).bool()
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
+
+ # 80% of the time,
+ # we replace masked input tokens with tokenizer.mask_token ([MASK])
+ indices_replaced = (
+ torch.bernoulli(
+ torch.full(labels.shape, 0.8)).bool() & masked_indices
+ )
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
+ self.tokenizer.mask_token
+ )
+
+ # 10% of the time, we replace masked input tokens with random word
+ indices_random = (
+ torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
+ & masked_indices
+ & ~indices_replaced
+ )
+ random_words = torch.randint(
+ len(self.tokenizer), labels.shape, dtype=torch.long
+ )
+ inputs[indices_random] = random_words[indices_random]
+
+ # The rest of the time (10% of the time) we keep the masked input
+ # tokens unchanged
+ return inputs, labels
+
+ def mask_input(self, inputs, special_tokens_mask=None):
+ # the following is new with masked autoregressive.
+ probability_matrix = torch.full(
+ inputs.shape, self.mlm_probability)
+ if special_tokens_mask is None:
+ special_tokens_mask = self.get_special_tokens_mask(
+ inputs.tolist(), already_has_special_tokens=True
+ )
+ special_tokens_mask = torch.tensor(
+ special_tokens_mask, dtype=torch.bool)
+ else:
+ special_tokens_mask = special_tokens_mask.bool()
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
+ masked_indices = torch.bernoulli(probability_matrix).bool()
+ indices_replaced = (
+ torch.bernoulli(
+ torch.full(inputs.shape, 0.8)).bool() & masked_indices
+ )
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
+ self.tokenizer.mask_token
+ )
+
+ # 10% of the time, we replace masked input tokens with random word
+ indices_random = (
+ torch.bernoulli(torch.full(inputs.shape, 0.5)).bool()
+ & masked_indices
+ & ~indices_replaced
+ )
+ random_words = torch.randint(
+ len(self.tokenizer), inputs.shape, dtype=torch.long
+ )
+ inputs[indices_random] = random_words[indices_random]
+ return inputs
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int],
+ token_ids_1: Optional[List[int]] = None,
+ already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Note: the version from transformers do not consider pad
+ as special tokens.
+ """
+
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if"
+ "the provided sequence of "
+ "ids is already formated with special tokens "
+ "for the model."
+ )
+ return list(map(lambda x: 1 if x in [
+ self.tokenizer.sep_token_id,
+ self.tokenizer.cls_token_id,
+ self.tokenizer.pad_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+
+class TextClipSamplingProcessor(Processor):
+ def __init__(self, max_text_len, keep_prob=1.0):
+ self.max_text_len = max_text_len
+ self.max_video_len = 256 # always hold.
+ self.keep_prob = keep_prob
+
+ def __call__(
+ self,
+ text_feature,
+ centerclip_idx=None,
+ sampled_max_text_len=None,
+ sampled_max_video_len=None,
+ ):
+ # Let's use all caps for now and see if 256 can cover all of them.
+ if sampled_max_text_len is not None:
+ max_text_len = sampled_max_text_len
+ else:
+ max_text_len = self.max_text_len
+ if sampled_max_video_len is not None:
+ max_video_len = sampled_max_video_len
+ else:
+ max_video_len = self.max_video_len
+
+ t_num_clips = len(text_feature["start"])
+
+ if centerclip_idx is None:
+ centerclip_idx = random.randint(0, t_num_clips - 1)
+
+ start_idx, end_idx = centerclip_idx, centerclip_idx + 1
+ text_clip_indexs = deque()
+ text_clip_indexs.append(start_idx)
+ text_len = len(text_feature["cap"][start_idx])
+
+ video_len = max(
+ 0,
+ text_feature["end"][start_idx]
+ - text_feature["start"][start_idx],
+ )
+
+ while (
+ (start_idx > 0 or end_idx < t_num_clips)
+ and text_len < max_text_len
+ and video_len < max_video_len
+ ):
+ if random.random() > 0.5 and end_idx < t_num_clips:
+ # skip the next one?
+ if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
+ end_idx = end_idx + 1
+ text_clip_indexs.append(end_idx)
+ text_len += len(text_feature["cap"][end_idx])
+ end_idx += 1
+ elif start_idx > 0:
+ if random.random() > self.keep_prob and (start_idx - 1) > 0:
+ start_idx = start_idx - 1
+ start_idx -= 1
+ text_clip_indexs.insert(0, start_idx)
+ text_len += len(text_feature["cap"][start_idx])
+ else:
+ if end_idx < t_num_clips:
+ if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
+ end_idx = end_idx + 1
+ text_clip_indexs.append(end_idx)
+ text_len += len(text_feature["cap"][end_idx])
+ end_idx += 1
+ else:
+ return text_clip_indexs
+ video_len = max(
+ 0,
+ text_feature["end"][text_clip_indexs[-1]]
+ - text_feature["start"][text_clip_indexs[0]],
+ )
+ return text_clip_indexs
+
+
+class VideoClipSamplingProcessor(Processor):
+ def __call__(self, video_len, max_video_len, center):
+ """
+ `video_len`: length of the video.
+ `max_video_len`: maximum video tokens allowd in a sequence.
+ `center`: initial starting index.
+ """
+ assert center >= 0 and center < video_len
+ t_clip_len = 0
+ start, end = center, center
+ while (start > 0 or end < video_len) and t_clip_len < max_video_len:
+ # decide the direction to grow.
+ if start <= 0:
+ end += 1
+ elif end >= video_len:
+ start -= 1
+ elif random.random() > 0.5:
+ end += 1
+ else:
+ start -= 1
+ t_clip_len += 1
+ return {"start": [start], "end": [end]}
+
+
+class How2MILNCEAligner(FixedLenAligner):
+ """reference: `antoine77340/MIL-NCE_HowTo100M/video_loader.py`"""
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_candidates = 4
+ self.min_time = 5.0
+ self.num_sec = 3.2
+ # self.num_sec = self.num_frames / float(self.fps) num_frames=16 / fps = 5
+ # self.num_frames = 16
+
+ def sampling(
+ self,
+ video_id,
+ video_feature,
+ text_feature,
+ centerclip_idx=None, # will be ignored.
+ sampled_max_text_len=None # will be ignored.
+ ):
+ text, start, end = self._get_text(text_feature)
+ video = self._get_video(video_feature, start, end)
+
+ vfeats = torch.zeros((self.max_video_len, video_feature.shape[1]))
+ vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
+ vfeats[: video.shape[0]] = torch.from_numpy(np.array(video))
+ vmasks[: video.shape[0]] = 1
+
+ caps, cmasks = [], []
+ for words in text:
+ cap, cmask = self._build_text_seq(text_feature, words)
+ caps.append(cap)
+ cmasks.append(cmask)
+ caps = torch.stack(caps)
+ cmasks = torch.stack(cmasks)
+ # video of shape: (video_len)
+ # text of shape (num_candidates, max_text_len)
+
+ return {
+ "caps": caps,
+ "cmasks": cmasks,
+ "vfeats": vfeats,
+ "vmasks": vmasks,
+ # "video_id": video_id,
+ }
+
+ def _get_video(self, video_feature, start, end):
+ start_seek = random.randint(start, int(max(start, end - self.num_sec)))
+ # duration = self.num_sec + 0.1
+ return video_feature[start_seek : int(start_seek + self.num_sec)]
+
+ def _get_text(self, cap):
+ ind = random.randint(0, len(cap["start"]) - 1)
+ if self.num_candidates == 1:
+ words = [ind]
+ else:
+ words = []
+ cap_start = self._find_nearest_candidates(cap, ind)
+ for i in range(self.num_candidates):
+ words.append([max(0, min(len(cap["cap"]) - 1, cap_start + i))])
+
+ start, end = cap["start"][ind], cap["end"][ind]
+ # TODO: May need to be improved for edge cases.
+ # expand the min time.
+ if end - start < self.min_time:
+ diff = self.min_time - end + start
+ start = max(0, start - diff / 2)
+ end = start + self.min_time
+ return words, int(start), int(end)
+
+ def _find_nearest_candidates(self, caption, ind):
+ """find the range of the clips."""
+ start, end = ind, ind
+ #diff = caption["end"][end] - caption["start"][start]
+ n_candidate = 1
+ while n_candidate < self.num_candidates:
+ # the first clip
+ if start == 0:
+ return 0
+ # we add () in the following condition to fix the bug.
+ elif end == (len(caption["start"]) - 1):
+ return start - (self.num_candidates - n_candidate)
+ elif (caption["end"][end] - caption["start"][start - 1]) < (
+ caption["end"][end + 1] - caption["start"][start]
+ ):
+ start -= 1
+ else:
+ end += 1
+ n_candidate += 1
+ return start
+
+
+class PKLJSONStrTextProcessor(TextProcessor):
+ """`caption.json` from howto100m are preprocessed as a
+ dict `[video_id, json_str]`.
+ Json parsing tokenization are conducted on-the-fly and cached into dict.
+ """
+
+ def __init__(self, config, max_clip_text_len=96):
+ print("[Warning] PKLJSONStrTextProcessor is slow for num_workers > 0.")
+ self.caption_pkl_path = str(config.caption_pkl_path)
+ with open(self.caption_pkl_path, "rb") as fd:
+ self.data = pickle.load(fd)
+ self.max_clip_text_len = max_clip_text_len
+ from transformers import AutoTokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ str(config.bert_name), use_fast=config.use_fast
+ )
+
+ def __call__(self, video_id):
+ caption = self.data[video_id]
+ if isinstance(caption, str):
+ import json
+ caption = json.loads(caption)
+ cap = []
+ for clip_idx, text_clip in enumerate(caption["text"]):
+ clip_ids = []
+ if isinstance(text_clip, str):
+ clip_ids = self.tokenizer(
+ text_clip[: self.max_clip_text_len],
+ add_special_tokens=False
+ )["input_ids"]
+ cap.append(clip_ids)
+ caption["cap"] = cap
+ caption.pop("text") # save space.
+ self.data[video_id] = caption
+ return caption
diff --git a/examples/MMPT/mmpt/processors/how2retriprocessor.py b/examples/MMPT/mmpt/processors/how2retriprocessor.py
new file mode 100644
index 0000000000..b5a7730ec0
--- /dev/null
+++ b/examples/MMPT/mmpt/processors/how2retriprocessor.py
@@ -0,0 +1,100 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .how2processor import (
+ ShardedHow2MetaProcessor,
+ ShardedVideoProcessor,
+ ShardedTextProcessor,
+ VariedLenAligner,
+ OverlappedAligner
+)
+
+
+class ShardedHow2VideoRetriMetaProcessor(ShardedHow2MetaProcessor):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_video_per_batch = config.num_video_per_batch
+ self.cands = [
+ self.data[batch_offset:batch_offset + self.num_video_per_batch]
+ for batch_offset in
+ range(0, (len(self.data) // (8 * self.num_video_per_batch)) * 8 * self.num_video_per_batch, self.num_video_per_batch)]
+
+ def __len__(self):
+ return len(self.cands)
+
+ def set_candidates(self, cands):
+ # no changes on num of batches.
+ print(len(self.cands), "->", len(cands))
+ # assert len(self.cands) == len(cands)
+ self.cands = cands
+
+ def __getitem__(self, idx):
+ video_ids = self.cands[idx]
+ assert isinstance(video_ids, list)
+ sharded_video_idxs = []
+ for video_id in video_ids:
+ shard_id, video_idx = self.video_id_to_shard[video_id]
+ sharded_video_idxs.append((video_id, -1, shard_id, video_idx))
+ return sharded_video_idxs, sharded_video_idxs
+
+
+class ShardedVideoRetriVideoProcessor(ShardedVideoProcessor):
+ """In retrival case the video_id
+ is a list of tuples: `(shard_id, video_idx)` ."""
+
+ def __call__(self, sharded_video_idxs):
+ assert isinstance(sharded_video_idxs, list)
+ cand_feats = []
+ for shared_video_idx in sharded_video_idxs:
+ feat = super().__call__(shared_video_idx)
+ cand_feats.append(feat)
+ return cand_feats
+
+
+class ShardedVideoRetriTextProcessor(ShardedTextProcessor):
+ """In retrival case the video_id
+ is a list of tuples: `(shard_id, video_idx)` ."""
+
+ def __call__(self, sharded_video_idxs):
+ assert isinstance(sharded_video_idxs, list)
+ cand_caps = []
+ for shared_video_idx in sharded_video_idxs:
+ caps = super().__call__(shared_video_idx)
+ cand_caps.append(caps)
+ return cand_caps
+
+
+class VideoRetriAligner(VariedLenAligner):
+ # Retritask will trim dim-0.
+ def __call__(self, sharded_video_idxs, video_features, text_features):
+ from transformers import default_data_collator
+ batch, video_ids = [], []
+ for video_id, video_feature, text_feature in \
+ zip(sharded_video_idxs, video_features, text_features):
+ sub_batch = super().__call__(video_id, video_feature, text_feature)
+ batch.append(sub_batch)
+ if isinstance(video_id, tuple):
+ video_id = video_id[0]
+ video_ids.append(video_id)
+ batch = default_data_collator(batch)
+ batch["video_id"] = video_ids
+ return batch
+
+
+class VideoRetriOverlappedAligner(OverlappedAligner):
+ # Retritask will trim dim-0.
+ def __call__(self, sharded_video_idxs, video_features, text_features):
+ from transformers import default_data_collator
+ batch, video_ids = [], []
+ for video_id, video_feature, text_feature in \
+ zip(sharded_video_idxs, video_features, text_features):
+ sub_batch = super().__call__(video_id, video_feature, text_feature)
+ batch.append(sub_batch)
+ if isinstance(video_id, tuple):
+ video_id = video_id[0]
+ video_ids.append(video_id)
+ batch = default_data_collator(batch)
+ batch["video_id"] = video_ids
+ return batch
diff --git a/examples/MMPT/mmpt/processors/models/s3dg.py b/examples/MMPT/mmpt/processors/models/s3dg.py
new file mode 100644
index 0000000000..6c7a691e33
--- /dev/null
+++ b/examples/MMPT/mmpt/processors/models/s3dg.py
@@ -0,0 +1,336 @@
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Contains a PyTorch definition for Gated Separable 3D network (S3D-G)
+with a text module for computing joint text-video embedding from raw text
+and video input. The following code will enable you to load the HowTo100M
+pretrained S3D Text-Video model from:
+ A. Miech, J.-B. Alayrac, L. Smaira, I. Laptev, J. Sivic and A. Zisserman,
+ End-to-End Learning of Visual Representations from Uncurated Instructional Videos.
+ https://arxiv.org/abs/1912.06430.
+
+S3D-G was proposed by:
+ S. Xie, C. Sun, J. Huang, Z. Tu and K. Murphy,
+ Rethinking Spatiotemporal Feature Learning For Video Understanding.
+ https://arxiv.org/abs/1712.04851.
+ Tensorflow code: https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
+
+The S3D architecture was slightly modified with a space to depth trick for TPU
+optimization.
+"""
+
+import torch as th
+import torch.nn.functional as F
+import torch.nn as nn
+import os
+import numpy as np
+import re
+
+
+class InceptionBlock(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_outputs_0_0a,
+ num_outputs_1_0a,
+ num_outputs_1_0b,
+ num_outputs_2_0a,
+ num_outputs_2_0b,
+ num_outputs_3_0b,
+ gating=True,
+ ):
+ super(InceptionBlock, self).__init__()
+ self.conv_b0 = STConv3D(input_dim, num_outputs_0_0a, [1, 1, 1])
+ self.conv_b1_a = STConv3D(input_dim, num_outputs_1_0a, [1, 1, 1])
+ self.conv_b1_b = STConv3D(
+ num_outputs_1_0a, num_outputs_1_0b, [3, 3, 3], padding=1, separable=True
+ )
+ self.conv_b2_a = STConv3D(input_dim, num_outputs_2_0a, [1, 1, 1])
+ self.conv_b2_b = STConv3D(
+ num_outputs_2_0a, num_outputs_2_0b, [3, 3, 3], padding=1, separable=True
+ )
+ self.maxpool_b3 = th.nn.MaxPool3d((3, 3, 3), stride=1, padding=1)
+ self.conv_b3_b = STConv3D(input_dim, num_outputs_3_0b, [1, 1, 1])
+ self.gating = gating
+ self.output_dim = (
+ num_outputs_0_0a + num_outputs_1_0b + num_outputs_2_0b + num_outputs_3_0b
+ )
+ if gating:
+ self.gating_b0 = SelfGating(num_outputs_0_0a)
+ self.gating_b1 = SelfGating(num_outputs_1_0b)
+ self.gating_b2 = SelfGating(num_outputs_2_0b)
+ self.gating_b3 = SelfGating(num_outputs_3_0b)
+
+ def forward(self, input):
+ """Inception block
+ """
+ b0 = self.conv_b0(input)
+ b1 = self.conv_b1_a(input)
+ b1 = self.conv_b1_b(b1)
+ b2 = self.conv_b2_a(input)
+ b2 = self.conv_b2_b(b2)
+ b3 = self.maxpool_b3(input)
+ b3 = self.conv_b3_b(b3)
+ if self.gating:
+ b0 = self.gating_b0(b0)
+ b1 = self.gating_b1(b1)
+ b2 = self.gating_b2(b2)
+ b3 = self.gating_b3(b3)
+ return th.cat((b0, b1, b2, b3), dim=1)
+
+
+class SelfGating(nn.Module):
+ def __init__(self, input_dim):
+ super(SelfGating, self).__init__()
+ self.fc = nn.Linear(input_dim, input_dim)
+
+ def forward(self, input_tensor):
+ """Feature gating as used in S3D-G.
+ """
+ spatiotemporal_average = th.mean(input_tensor, dim=[2, 3, 4])
+ weights = self.fc(spatiotemporal_average)
+ weights = th.sigmoid(weights)
+ return weights[:, :, None, None, None] * input_tensor
+
+
+class STConv3D(nn.Module):
+ def __init__(
+ self, input_dim, output_dim, kernel_size, stride=1, padding=0, separable=False
+ ):
+ super(STConv3D, self).__init__()
+ self.separable = separable
+ self.relu = nn.ReLU(inplace=True)
+ assert len(kernel_size) == 3
+ if separable and kernel_size[0] != 1:
+ spatial_kernel_size = [1, kernel_size[1], kernel_size[2]]
+ temporal_kernel_size = [kernel_size[0], 1, 1]
+ if isinstance(stride, list) and len(stride) == 3:
+ spatial_stride = [1, stride[1], stride[2]]
+ temporal_stride = [stride[0], 1, 1]
+ else:
+ spatial_stride = [1, stride, stride]
+ temporal_stride = [stride, 1, 1]
+ if isinstance(padding, list) and len(padding) == 3:
+ spatial_padding = [0, padding[1], padding[2]]
+ temporal_padding = [padding[0], 0, 0]
+ else:
+ spatial_padding = [0, padding, padding]
+ temporal_padding = [padding, 0, 0]
+ if separable:
+ self.conv1 = nn.Conv3d(
+ input_dim,
+ output_dim,
+ kernel_size=spatial_kernel_size,
+ stride=spatial_stride,
+ padding=spatial_padding,
+ bias=False,
+ )
+ self.bn1 = nn.BatchNorm3d(output_dim)
+ self.conv2 = nn.Conv3d(
+ output_dim,
+ output_dim,
+ kernel_size=temporal_kernel_size,
+ stride=temporal_stride,
+ padding=temporal_padding,
+ bias=False,
+ )
+ self.bn2 = nn.BatchNorm3d(output_dim)
+ else:
+ self.conv1 = nn.Conv3d(
+ input_dim,
+ output_dim,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.bn1 = nn.BatchNorm3d(output_dim)
+
+ def forward(self, input):
+ out = self.relu(self.bn1(self.conv1(input)))
+ if self.separable:
+ out = self.relu(self.bn2(self.conv2(out)))
+ return out
+
+
+class MaxPool3dTFPadding(th.nn.Module):
+ def __init__(self, kernel_size, stride=None, padding="SAME"):
+ super(MaxPool3dTFPadding, self).__init__()
+ if padding == "SAME":
+ padding_shape = self._get_padding_shape(kernel_size, stride)
+ self.padding_shape = padding_shape
+ self.pad = th.nn.ConstantPad3d(padding_shape, 0)
+ self.pool = th.nn.MaxPool3d(kernel_size, stride, ceil_mode=True)
+
+ def _get_padding_shape(self, filter_shape, stride):
+ def _pad_top_bottom(filter_dim, stride_val):
+ pad_along = max(filter_dim - stride_val, 0)
+ pad_top = pad_along // 2
+ pad_bottom = pad_along - pad_top
+ return pad_top, pad_bottom
+
+ padding_shape = []
+ for filter_dim, stride_val in zip(filter_shape, stride):
+ pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val)
+ padding_shape.append(pad_top)
+ padding_shape.append(pad_bottom)
+ depth_top = padding_shape.pop(0)
+ depth_bottom = padding_shape.pop(0)
+ padding_shape.append(depth_top)
+ padding_shape.append(depth_bottom)
+ return tuple(padding_shape)
+
+ def forward(self, inp):
+ inp = self.pad(inp)
+ out = self.pool(inp)
+ return out
+
+
+class Sentence_Embedding(nn.Module):
+ def __init__(
+ self,
+ embd_dim,
+ num_embeddings=66250,
+ word_embedding_dim=300,
+ token_to_word_path="dict.npy",
+ max_words=16,
+ output_dim=2048,
+ ):
+ super(Sentence_Embedding, self).__init__()
+ self.word_embd = nn.Embedding(num_embeddings, word_embedding_dim)
+ self.fc1 = nn.Linear(word_embedding_dim, output_dim)
+ self.fc2 = nn.Linear(output_dim, embd_dim)
+ self.word_to_token = {}
+ self.max_words = max_words
+ token_to_word = np.load(token_to_word_path)
+ for i, t in enumerate(token_to_word):
+ self.word_to_token[t] = i + 1
+
+ def _zero_pad_tensor_token(self, tensor, size):
+ if len(tensor) >= size:
+ return tensor[:size]
+ else:
+ zero = th.zeros(size - len(tensor)).long()
+ return th.cat((tensor, zero), dim=0)
+
+ def _split_text(self, sentence):
+ w = re.findall(r"[\w']+", str(sentence))
+ return w
+
+ def _words_to_token(self, words):
+ words = [
+ self.word_to_token[word] for word in words if word in self.word_to_token
+ ]
+ if words:
+ we = self._zero_pad_tensor_token(th.LongTensor(words), self.max_words)
+ return we
+ else:
+ return th.zeros(self.max_words).long()
+
+ def _words_to_ids(self, x):
+ split_x = [self._words_to_token(self._split_text(sent.lower())) for sent in x]
+ return th.stack(split_x, dim=0)
+
+ def forward(self, x):
+ x = self._words_to_ids(x)
+ x = self.word_embd(x)
+ x = F.relu(self.fc1(x))
+ x = th.max(x, dim=1)[0]
+ x = self.fc2(x)
+ return {'text_embedding': x}
+
+
+class S3D(nn.Module):
+ def __init__(self, dict_path, num_classes=512, gating=True, space_to_depth=True):
+ super(S3D, self).__init__()
+ self.num_classes = num_classes
+ self.gating = gating
+ self.space_to_depth = space_to_depth
+ if space_to_depth:
+ self.conv1 = STConv3D(
+ 24, 64, [2, 4, 4], stride=1, padding=(1, 2, 2), separable=False
+ )
+ else:
+ self.conv1 = STConv3D(
+ 3, 64, [3, 7, 7], stride=2, padding=(1, 3, 3), separable=False
+ )
+ self.conv_2b = STConv3D(64, 64, [1, 1, 1], separable=False)
+ self.conv_2c = STConv3D(64, 192, [3, 3, 3], padding=1, separable=True)
+ self.gating = SelfGating(192)
+ self.maxpool_2a = MaxPool3dTFPadding(
+ kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
+ )
+ self.maxpool_3a = MaxPool3dTFPadding(
+ kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
+ )
+ self.mixed_3b = InceptionBlock(192, 64, 96, 128, 16, 32, 32)
+ self.mixed_3c = InceptionBlock(
+ self.mixed_3b.output_dim, 128, 128, 192, 32, 96, 64
+ )
+ self.maxpool_4a = MaxPool3dTFPadding(
+ kernel_size=(3, 3, 3), stride=(2, 2, 2), padding="SAME"
+ )
+ self.mixed_4b = InceptionBlock(
+ self.mixed_3c.output_dim, 192, 96, 208, 16, 48, 64
+ )
+ self.mixed_4c = InceptionBlock(
+ self.mixed_4b.output_dim, 160, 112, 224, 24, 64, 64
+ )
+ self.mixed_4d = InceptionBlock(
+ self.mixed_4c.output_dim, 128, 128, 256, 24, 64, 64
+ )
+ self.mixed_4e = InceptionBlock(
+ self.mixed_4d.output_dim, 112, 144, 288, 32, 64, 64
+ )
+ self.mixed_4f = InceptionBlock(
+ self.mixed_4e.output_dim, 256, 160, 320, 32, 128, 128
+ )
+ self.maxpool_5a = self.maxPool3d_5a_2x2 = MaxPool3dTFPadding(
+ kernel_size=(2, 2, 2), stride=(2, 2, 2), padding="SAME"
+ )
+ self.mixed_5b = InceptionBlock(
+ self.mixed_4f.output_dim, 256, 160, 320, 32, 128, 128
+ )
+ self.mixed_5c = InceptionBlock(
+ self.mixed_5b.output_dim, 384, 192, 384, 48, 128, 128
+ )
+ self.fc = nn.Linear(self.mixed_5c.output_dim, num_classes)
+ self.text_module = Sentence_Embedding(num_classes,
+ token_to_word_path=dict_path)
+
+ def _space_to_depth(self, input):
+ """3D space to depth trick for TPU optimization.
+ """
+ B, C, T, H, W = input.shape
+ input = input.view(B, C, T // 2, 2, H // 2, 2, W // 2, 2)
+ input = input.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ input = input.contiguous().view(B, 8 * C, T // 2, H // 2, W // 2)
+ return input
+
+ def forward(self, inputs):
+ """Defines the S3DG base architecture."""
+ if self.space_to_depth:
+ inputs = self._space_to_depth(inputs)
+ net = self.conv1(inputs)
+ if self.space_to_depth:
+ # we need to replicate 'SAME' tensorflow padding
+ net = net[:, :, 1:, 1:, 1:]
+ net = self.maxpool_2a(net)
+ net = self.conv_2b(net)
+ net = self.conv_2c(net)
+ if self.gating:
+ net = self.gating(net)
+ net = self.maxpool_3a(net)
+ net = self.mixed_3b(net)
+ net = self.mixed_3c(net)
+ net = self.maxpool_4a(net)
+ net = self.mixed_4b(net)
+ net = self.mixed_4c(net)
+ net = self.mixed_4d(net)
+ net = self.mixed_4e(net)
+ net = self.mixed_4f(net)
+ net = self.maxpool_5a(net)
+ net = self.mixed_5b(net)
+ net = self.mixed_5c(net)
+ net = th.mean(net, dim=[2, 3, 4])
+ return {'video_embedding': self.fc(net), 'mixed_5c': net}
diff --git a/examples/MMPT/mmpt/processors/processor.py b/examples/MMPT/mmpt/processors/processor.py
new file mode 100644
index 0000000000..98edb051f1
--- /dev/null
+++ b/examples/MMPT/mmpt/processors/processor.py
@@ -0,0 +1,274 @@
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import numpy as np
+import os
+import torch
+
+
+class Processor(object):
+ """
+ A generic processor for video (codec, feature etc.) and text.
+ """
+
+ def __call__(self, **kwargs):
+ raise NotImplementedError
+
+
+class MetaProcessor(Processor):
+ """
+ A meta processor is expected to load the metadata of a dataset:
+ (e.g., video_ids, or captions).
+ You must implement the `__getitem__` (meta datasets are rather diverse.).
+ """
+
+ def __init__(self, config):
+ self.split = config.split
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ raise NotImplementedError
+
+ def _get_split_path(self, config):
+ splits = {
+ "train": config.train_path,
+ "valid": config.val_path,
+ "test": config.test_path,
+ }
+ if config.split is not None:
+ return splits[config.split]
+ return config.train_path
+
+
+class TextProcessor(Processor):
+ """
+ A generic Text processor: rename this as `withTokenizer`.
+ tokenize a string of text on-the-fly.
+ Warning: mostly used for end tasks.
+ (on-the-fly tokenization is slow for how2.)
+ TODO(huxu): move this class as a subclass.
+ """
+
+ def __init__(self, config):
+ self.bert_name = str(config.bert_name)
+ self.use_fast = config.use_fast
+ from transformers import AutoTokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.bert_name, use_fast=self.use_fast
+ )
+
+ def __call__(self, text_id):
+ caption = self.tokenizer(text_id, add_special_tokens=False)
+ return caption["input_ids"]
+
+
+class VideoProcessor(Processor):
+ """
+ A generic video processor: load a numpy video tokens by default.
+ """
+
+ def __init__(self, config):
+ self.vfeat_dir = config.vfeat_dir
+
+ def __call__(self, video_fn):
+ if isinstance(video_fn, tuple):
+ video_fn = video_fn[0]
+ assert isinstance(video_fn, str)
+ video_fn = os.path.join(self.vfeat_dir, video_fn + ".npy")
+ feat = np.load(video_fn)
+ return feat
+
+
+class Aligner(object):
+ """
+ An alignprocessor align video and text and output a dict of tensors (for a model).
+ """
+ def __init__(self, config):
+ """__init__ needs to be light weight for more workers/threads."""
+ self.split = config.split
+ self.max_video_len = config.max_video_len
+ self.max_len = config.max_len
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ str(config.bert_name), use_fast=config.use_fast
+ )
+ self.cls_token_id = tokenizer.cls_token_id
+ self.sep_token_id = tokenizer.sep_token_id
+ self.pad_token_id = tokenizer.pad_token_id
+ self.mask_token_id = tokenizer.mask_token_id
+
+ def __call__(self, video_id, video_feature, text_feature):
+ raise NotImplementedError
+
+ def _build_video_seq(self, video_feature, video_clips=None):
+ """
+ `video_feature`: available video tokens.
+ `video_clips`: video clip sequence to build.
+ """
+ if not isinstance(video_feature, np.ndarray):
+ raise ValueError(
+ "unsupported type of video_feature", type(video_feature)
+ )
+
+ if video_clips is None:
+ # this is borrowed from DSAligner
+ video_start = 0
+ video_end = min(len(video_feature), self.max_video_len)
+ # the whole sequence is a single clip.
+ video_clips = {"start": [video_start], "end": [video_end]}
+
+ vfeats = np.zeros(
+ (self.max_video_len, video_feature.shape[1]), dtype=np.float32
+ )
+ vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
+ video_len = 0
+ for start, end in zip(video_clips["start"], video_clips["end"]):
+ clip_len = min(self.max_video_len - video_len, (end - start))
+ if clip_len > 0:
+ vfeats[video_len: video_len + clip_len] = video_feature[
+ start: start + clip_len
+ ]
+ vmasks[video_len: video_len + clip_len] = 1
+ video_len += clip_len
+ vfeats = torch.from_numpy(vfeats)
+
+ return vfeats, vmasks
+
+ def _build_text_seq(self, text_feature, text_clip_indexs=None):
+ """
+ `text_feature`: all available clips.
+ `text_clip_indexes`: clip sequence to build.
+ """
+ if text_clip_indexs is None:
+ text_clip_indexs = [0]
+
+ full_caps = []
+ if isinstance(text_feature, dict):
+ for clip_idx in text_clip_indexs:
+ full_caps.extend(text_feature["cap"][clip_idx])
+ else:
+ full_caps = text_feature
+ max_text_len = self.max_len - self.max_video_len - 3
+ full_caps = full_caps[:max_text_len]
+ full_caps = (
+ [self.cls_token_id, self.sep_token_id] + full_caps + [self.sep_token_id]
+ )
+ text_pad_len = self.max_len - len(full_caps) - self.max_video_len
+ padded_full_caps = full_caps + [self.pad_token_id] * text_pad_len
+ caps = torch.LongTensor(padded_full_caps)
+ cmasks = torch.zeros((len(padded_full_caps),), dtype=torch.bool)
+ cmasks[: len(full_caps)] = 1
+
+ return caps, cmasks
+
+ def batch_post_processing(self, batch, video_feature):
+ return batch
+
+
+class MMAttentionMask2DProcessor(Processor):
+ """text generation requires 2d mask
+ that is harder to generate by GPU at this stage."""
+
+ def __call__(self, vmask, cmask, mtype):
+ if mtype == "textgen":
+ return self._build_textgeneration_mask(vmask, cmask)
+ elif mtype == "videogen":
+ return self._build_videogeneration_mask(vmask, cmask)
+ else:
+ return self._build_mm_mask(vmask, cmask)
+
+ def _build_mm_mask(self, vmask, cmask):
+ mask_1d = torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)
+ return mask_1d[None, :].repeat(mask_1d.size(0), 1)
+
+ def _build_videogeneration_mask(self, vmask, cmask):
+ # cls_mask is only about text otherwise it will leak generation.
+ cls_text_mask = torch.cat([
+ # [CLS]
+ torch.ones(
+ (1,), dtype=torch.bool, device=cmask.device),
+ # video tokens and [SEP] for video.
+ torch.zeros(
+ (vmask.size(0) + 1,), dtype=torch.bool, device=cmask.device),
+ cmask[2:]
+ ], dim=0)
+
+ # concat horizontially.
+ video_len = int(vmask.sum())
+ video_masks = torch.cat([
+ # [CLS]
+ torch.ones(
+ (video_len, 1), dtype=torch.bool, device=cmask.device
+ ),
+ torch.tril(
+ torch.ones(
+ (video_len, video_len),
+ dtype=torch.bool, device=cmask.device)),
+ # video_padding
+ torch.zeros(
+ (video_len, vmask.size(0) - video_len),
+ dtype=torch.bool, device=cmask.device
+ ),
+ # [SEP] for video (unused).
+ torch.zeros(
+ (video_len, 1), dtype=torch.bool, device=cmask.device
+ ),
+ cmask[2:].unsqueeze(0).repeat(video_len, 1)
+ ], dim=1)
+
+ text_masks = cls_text_mask[None, :].repeat(
+ cmask.size(0) - 2, 1)
+ video_padding_masks = cls_text_mask[None, :].repeat(
+ vmask.size(0) - video_len, 1)
+
+ return torch.cat([
+ cls_text_mask[None, :],
+ video_masks,
+ video_padding_masks,
+ torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)[None,:],
+ text_masks
+ ], dim=0)
+
+ def _build_textgeneration_mask(self, vmask, cmask):
+ # cls_mask is only about video otherwise it will leak generation.
+ cls_video_mask = torch.cat([
+ # [CLS]
+ torch.ones(
+ (1,), dtype=torch.bool, device=cmask.device),
+ vmask,
+ # [SEP]
+ torch.ones((1,), dtype=torch.bool, device=cmask.device),
+ torch.zeros(
+ (cmask.size(0)-2,), dtype=torch.bool, device=cmask.device)
+ ], dim=0)
+
+ # concat horizontially.
+ text_len = int(cmask[2:].sum())
+ text_masks = torch.cat([
+ # [CLS]
+ torch.ones(
+ (text_len, 1), dtype=torch.bool, device=cmask.device
+ ),
+ vmask.unsqueeze(0).repeat(text_len, 1),
+ # [SEP] for video.
+ torch.ones(
+ (text_len, 1), dtype=torch.bool, device=cmask.device
+ ),
+ torch.tril(
+ torch.ones(
+ (text_len, text_len),
+ dtype=torch.bool, device=cmask.device)),
+ # padding.
+ torch.zeros(
+ (text_len, cmask.size(0) - text_len - 2),
+ dtype=torch.bool, device=cmask.device
+ )
+ ], dim=1)
+
+ cls_video_masks = cls_video_mask[None, :].repeat(
+ vmask.size(0) + 2, 1)
+ text_padding_masks = cls_video_mask[None, :].repeat(
+ cmask.size(0) - text_len - 2, 1)
+ return torch.cat([
+ cls_video_masks, text_masks, text_padding_masks], dim=0)
diff --git a/examples/MMPT/mmpt/tasks/__init__.py b/examples/MMPT/mmpt/tasks/__init__.py
new file mode 100644
index 0000000000..e2e9323a53
--- /dev/null
+++ b/examples/MMPT/mmpt/tasks/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .task import *
+from .vlmtask import *
+from .retritask import *
+
+try:
+ from .fairseqmmtask import *
+except ImportError:
+ pass
+
+try:
+ from .milncetask import *
+except ImportError:
+ pass
+
+try:
+ from .expretritask import *
+except ImportError:
+ pass
diff --git a/examples/MMPT/mmpt/tasks/fairseqmmtask.py b/examples/MMPT/mmpt/tasks/fairseqmmtask.py
new file mode 100644
index 0000000000..f6b6115a39
--- /dev/null
+++ b/examples/MMPT/mmpt/tasks/fairseqmmtask.py
@@ -0,0 +1,104 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+make a general fairseq task for MM pretraining.
+"""
+
+import random
+
+from fairseq.tasks import LegacyFairseqTask, register_task
+
+from .task import Task
+from .retritask import RetriTask
+from ..datasets import FairseqMMDataset
+from .. import utils
+
+
+@register_task("mmtask")
+class FairseqMMTask(LegacyFairseqTask):
+ @staticmethod
+ def add_args(parser):
+ # Add some command-line arguments for specifying where the data is
+ # located and the maximum supported input length.
+ parser.add_argument(
+ "taskconfig",
+ metavar="FILE",
+ help=("taskconfig to load all configurations" "outside fairseq parser."),
+ )
+
+ @classmethod
+ def setup_task(cls, args, **kwargs):
+ return FairseqMMTask(args)
+
+ def __init__(self, args):
+ super().__init__(args)
+ config = utils.load_config(args)
+ self.mmtask = Task.config_task(config)
+ self.mmtask.build_dataset()
+ self.mmtask.build_model()
+ self.mmtask.build_loss()
+
+ def load_dataset(self, split, **kwargs):
+ split_map = {
+ "train": self.mmtask.train_data,
+ "valid": self.mmtask.val_data,
+ "test": self.mmtask.test_data,
+ }
+ if split not in split_map:
+ raise ValueError("unknown split type.")
+ if split_map[split] is not None:
+ self.datasets[split] = FairseqMMDataset(split_map[split])
+
+ def get_batch_iterator(
+ self,
+ dataset,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=1,
+ data_buffer_size=0,
+ disable_iterator_cache=False,
+ skip_remainder_batch=False,
+ grouped_shuffling=False,
+ update_epoch_batch_itr=False,
+ ):
+ random.seed(epoch)
+ if dataset.mmdataset.split == "train" and isinstance(self.mmtask, RetriTask):
+ if epoch >= self.mmtask.config.retri_epoch:
+ if not hasattr(self.mmtask, "retri_dataloader"):
+ self.mmtask.build_dataloader()
+ self.mmtask.retrive_candidates(epoch)
+
+ return super().get_batch_iterator(
+ dataset,
+ max_tokens,
+ max_sentences,
+ max_positions,
+ ignore_invalid_inputs,
+ required_batch_size_multiple,
+ seed,
+ num_shards,
+ shard_id,
+ num_workers,
+ epoch,
+ data_buffer_size,
+ disable_iterator_cache,
+ grouped_shuffling,
+ update_epoch_batch_itr,
+ )
+
+ @property
+ def source_dictionary(self):
+ return None
+
+ @property
+ def target_dictionary(self):
+ return None
diff --git a/examples/MMPT/mmpt/tasks/milncetask.py b/examples/MMPT/mmpt/tasks/milncetask.py
new file mode 100644
index 0000000000..61b6ab0597
--- /dev/null
+++ b/examples/MMPT/mmpt/tasks/milncetask.py
@@ -0,0 +1,27 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from .task import Task
+
+
+class MILNCETask(Task):
+ def reshape_subsample(self, sample):
+ if (
+ hasattr(self.config.dataset, "subsampling")
+ and self.config.dataset.subsampling is not None
+ and self.config.dataset.subsampling > 1
+ ):
+ for key in sample:
+ if torch.is_tensor(sample[key]):
+ tensor = self.flat_subsample(sample[key])
+ if key in ["caps", "cmasks"]:
+ size = tensor.size()
+ batch_size = size[0] * size[1]
+ expanded_size = (batch_size,) + size[2:]
+ tensor = tensor.view(expanded_size)
+ sample[key] = tensor
+ return sample
diff --git a/examples/MMPT/mmpt/tasks/retritask.py b/examples/MMPT/mmpt/tasks/retritask.py
new file mode 100644
index 0000000000..b43f20fddb
--- /dev/null
+++ b/examples/MMPT/mmpt/tasks/retritask.py
@@ -0,0 +1,253 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import torch
+import pickle
+import random
+
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from ..processors import (
+ ShardedHow2MetaProcessor,
+ ShardedVideoProcessor,
+ ShardedTextProcessor,
+ VariedLenAligner,
+)
+
+from ..datasets import MMDataset
+from .task import Task
+from ..modules import vectorpool
+from ..evaluators.predictor import Predictor
+from ..utils import set_seed, get_local_rank, get_world_size
+
+
+class RetriTask(Task):
+ """abstract class for task with retrival."""
+
+ def reshape_subsample(self, sample):
+ for key in sample:
+ if torch.is_tensor(sample[key]):
+ sample[key] = self.flat_subsample(sample[key])
+ return sample
+
+ def flat_subsample(self, tensor):
+ if tensor.size(0) == 1:
+ tensor = tensor.squeeze(0)
+ return tensor
+
+ def build_dataloader(self):
+ """called by `get_batch_iterator` in fairseqmmtask. """
+ # TODO: hard-code dataloader for retri for now and configurable in .yaml.
+ # reuse the `train.lst`.
+ self.config.dataset.split = "train"
+ meta_processor = ShardedHow2MetaProcessor(self.config.dataset)
+ video_processor = ShardedVideoProcessor(self.config.dataset)
+ text_processor = ShardedTextProcessor(self.config.dataset)
+
+ aligner = VariedLenAligner(self.config.dataset)
+ aligner.subsampling = self.config.dataset.clip_per_video
+
+ self.retri_data = MMDataset(
+ meta_processor, video_processor, text_processor, aligner
+ )
+
+ retri_sampler = DistributedSampler(self.retri_data)
+ infer_scale = 16
+ batch_size = self.config.dataset.num_video_per_batch \
+ * infer_scale
+
+ self.retri_dataloader = DataLoader(
+ self.retri_data,
+ collate_fn=self.retri_data.collater,
+ batch_size=batch_size,
+ shuffle=False,
+ sampler=retri_sampler,
+ num_workers=self.config.fairseq.dataset.num_workers
+ )
+ return self.retri_dataloader
+
+ def retrive_candidates(self, epoch, dataloader=None):
+ if get_local_rank() == 0:
+ print("running retrieval model.")
+ out_dir = os.path.join(
+ self.config.fairseq.checkpoint.save_dir, "retri")
+ os.makedirs(out_dir, exist_ok=True)
+
+ if not os.path.isfile(
+ os.path.join(
+ out_dir, "batched_e" + str(epoch) + "_videos0.pkl")
+ ):
+ if dataloader is None:
+ dataloader = self.retri_dataloader
+
+ self.model.eval()
+ self.model.is_train = False
+
+ assert self.retri_data.meta_processor.data == \
+ self.train_data.meta_processor.data # video_ids not mutated.
+
+ self._retri_predict(epoch, dataloader)
+
+ self.model.train()
+ self.model.is_train = True
+
+ torch.distributed.barrier()
+ output = self._retri_sync(epoch, out_dir)
+ torch.distributed.barrier()
+ self.train_data.meta_processor.set_candidates(output)
+ return output
+
+
+class VideoRetriTask(RetriTask):
+ """RetriTask on video level."""
+
+ def reshape_subsample(self, sample):
+ if (
+ hasattr(self.config.dataset, "clip_per_video")
+ and self.config.dataset.clip_per_video is not None
+ and self.config.dataset.clip_per_video > 1
+ ):
+ for key in sample:
+ if torch.is_tensor(sample[key]):
+ sample[key] = self.flat_subsample(sample[key])
+ return sample
+
+ def flat_subsample(self, tensor):
+ if tensor.size(0) == 1:
+ tensor = tensor.squeeze(0)
+ return Task.flat_subsample(self, tensor)
+
+ def _retri_predict(self, epoch, dataloader):
+ set_seed(epoch)
+ # save for retrival.
+ predictor = VideoPredictor(self.config)
+ predictor.predict_loop(
+ self.model, dataloader)
+ set_seed(epoch) # get the same text clips.
+ # retrival.
+ retri_predictor = VideoRetriPredictor(
+ self.config)
+ retri_predictor.predict_loop(
+ self.model, predictor.vecpool.retriver, epoch)
+ del predictor
+ del retri_predictor
+
+ def _retri_sync(self, epoch, out_dir):
+ # gpu do the same merge.
+ batched_videos = []
+ for local_rank in range(get_world_size()):
+ fn = os.path.join(
+ out_dir,
+ "batched_e" + str(epoch) + "_videos" + str(local_rank) + ".pkl")
+ with open(fn, "rb") as fr:
+ batched_videos.extend(pickle.load(fr))
+ print(
+ "[INFO] batched_videos",
+ len(batched_videos), len(batched_videos[0]))
+ return batched_videos
+
+
+class VideoPredictor(Predictor):
+ def __init__(self, config):
+ vectorpool_cls = getattr(vectorpool, config.vectorpool_cls)
+ self.vecpool = vectorpool_cls(config)
+
+ def predict_loop(
+ self,
+ model,
+ dataloader,
+ early_stop=-1,
+ ):
+ with torch.no_grad():
+ if get_local_rank() == 0:
+ dataloader = tqdm(dataloader)
+ for batch_idx, batch in enumerate(dataloader):
+ if batch_idx == early_stop:
+ break
+ self(batch, model)
+ return self.finalize()
+
+ def __call__(self, sample, model, **kwargs):
+ param = next(model.parameters())
+ dtype = param.dtype
+ device = param.device
+ subsample = sample["vfeats"].size(1)
+ sample = self.to_ctx(sample, device, dtype)
+ for key in sample:
+ if torch.is_tensor(sample[key]):
+ size = sample[key].size()
+ if len(size) >= 2:
+ batch_size = size[0] * size[1]
+ expanded_size = (
+ (batch_size,) + size[2:] if len(size) > 2
+ else (batch_size,)
+ )
+ sample[key] = sample[key].view(expanded_size)
+
+ outputs = model(**sample)
+ sample.update(outputs)
+ self.vecpool(sample, subsample)
+
+ def finalize(self):
+ print("[INFO]", self.vecpool)
+ if not self.vecpool.retriver.db.is_trained:
+ self.vecpool.retriver.finalize_training()
+ return self.vecpool.retriver
+
+
+class VideoRetriPredictor(Predictor):
+ """
+ Online Retrieval Predictor for Clips (used by RetriTask).
+ TODO: merge this with VisPredictor?
+ """
+
+ def __init__(self, config):
+ self.pred_dir = os.path.join(
+ config.fairseq.checkpoint.save_dir,
+ "retri")
+ self.num_cands = config.num_cands
+ self.num_video_per_batch = config.dataset.num_video_per_batch
+
+ def predict_loop(
+ self,
+ model,
+ retriver,
+ epoch,
+ early_stop=-1
+ ):
+ # a fake loop that only try to recover video vector
+ # from video_id.
+ batched_videos = []
+ # obtain available video_ids.
+ video_ids = list(retriver.videoid_to_vectoridx.keys())
+
+ dataloader = random.sample(
+ video_ids,
+ len(video_ids) // self.num_video_per_batch
+ )
+
+ if get_local_rank() == 0:
+ dataloader = tqdm(dataloader)
+ for batch_idx, batch in enumerate(dataloader):
+ # batch is one video id.
+ if batch_idx == early_stop:
+ break
+ video_ids = retriver.search_by_video_ids(
+ [batch], self.num_cands)[0]
+ if len(video_ids) > self.num_video_per_batch:
+ # we moved the center to make cluster robust.
+ video_ids = random.sample(video_ids, self.num_video_per_batch)
+ batched_videos.append(video_ids)
+ return self.finalize(batched_videos, epoch)
+
+ def finalize(self, batched_videos, epoch):
+ fn = os.path.join(
+ self.pred_dir,
+ "batched_e" + str(epoch) + "_videos" + str(get_local_rank()) + ".pkl")
+ with open(fn, "wb") as fw:
+ pickle.dump(batched_videos, fw, pickle.HIGHEST_PROTOCOL)
+ return batched_videos
diff --git a/examples/MMPT/mmpt/tasks/task.py b/examples/MMPT/mmpt/tasks/task.py
new file mode 100644
index 0000000000..8bb50f24df
--- /dev/null
+++ b/examples/MMPT/mmpt/tasks/task.py
@@ -0,0 +1,184 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from .. import tasks
+from .. import models
+from .. import losses
+from ..datasets import MMDataset
+from .. import processors
+
+
+class Task(object):
+ """
+ A task refers to one generic training task (e.g., training one model).
+ """
+
+ @classmethod
+ def config_task(cls, config):
+ """
+ determine whether to load a hard-coded task or config from a generic one.
+ via if a task string is available in config.
+ """
+ if config.task is not None:
+ # TODO (huxu): expand the search scope.
+ task_cls = getattr(tasks, config.task)
+ return task_cls(config)
+ else:
+ return Task(config)
+
+ def __init__(self, config):
+ self.config = config
+ self.train_data = None
+ self.val_data = None
+ self.test_data = None
+
+ self.model = None
+ self.loss_fn = None
+ self.eval_fn = None
+
+ def build_dataset(self):
+ """TODO (huxu): move processor breakdown to MMDataset."""
+ """fill-in `self.train_data`, `self.val_data` and `self.test_data`."""
+
+ meta_processor_cls = getattr(
+ processors, self.config.dataset.meta_processor)
+ video_processor_cls = getattr(
+ processors, self.config.dataset.video_processor)
+ text_processor_cls = getattr(
+ processors, self.config.dataset.text_processor)
+ aligner_cls = getattr(
+ processors, self.config.dataset.aligner)
+
+ if self.config.dataset.train_path is not None:
+ self.config.dataset.split = "train"
+ # may be used by meta processor.
+ # meta_processor controls different dataset.
+ meta_processor = meta_processor_cls(self.config.dataset)
+ video_processor = video_processor_cls(self.config.dataset)
+ text_processor = text_processor_cls(self.config.dataset)
+ aligner = aligner_cls(self.config.dataset)
+ self.train_data = MMDataset(
+ meta_processor, video_processor, text_processor, aligner
+ )
+ print("train_len", len(self.train_data))
+ output = self.train_data[0]
+ self.train_data.print_example(output)
+ if self.config.dataset.val_path is not None:
+ self.config.dataset.split = "valid"
+ # may be used by meta processor.
+ meta_processor = meta_processor_cls(self.config.dataset)
+ video_processor = video_processor_cls(self.config.dataset)
+ text_processor = text_processor_cls(self.config.dataset)
+ aligner = aligner_cls(self.config.dataset)
+ self.val_data = MMDataset(
+ meta_processor, video_processor, text_processor, aligner
+ )
+ print("val_len", len(self.val_data))
+ output = self.val_data[0]
+ self.val_data.print_example(output)
+
+ if self.config.dataset.split == "test":
+ # the following is run via lauching fairseq-validate.
+ meta_processor = meta_processor_cls(self.config.dataset)
+ video_processor = video_processor_cls(self.config.dataset)
+ text_processor = text_processor_cls(self.config.dataset)
+
+ self.test_data = MMDataset(
+ meta_processor, video_processor, text_processor, aligner
+ )
+ print("test_len", len(self.test_data))
+ output = self.test_data[0]
+ self.test_data.print_example(output)
+
+ def build_model(self, checkpoint=None):
+ if self.model is None:
+ model_cls = getattr(models, self.config.model.model_cls)
+ self.model = model_cls(self.config)
+ if checkpoint is not None:
+ self.load_checkpoint(checkpoint)
+ return self.model
+
+ def load_checkpoint(self, checkpoint):
+ if self.model is None:
+ raise ValueError("model is not initialized.")
+ state_dict = torch.load(checkpoint)
+ state_dict = self._trim_state_dict(state_dict)
+ self.model.load_state_dict(state_dict, strict=False)
+ # if it's a fp16 model, turn it back.
+ if next(self.model.parameters()).dtype == torch.float16:
+ self.model = self.model.float()
+ return self.model
+
+ def _trim_state_dict(self, state_dict):
+ from collections import OrderedDict
+
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ if "model" in state_dict: # fairseq checkpoint format.
+ state_dict = state_dict["model"]
+ ret_state_dict = OrderedDict()
+ for (
+ key,
+ value,
+ ) in state_dict.items():
+ # remove fairseq wrapper since this is a task.
+ if key.startswith("mmmodel"):
+ key = key[len("mmmodel."):]
+ ret_state_dict[key] = value
+ return ret_state_dict
+
+ def build_loss(self):
+ if self.loss_fn is None and self.config.loss is not None:
+ loss_cls = getattr(losses, self.config.loss.loss_cls)
+ self.loss_fn = loss_cls()
+ return self.loss_fn
+
+ def flat_subsample(self, tensor):
+ size = tensor.size()
+ if len(size) >= 2:
+ batch_size = size[0] * size[1]
+ expanded_size = (
+ (batch_size,) + size[2:] if len(size) > 2
+ else (batch_size,)
+ )
+ tensor = tensor.view(expanded_size)
+ return tensor
+
+ def reshape_subsample(self, sample):
+ if (
+ hasattr(self.config.dataset, "subsampling")
+ and self.config.dataset.subsampling is not None
+ and self.config.dataset.subsampling > 1
+ ):
+ for key in sample:
+ if torch.is_tensor(sample[key]):
+ sample[key] = self.flat_subsample(sample[key])
+ return sample
+
+ def __call__(self, model, sample):
+ loss = None
+ loss_scalar = float("inf")
+
+ sample = self.reshape_subsample(sample)
+ outputs = self.model(**sample)
+ sample.update(outputs)
+ if self.loss_fn is not None:
+ loss = self.loss_fn(**sample)
+ loss_scalar = loss.item()
+
+ batch_size = sample["caps"].size(0)
+ sample_size = 1
+ return {
+ "loss": loss,
+ "loss_scalar": loss_scalar,
+ "max_len": self.config.dataset.max_len,
+ "batch_size": batch_size,
+ "sample_size": sample_size,
+ }
+
+ def build_dataloader(self):
+ """only used for trainer that lacks building loaders."""
+ raise NotImplementedError
diff --git a/examples/MMPT/mmpt/tasks/vlmtask.py b/examples/MMPT/mmpt/tasks/vlmtask.py
new file mode 100644
index 0000000000..57dc4c9170
--- /dev/null
+++ b/examples/MMPT/mmpt/tasks/vlmtask.py
@@ -0,0 +1,27 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from .task import Task
+
+
+class VLMTask(Task):
+ """A VLM task for reproducibility.
+ the collator split subsamples into two sub-batches.
+ This has should have no logic changes.
+ but changed the randomness in frame masking.
+ """
+
+ def flat_subsample(self, tensor):
+ size = tensor.size()
+ if len(size) >= 2:
+ batch_size = size[0] * (size[1] // 2)
+ expanded_size = (
+ (batch_size, 2) + size[2:] if len(size) > 2
+ else (batch_size, 2)
+ )
+ tensor = tensor.view(expanded_size)
+ tensor = torch.cat([tensor[:, 0], tensor[:, 1]], dim=0)
+ return tensor
diff --git a/examples/MMPT/mmpt/utils/__init__.py b/examples/MMPT/mmpt/utils/__init__.py
new file mode 100644
index 0000000000..2429ee3757
--- /dev/null
+++ b/examples/MMPT/mmpt/utils/__init__.py
@@ -0,0 +1,68 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import random
+import numpy as np
+import torch
+
+from .shardedtensor import *
+from .load_config import *
+
+
+def set_seed(seed=43211):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if torch.backends.cudnn.enabled:
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+
+
+def get_world_size():
+ if torch.distributed.is_initialized():
+ world_size = torch.distributed.get_world_size()
+ else:
+ world_size = 1
+ return world_size
+
+
+def get_local_rank():
+ return torch.distributed.get_rank() \
+ if torch.distributed.is_initialized() else 0
+
+
+def print_on_rank0(func):
+ local_rank = get_local_rank()
+ if local_rank == 0:
+ print("[INFO]", func)
+
+
+class RetriMeter(object):
+ """
+ Statistics on whether retrieval yields a better pair.
+ """
+ def __init__(self, freq=1024):
+ self.freq = freq
+ self.total = 0
+ self.replace = 0
+ self.updates = 0
+
+ def __call__(self, data):
+ if isinstance(data, np.ndarray):
+ self.replace += data.shape[0] - int((data[:, 0] == -1).sum())
+ self.total += data.shape[0]
+ elif torch.is_tensor(data):
+ self.replace += int(data.sum())
+ self.total += data.size(0)
+ else:
+ raise ValueError("unsupported RetriMeter data type.", type(data))
+
+ self.updates += 1
+ if get_local_rank() == 0 and self.updates % self.freq == 0:
+ print("[INFO]", self)
+
+ def __repr__(self):
+ return "RetriMeter (" + str(self.replace / self.total) \
+ + "/" + str(self.replace) + "/" + str(self.total) + ")"
diff --git a/examples/MMPT/mmpt/utils/load_config.py b/examples/MMPT/mmpt/utils/load_config.py
new file mode 100644
index 0000000000..ede4f94117
--- /dev/null
+++ b/examples/MMPT/mmpt/utils/load_config.py
@@ -0,0 +1,81 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import omegaconf
+from omegaconf import OmegaConf
+
+
+def load_config(args=None, config_file=None, overwrite_fairseq=False):
+ """TODO (huxu): move fairseq overwrite to another function."""
+ if args is not None:
+ config_file = args.taskconfig
+ config = recursive_config(config_file)
+
+ if config.dataset.subsampling is not None:
+ batch_size = config.fairseq.dataset.batch_size // config.dataset.subsampling
+ print(
+ "adjusting batch_size to {} due to subsampling {}.".format(
+ batch_size, config.dataset.subsampling
+ )
+ )
+ config.fairseq.dataset.batch_size = batch_size
+
+ is_test = config.dataset.split is not None and config.dataset.split == "test"
+ if not is_test:
+ if (
+ config.fairseq.checkpoint is None
+ or config.fairseq.checkpoint.save_dir is None
+ ):
+ raise ValueError("fairseq save_dir or save_path must be specified.")
+
+ save_dir = config.fairseq.checkpoint.save_dir
+ os.makedirs(save_dir, exist_ok=True)
+ if config.fairseq.common.tensorboard_logdir is not None:
+ tb_run_dir = suffix_rundir(
+ save_dir, config.fairseq.common.tensorboard_logdir
+ )
+ config.fairseq.common.tensorboard_logdir = tb_run_dir
+ print(
+ "update tensorboard_logdir as", config.fairseq.common.tensorboard_logdir
+ )
+ os.makedirs(save_dir, exist_ok=True)
+ OmegaConf.save(config=config, f=os.path.join(save_dir, "config.yaml"))
+
+ if overwrite_fairseq and config.fairseq is not None and args is not None:
+ # flatten fields.
+ for group in config.fairseq:
+ for field in config.fairseq[group]:
+ print("overwrite args." + field, "as", config.fairseq[group][field])
+ setattr(args, field, config.fairseq[group][field])
+ return config
+
+
+def recursive_config(config_path):
+ """allows for stacking of configs in any depth."""
+ config = OmegaConf.load(config_path)
+ if config.includes is not None:
+ includes = config.includes
+ config.pop("includes")
+ base_config = recursive_config(includes)
+ config = OmegaConf.merge(base_config, config)
+ return config
+
+
+def suffix_rundir(save_dir, run_dir):
+ max_id = -1
+ for search_dir in os.listdir(save_dir):
+ if search_dir.startswith(run_dir):
+ splits = search_dir.split("_")
+ cur_id = int(splits[1]) if len(splits) > 1 else 0
+ max_id = max(max_id, cur_id)
+ return os.path.join(save_dir, run_dir + "_" + str(max_id + 1))
+
+
+def overwrite_dir(config, replace, basedir):
+ for key in config:
+ if isinstance(config[key], str) and config[key].startswith(basedir):
+ config[key] = config[key].replace(basedir, replace)
+ if isinstance(config[key], omegaconf.dictconfig.DictConfig):
+ overwrite_dir(config[key], replace, basedir)
diff --git a/examples/MMPT/mmpt/utils/shardedtensor.py b/examples/MMPT/mmpt/utils/shardedtensor.py
new file mode 100644
index 0000000000..2424f360ef
--- /dev/null
+++ b/examples/MMPT/mmpt/utils/shardedtensor.py
@@ -0,0 +1,46 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import pickle
+import numpy as np
+
+
+class ShardedTensor(object):
+ def __init__(self, data, starts):
+ self.data = data
+ self.starts = starts
+ assert self.starts[0] == 0
+ assert self.starts[-1] == len(self.data)
+ assert (self.starts[1:] >= self.starts[:-1]).all()
+ assert (self.starts > -1).all()
+
+ @staticmethod
+ def from_list(xs):
+ starts = np.full((len(xs) + 1,), -1, dtype=np.long)
+ data = np.concatenate(xs, axis=0)
+ starts[0] = 0
+ for i, x in enumerate(xs):
+ starts[i + 1] = starts[i] + x.shape[0]
+ assert (starts > -1).all()
+ return ShardedTensor(data, starts)
+
+ def __getitem__(self, i):
+ return self.data[self.starts[i] : self.starts[i + 1]]
+
+ def __len__(self):
+ return len(self.starts) - 1
+
+ def lengths(self):
+ return self.starts[1:] - self.starts[:-1]
+
+ def save(self, path):
+ np.save(path + "_starts", self.starts)
+ np.save(path + "_data", self.data)
+
+ @staticmethod
+ def load(path, mmap_mode=None):
+ starts = np.load(path + "_starts.npy", mmap_mode)
+ data = np.load(path + "_data.npy", mmap_mode)
+ return ShardedTensor(data, starts)
diff --git a/examples/MMPT/mmpt_cli/localjob.py b/examples/MMPT/mmpt_cli/localjob.py
new file mode 100644
index 0000000000..2675d3511a
--- /dev/null
+++ b/examples/MMPT/mmpt_cli/localjob.py
@@ -0,0 +1,117 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+
+from mmpt.utils import recursive_config
+
+
+class BaseJob(object):
+ def __init__(self, yaml_file, dryrun=False):
+ self.yaml_file = yaml_file
+ self.config = recursive_config(yaml_file)
+ self.dryrun = dryrun
+
+ def submit(self, **kwargs):
+ raise NotImplementedError
+
+ def _normalize_cmd(self, cmd_list):
+ cmd_list = list(cmd_list)
+ yaml_index = cmd_list.index("[yaml]")
+ cmd_list[yaml_index] = self.yaml_file
+ return cmd_list
+
+
+class LocalJob(BaseJob):
+
+ CMD_CONFIG = {
+ "local_single": [
+ "fairseq-train", "[yaml]", "--user-dir", "mmpt",
+ "--task", "mmtask", "--arch", "mmarch",
+ "--criterion", "mmloss",
+ ],
+ "local_small": [
+ "fairseq-train", "[yaml]", "--user-dir", "mmpt",
+ "--task", "mmtask", "--arch", "mmarch",
+ "--criterion", "mmloss",
+ "--distributed-world-size", "2"
+ ],
+ "local_big": [
+ "fairseq-train", "[yaml]", "--user-dir", "mmpt",
+ "--task", "mmtask", "--arch", "mmarch",
+ "--criterion", "mmloss",
+ "--distributed-world-size", "8"
+ ],
+ "local_predict": ["python", "mmpt_cli/predict.py", "[yaml]"],
+ }
+
+ def __init__(self, yaml_file, job_type=None, dryrun=False):
+ super().__init__(yaml_file, dryrun)
+ if job_type is None:
+ self.job_type = "local_single"
+ if self.config.task_type is not None:
+ self.job_type = self.config.task_type
+ else:
+ self.job_type = job_type
+ if self.job_type in ["local_single", "local_small"]:
+ if self.config.fairseq.dataset.batch_size > 32:
+ print("decreasing batch_size to 32 for local testing?")
+
+ def submit(self):
+ cmd_list = self._normalize_cmd(LocalJob.CMD_CONFIG[self.job_type])
+ if "predict" not in self.job_type:
+ # append fairseq args.
+ from mmpt.utils import load_config
+
+ config = load_config(config_file=self.yaml_file)
+ for field in config.fairseq:
+ for key in config.fairseq[field]:
+ if key in ["fp16", "reset_optimizer", "reset_dataloader", "reset_meters"]: # a list of binary flag.
+ param = ["--" + key.replace("_", "-")]
+ else:
+ if key == "lr":
+ value = str(config.fairseq[field][key][0])
+ elif key == "adam_betas":
+ value = "'"+str(config.fairseq[field][key])+"'"
+ else:
+ value = str(config.fairseq[field][key])
+ param = [
+ "--" + key.replace("_", "-"),
+ value
+ ]
+ cmd_list.extend(param)
+
+ print("launching", " ".join(cmd_list))
+ if not self.dryrun:
+ os.system(" ".join(cmd_list))
+ return JobStatus("12345678")
+
+
+class JobStatus(object):
+ def __init__(self, job_id):
+ self.job_id = job_id
+
+ def __repr__(self):
+ return self.job_id
+
+ def __str__(self):
+ return self.job_id
+
+ def done(self):
+ return False
+
+ def running(self):
+ return False
+
+ def result(self):
+ if self.done():
+ return "{} is done.".format(self.job_id)
+ else:
+ return "{} is running.".format(self.job_id)
+
+ def stderr(self):
+ return self.result()
+
+ def stdout(self):
+ return self.result()
diff --git a/examples/MMPT/mmpt_cli/predict.py b/examples/MMPT/mmpt_cli/predict.py
new file mode 100644
index 0000000000..4071e196d2
--- /dev/null
+++ b/examples/MMPT/mmpt_cli/predict.py
@@ -0,0 +1,113 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import glob
+import argparse
+import pprint
+import omegaconf
+
+from omegaconf import OmegaConf
+from torch.utils.data import DataLoader
+
+from mmpt.utils import load_config, set_seed
+from mmpt.evaluators import Evaluator
+from mmpt.evaluators import predictor as predictor_path
+from mmpt.tasks import Task
+from mmpt import processors
+from mmpt.datasets import MMDataset
+
+
+def get_dataloader(config):
+ meta_processor_cls = getattr(processors, config.dataset.meta_processor)
+ video_processor_cls = getattr(processors, config.dataset.video_processor)
+ text_processor_cls = getattr(processors, config.dataset.text_processor)
+ aligner_cls = getattr(processors, config.dataset.aligner)
+
+ meta_processor = meta_processor_cls(config.dataset)
+ video_processor = video_processor_cls(config.dataset)
+ text_processor = text_processor_cls(config.dataset)
+ aligner = aligner_cls(config.dataset)
+
+ test_data = MMDataset(
+ meta_processor,
+ video_processor,
+ text_processor,
+ aligner,
+ )
+ print("test_len", len(test_data))
+ output = test_data[0]
+ test_data.print_example(output)
+
+ test_dataloader = DataLoader(
+ test_data,
+ batch_size=config.fairseq.dataset.batch_size,
+ shuffle=False,
+ num_workers=6,
+ collate_fn=test_data.collater,
+ )
+ return test_dataloader
+
+
+def main(args):
+ config = load_config(args)
+
+ if isinstance(config, omegaconf.dictconfig.DictConfig):
+ print(OmegaConf.to_yaml(config))
+ else:
+ pp = pprint.PrettyPrinter(indent=4)
+ pp.print(config)
+
+ mmtask = Task.config_task(config)
+ mmtask.build_model()
+
+ test_dataloader = get_dataloader(config)
+ checkpoint_search_path = os.path.dirname(config.eval.save_path)
+ results = []
+
+ prefix = os.path.basename(args.taskconfig)
+ if prefix.startswith("test"):
+ # loop all checkpoint for datasets without validation set.
+ if "best" not in config.fairseq.common_eval.path:
+ print("eval each epoch.")
+ for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"):
+ model = mmtask.load_checkpoint(checkpoint)
+ ckpt = os.path.basename(checkpoint)
+ evaluator = Evaluator(config)
+ output = evaluator.evaluate(
+ model, test_dataloader, ckpt + "_merged")
+ results.append((checkpoint, output))
+ # use the one specified by the config lastly.
+ model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
+ evaluator = Evaluator(config)
+ output = evaluator.evaluate(model, test_dataloader)
+ results.append((config.fairseq.common_eval.path, output))
+
+ best_result = None
+ best_metric = 0.
+ for checkpoint, result in results:
+ print(checkpoint)
+ evaluator.metric.print_computed_metrics(result)
+ best_score = evaluator.metric.best_metric(result)
+ if best_score > best_metric:
+ best_result = (checkpoint, result)
+ best_metric = best_score
+ print("best results:")
+ print(best_result[0])
+ evaluator.metric.print_computed_metrics(best_result[1])
+
+ elif prefix.startswith("vis"):
+ model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
+ predictor_cls = getattr(predictor_path, config.predictor)
+ predictor = predictor_cls(config)
+ predictor.predict_loop(model, test_dataloader, mmtask, None)
+ else:
+ raise ValueError("unknown prefix of the config file", args.taskconfig)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("taskconfig", type=str)
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/MMPT/pretraining.md b/examples/MMPT/pretraining.md
new file mode 100644
index 0000000000..8f8e6d0fac
--- /dev/null
+++ b/examples/MMPT/pretraining.md
@@ -0,0 +1,29 @@
+# Pretraining
+
+(If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.)
+We mostly use [howto100M](https://github.com/antoine77340/howto100m) dataset for pretraining (other datasets are coming). So you are less likely to write a new `MetaProcessor`, `VideoProcessor` or `TextProcessor` but only working on a new `Aligner`, a new model and loss.
+
+### Data Sharding
+Pretraining on Howto100M is heavy on IO since we have millions of videos or captions on the hard disk that cannot be fit into the memory.
+It is desirable to have an optimized preprocessing step before the actual dataloading.
+
+We support data sharding to pack multiple videos into a shards of training data for both videos and captions. (see [dataset](DATASET.md) for preprocessing).
+These shards will be mapped into memory to reduce the frequency of IO access on millions of files. See (processors starting with `Sharded*`).
+This will be the default config for a how2 dataset `projects/task/how2.yaml`.
+
+Great thanks to Dmytro Okhonko for sharing the code from MARGE project.
+
+### Training
+Pretraining on Howto100m is expected on one or multiple nodes, where each node has 8 GPUS with 32 GB mem.
+launching a pretraing on MFM+MLM can be done, via:
+```python locallaunch.py projects/mfmmlm/how2.yaml```
+
+### Pre-training with a Retrieval Model (VideoCLIP)
+This projects now support alternatively run a retrieval model and pre-training.
+We implement a basic retrieval model that is built on the hidden states of a video and faiss.
+
+You may need to install faiss via `conda install faiss-cpu -c pytorch`.
+
+Right now, the hidden states of a video is computed as the average of 8 clips of their pooled visual/text hidden states.
+See `mmpt/tasks/retritask.py` for more details.
+The `.yaml` config for running pre-training with a retrieval model can be found at `projects/retri/videoretri.yaml`.
diff --git a/examples/MMPT/projects/mfmmlm.yaml b/examples/MMPT/projects/mfmmlm.yaml
new file mode 100644
index 0000000000..0f3450a1e0
--- /dev/null
+++ b/examples/MMPT/projects/mfmmlm.yaml
@@ -0,0 +1,59 @@
+project_dir: mfmmlm
+run_task:
+ - how2.yaml
+ - [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml]
+base_dir: task
+task_group:
+ pretrain:
+ task_list:
+ - how2.yaml
+ dataset:
+ subsampling: 32
+ sampled_min_len: 10
+ sampled_max_len: 64
+ max_video_len: 32
+ max_len: 96
+ aligner: MFMMLMAligner
+ lazy_vfeat_mask: True
+ mfm_probability: 0.15
+ mlm_probability: 0.15
+ mm_prob: 0.5
+ model:
+ model_cls: MMFusionMFMMLM
+ mm_encoder_cls: MMFusionForMFMMLM
+ loss:
+ loss_cls: MFMMLM
+ fairseq:
+ common:
+ fp16: true
+ dataset:
+ batch_size: 256
+ optimization:
+ max_epoch: 15
+ finetune:
+ task_list:
+ - vtt.yaml
+ - vttqa.yaml
+ - youcook.yaml
+ - youcookcap.yaml
+ - crosstask.yaml
+ - coin.yaml
+ dataset:
+ max_video_len: 32
+ max_len: 96
+ fairseq:
+ common:
+ fp16: true
+ # do not write any model or loss here (they are expected to be fixed in mmfusion).
+ test:
+ task_list:
+ - test_vtt.yaml
+ - test_vttqa.yaml
+ - test_youcook.yaml
+ - test_youcookcap.yaml
+ - test_crosstask.yaml
+ - test_crosstask_zs.yaml
+ - test_coin.yaml
+ dataset:
+ max_video_len: 32
+ max_len: 96
diff --git a/examples/MMPT/projects/mtm/mmfusionmtm.yaml b/examples/MMPT/projects/mtm/mmfusionmtm.yaml
new file mode 100644
index 0000000000..337d66a2aa
--- /dev/null
+++ b/examples/MMPT/projects/mtm/mmfusionmtm.yaml
@@ -0,0 +1,19 @@
+includes: projects/mfmmlm.yaml
+project_dir: mtm/mmfusionmtm
+task_group:
+ pretrain:
+ task: VLMTask # reproducible
+ dataset:
+ aligner: MFMMLMAligner
+ model:
+ use_seg_emb: True # reproducible
+ model_cls: MMFusionMTM
+ mm_encoder_cls: MMBertForMFMMLM
+ loss:
+ loss_cls: MTM
+ finetune:
+ model:
+ use_seg_emb: True # reproducible
+ test:
+ model:
+ use_seg_emb: True # reproducible
diff --git a/examples/MMPT/projects/mtm/vlm.yaml b/examples/MMPT/projects/mtm/vlm.yaml
new file mode 100644
index 0000000000..022a2623c5
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm.yaml
@@ -0,0 +1,8 @@
+includes: projects/mtm/mmfusionmtm.yaml
+project_dir: mtm/vlm
+task_group:
+ pretrain:
+ dataset:
+ sampled_min_len: 8
+ loss:
+ loss_cls: MTM
diff --git a/examples/MMPT/projects/mtm/vlm/coin.yaml b/examples/MMPT/projects/mtm/vlm/coin.yaml
new file mode 100644
index 0000000000..48fd64a5f4
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/coin.yaml
@@ -0,0 +1,47 @@
+dataset:
+ video_processor: VideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: COINActionSegmentationMetaProcessor
+ train_path: data/coin/COIN.json
+ val_path: data/coin/COIN.json
+ vfeat_dir: data/feat/feat_coin_s3d
+ text_processor: COINActionSegmentationTextProcessor
+ aligner: COINActionSegmentationAligner
+ num_iso_layer: 12
+ sliding_window: 8
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 1
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 8
+ checkpoint:
+ restore_file: runs/mtm/vlm/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/mtm/vlm/coin
+task_type: sweep_big
+model:
+ model_cls: MMFusionActionSegmentation
+ mm_encoder_cls: MMBertForTokenClassification
+ use_seg_emb: true
+loss:
+ loss_cls: CrossEntropy
diff --git a/examples/MMPT/projects/mtm/vlm/crosstask.yaml b/examples/MMPT/projects/mtm/vlm/crosstask.yaml
new file mode 100644
index 0000000000..4e706b549e
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/crosstask.yaml
@@ -0,0 +1,53 @@
+dataset:
+ video_processor: CrossTaskVideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: CrossTaskMetaProcessor
+ train_path: data/crosstask/crosstask_release/videos.csv
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
+ val_path: data/crosstask/crosstask_release/videos_val.csv
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
+ vfeat_dir: data/feat/feat_crosstask_s3d
+ annotation_path: data/crosstask/crosstask_release/annotations
+ n_train: 30
+ text_processor: CrossTaskTextProcessor
+ aligner: CrossTaskAligner
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 1
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 5
+ checkpoint:
+ restore_file: runs/mtm/vlm/checkpoint11.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/mtm/vlm/crosstask
+task_type: sweep_small
+model:
+ model_cls: MMFusionActionLocalization
+ mm_encoder_cls: MMBertForJoint
+ use_seg_emb: true
+loss:
+ loss_cls: BCE
diff --git a/examples/MMPT/projects/mtm/vlm/how2.yaml b/examples/MMPT/projects/mtm/vlm/how2.yaml
new file mode 100644
index 0000000000..7ca40ad815
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/how2.yaml
@@ -0,0 +1,55 @@
+dataset:
+ video_processor: ShardedVideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: ShardedHow2MetaProcessor
+ train_path: data/how2/how2_s3d_train.lst
+ val_path: data/how2/how2_s3d_val.lst
+ vfeat_dir: data/feat/feat_how2_s3d_shard_small
+ text_processor: ShardedTextProcessor
+ tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased.
+ aligner: MFMMLMAligner
+ subsampling: 32
+ sampled_min_len: 8
+ sampled_max_len: 64
+ max_video_len: 32
+ max_len: 96
+ lazy_vfeat_mask: true
+ mfm_probability: 0.15
+ mlm_probability: 0.15
+ mm_prob: 0.5
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 256
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 1000
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 15
+ checkpoint:
+ save_dir: runs/mtm/vlm
+ save_interval_updates: 1024
+ keep_interval_updates: 2
+ keep_last_epochs: 30
+task_type: sweep_big
+slurm_config: big
+eval:
+ save_path: runs/mtm/vlm
+model:
+ model_cls: MMFusionMTM
+ mm_encoder_cls: MMBertForMFMMLM
+ use_seg_emb: true
+loss:
+ loss_cls: MTM
+task: VLMTask
diff --git a/examples/MMPT/projects/mtm/vlm/test_coin.yaml b/examples/MMPT/projects/mtm/vlm/test_coin.yaml
new file mode 100644
index 0000000000..8df2e66ad1
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/test_coin.yaml
@@ -0,0 +1,31 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: COINActionSegmentationAligner
+ bert_name: bert-base-uncased
+ test_path: data/coin/COIN.json
+ meta_processor: COINActionSegmentationMetaProcessor
+ vfeat_dir: data/feat/feat_coin_s3d
+ text_processor: COINActionSegmentationTextProcessor
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 1
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/mtm/vlm/coin/checkpoint_best.pt
+model:
+ model_cls: MMFusionActionSegmentation
+ mm_encoder_cls: MMBertForTokenClassification
+ use_seg_emb: true
+eval:
+ save_path: runs/mtm/vlm/coin/eval
+metric: COINActionSegmentationMetric
+predictor: COINPredictor
diff --git a/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml b/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml
new file mode 100644
index 0000000000..d159847875
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml
@@ -0,0 +1,38 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: CrossTaskVideoProcessor
+ aligner: CrossTaskAligner
+ bert_name: bert-base-uncased
+ meta_processor: CrossTaskMetaProcessor
+ test_path: data/crosstask/crosstask_release/videos_val.csv
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
+ val_path: data/crosstask/crosstask_release/videos_val.csv
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
+ vfeat_dir: data/feat/feat_crosstask_s3d
+ annotation_path: data/crosstask/crosstask_release/annotations
+ n_train: 30
+ text_processor: CrossTaskTextProcessor
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 1
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/mtm/vlm/crosstask/checkpoint_best.pt
+model:
+ model_cls: MMFusionActionLocalization
+ mm_encoder_cls: MMBertForJoint
+ use_seg_emb: true
+eval:
+ save_path: runs/mtm/vlm/crosstask/eval
+metric: CrossTaskMetric
+predictor: CrossTaskPredictor
diff --git a/examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml b/examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml
new file mode 100644
index 0000000000..59833c5540
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml
@@ -0,0 +1,38 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: CrossTaskVideoProcessor
+ aligner: CrossTaskAligner
+ bert_name: bert-base-uncased
+ meta_processor: CrossTaskMetaProcessor
+ test_path: data/crosstask/crosstask_release/videos_val.csv
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
+ val_path: data/crosstask/crosstask_release/videos_val.csv
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
+ vfeat_dir: data/feat/feat_crosstask_s3d
+ annotation_path: data/crosstask/crosstask_release/annotations
+ n_train: 30
+ text_processor: CrossTaskTextProcessor
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 1
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/mtm/vlm/checkpoint_best.pt
+model:
+ model_cls: MMFusionActionLocalization
+ mm_encoder_cls: MMBertForJoint
+ use_seg_emb: true
+eval:
+ save_path: runs/mtm/vlm/crosstask_zs/eval
+metric: CrossTaskMetric
+predictor: CrossTaskPredictor
diff --git a/examples/MMPT/projects/mtm/vlm/test_vtt.yaml b/examples/MMPT/projects/mtm/vlm/test_vtt.yaml
new file mode 100644
index 0000000000..a41557df6a
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/test_vtt.yaml
@@ -0,0 +1,29 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: DSAligner
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTMetaProcessor
+ test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/mtm/vlm/vtt/checkpoint_last.pt
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+ use_seg_emb: true
+eval:
+ save_path: runs/mtm/vlm/vtt/eval
+metric: RetrievalMetric
+predictor: RetrievalPredictor
diff --git a/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml b/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml
new file mode 100644
index 0000000000..abf3309f70
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml
@@ -0,0 +1,29 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: MSRVTTQAAligner
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTQAMetaProcessor
+ test_path: data/msrvtt-qa/MSR_MC_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTQATextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/mtm/vlm/vttqa/checkpoint_last.pt
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+ use_seg_emb: true
+eval:
+ save_path: runs/mtm/vlm/vttqa/eval
+metric: QAMetric
+predictor: QAPredictor
diff --git a/examples/MMPT/projects/mtm/vlm/test_youcook.yaml b/examples/MMPT/projects/mtm/vlm/test_youcook.yaml
new file mode 100644
index 0000000000..3a57d25c24
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/test_youcook.yaml
@@ -0,0 +1,31 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: YoucookVideoProcessor
+ aligner: DSAligner
+ bert_name: bert-base-uncased
+ meta_processor: YoucookMetaProcessor
+ test_path: data/youcook/youcook_val.pkl
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ use_annotation_text: true
+ vfeat_dir: data/feat/feat_youcook_s3d
+ text_processor: TextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/mtm/vlm/youcook/checkpoint_last.pt
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+ use_seg_emb: true
+eval:
+ save_path: runs/mtm/vlm/youcook/eval
+metric: RetrievalMetric
+predictor: RetrievalPredictor
diff --git a/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml b/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml
new file mode 100644
index 0000000000..b2595d7c3c
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml
@@ -0,0 +1,32 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: YoucookVideoProcessor
+ aligner: DSNLGAligner
+ bert_name: bert-base-uncased
+ meta_processor: YoucookNLGMetaProcessor
+ test_path: data/youcook/val_list.txt
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ vfeat_dir: data/feat/feat_youcook_s3d
+ text_processor: NLGTextProcessor
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/mtm/vlm/youcookcap/checkpoint_best.pt
+model:
+ model_cls: MMFusionNLG
+ mm_encoder_cls: MMBertForNLG
+ max_decode_length: 24
+ use_seg_emb: true
+eval:
+ save_path: runs/mtm/vlm/youcookcap/eval
+metric: NLGMetric
+predictor: NLGPredictor
+gen_param:
+ num_beams: 5
diff --git a/examples/MMPT/projects/mtm/vlm/vtt.yaml b/examples/MMPT/projects/mtm/vlm/vtt.yaml
new file mode 100644
index 0000000000..c6c5b1ab40
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/vtt.yaml
@@ -0,0 +1,49 @@
+dataset:
+ video_processor: VideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTMetaProcessor
+ train_path: data/msrvtt/MSRVTT_train.csv
+ jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ full_test_path: data/msrvtt/MSRVTT_FULL_test.csv
+ dup: 20
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ json_path: data/msrvtt/MSRVTT_data.json
+ aligner: DSAligner
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 256
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 10
+ checkpoint:
+ restore_file: runs/mtm/vlm/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/mtm/vlm/vtt
+task_type: sweep_small
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+ use_seg_emb: true
+loss:
+ loss_cls: T2VContraLoss
diff --git a/examples/MMPT/projects/mtm/vlm/vttqa.yaml b/examples/MMPT/projects/mtm/vlm/vttqa.yaml
new file mode 100644
index 0000000000..0a440c7dd2
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/vttqa.yaml
@@ -0,0 +1,47 @@
+dataset:
+ video_processor: VideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTMetaProcessor
+ train_path: data/msrvtt/MSRVTT_train.csv
+ dup: 20
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ json_path: data/msrvtt/MSRVTT_data.json
+ aligner: DSAligner
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 128
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 5
+ checkpoint:
+ restore_file: runs/mtm/vlm/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/mtm/vlm/vttqa
+task_type: sweep_small
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+ use_seg_emb: true
+loss:
+ loss_cls: V2TContraLoss
diff --git a/examples/MMPT/projects/mtm/vlm/youcook.yaml b/examples/MMPT/projects/mtm/vlm/youcook.yaml
new file mode 100644
index 0000000000..9ee82b81b8
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/youcook.yaml
@@ -0,0 +1,47 @@
+dataset:
+ video_processor: YoucookVideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: YoucookMetaProcessor
+ train_path: data/youcook/youcook_train.pkl
+ val_path: data/youcook/youcook_val.pkl
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ use_annotation_text: true
+ vfeat_dir: data/feat/feat_youcook_s3d
+ text_processor: TextProcessor
+ aligner: DSAligner
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 128
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 10
+ checkpoint:
+ restore_file: runs/mtm/vlm/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/mtm/vlm/youcook
+task_type: sweep_small
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+ use_seg_emb: true
+loss:
+ loss_cls: T2VContraLoss
diff --git a/examples/MMPT/projects/mtm/vlm/youcookcap.yaml b/examples/MMPT/projects/mtm/vlm/youcookcap.yaml
new file mode 100644
index 0000000000..d29dfad5cd
--- /dev/null
+++ b/examples/MMPT/projects/mtm/vlm/youcookcap.yaml
@@ -0,0 +1,45 @@
+dataset:
+ video_processor: YoucookVideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: YoucookNLGMetaProcessor
+ train_path: data/youcook/train_list.txt
+ val_path: data/youcook/val_list.txt
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ vfeat_dir: data/feat/feat_youcook_s3d
+ text_processor: NLGTextProcessor
+ aligner: DSNLGAligner
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 128
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 10
+ checkpoint:
+ restore_file: runs/mtm/vlm/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/mtm/vlm/youcookcap
+task_type: sweep_small
+model:
+ model_cls: MMFusionNLG
+ mm_encoder_cls: MMBertForNLG
+ use_seg_emb: true
+loss:
+ loss_cls: NLGLoss
diff --git a/examples/MMPT/projects/retri/videoclip.yaml b/examples/MMPT/projects/retri/videoclip.yaml
new file mode 100644
index 0000000000..afd040ab05
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip.yaml
@@ -0,0 +1,10 @@
+includes: projects/retri/videoretri.yaml
+project_dir: retri/videoclip
+task_group:
+ pretrain:
+ model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
diff --git a/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml
new file mode 100644
index 0000000000..aaed5e47f6
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml
@@ -0,0 +1,49 @@
+dataset:
+ video_processor: VideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: COINActionSegmentationMetaProcessor
+ train_path: data/coin/COIN.json
+ val_path: data/coin/COIN.json
+ vfeat_dir: data/feat/feat_coin_s3d
+ text_processor: COINActionSegmentationTextProcessor
+ aligner: COINActionSegmentationAligner
+ num_iso_layer: 12
+ sliding_window: 8
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 1
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 8
+ checkpoint:
+ restore_file: runs/retri/videoclip/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/retri/videoclip/coin
+task_type: sweep_big
+model:
+ model_cls: MMFusionSeparateActionSegmentation
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForTokenClassification
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+loss:
+ loss_cls: CrossEntropy
diff --git a/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml
new file mode 100644
index 0000000000..758601e359
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml
@@ -0,0 +1,55 @@
+dataset:
+ video_processor: CrossTaskVideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: CrossTaskMetaProcessor
+ train_path: data/crosstask/crosstask_release/videos.csv
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
+ val_path: data/crosstask/crosstask_release/videos_val.csv
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
+ vfeat_dir: data/feat/feat_crosstask_s3d
+ annotation_path: data/crosstask/crosstask_release/annotations
+ n_train: 30
+ text_processor: CrossTaskTextProcessor
+ aligner: CrossTaskAligner
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 1
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 5
+ checkpoint:
+ restore_file: runs/retri/videoclip/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/retri/videoclip/crosstask
+task_type: sweep_small
+model:
+ model_cls: MMFusionSeparateActionLocalization
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+loss:
+ loss_cls: BCE
diff --git a/examples/MMPT/projects/retri/videoclip/how2.yaml b/examples/MMPT/projects/retri/videoclip/how2.yaml
new file mode 100644
index 0000000000..b49581e878
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/how2.yaml
@@ -0,0 +1,65 @@
+dataset:
+ video_processor: ShardedVideoRetriVideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: ShardedHow2VideoRetriMetaProcessor
+ train_path: data/how2/how2_s3d_train.lst
+ val_path: data/how2/how2_s3d_val.lst
+ vfeat_dir: data/feat/feat_how2_s3d_shard_small
+ text_processor: ShardedVideoRetriTextProcessor
+ tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased.
+ aligner: VideoRetriOverlappedAligner
+ subsampling: 1
+ sampled_min_len: 8
+ sampled_max_len: 64
+ max_video_len: 32
+ max_len: 96
+ lazy_vfeat_mask: true
+ mfm_probability: 0.15
+ mlm_probability: 0.15
+ mm_prob: 0.5
+ sampled_video_min_len: 3
+ sampled_video_max_len: 32
+ num_video_per_batch: 32
+ clip_per_video: 16
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 1
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 1000
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 25
+ checkpoint:
+ save_dir: runs/retri/videoclip
+ save_interval_updates: 1024
+ keep_interval_updates: 2
+ keep_last_epochs: 30
+task_type: sweep_big
+slurm_config: big
+eval:
+ save_path: runs/retri/videoclip
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+loss:
+ loss_cls: MMContraLoss
+task: VideoRetriTask
+retri_epoch: 1
+vectorpool_cls: VideoVectorPool
+retriever_cls: VectorRetriever
+num_cands: 64
diff --git a/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml
new file mode 100644
index 0000000000..409906203c
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml
@@ -0,0 +1,33 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: COINActionSegmentationAligner
+ bert_name: bert-base-uncased
+ test_path: data/coin/COIN.json
+ meta_processor: COINActionSegmentationMetaProcessor
+ vfeat_dir: data/feat/feat_coin_s3d
+ text_processor: COINActionSegmentationTextProcessor
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 1
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/coin/checkpoint_best.pt
+model:
+ model_cls: MMFusionSeparateActionSegmentation
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForTokenClassification
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/coin/eval
+metric: COINActionSegmentationMetric
+predictor: COINPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml b/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml
new file mode 100644
index 0000000000..b33739c7b6
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml
@@ -0,0 +1,33 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: COINActionSegmentationAligner
+ bert_name: bert-base-uncased
+ test_path: data/coin/COIN.json
+ meta_processor: COINActionSegmentationMetaProcessor
+ vfeat_dir: data/feat/feat_coin_s3d
+ text_processor: COINActionSegmentationTextProcessor
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 1
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/checkpoint_best.pt
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/coin_zs/eval
+metric: COINActionSegmentationMetric
+predictor: COINZSPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml
new file mode 100644
index 0000000000..e82f54fbe5
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml
@@ -0,0 +1,40 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: CrossTaskVideoProcessor
+ aligner: CrossTaskAligner
+ bert_name: bert-base-uncased
+ meta_processor: CrossTaskMetaProcessor
+ test_path: data/crosstask/crosstask_release/videos_val.csv
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
+ val_path: data/crosstask/crosstask_release/videos_val.csv
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
+ vfeat_dir: data/feat/feat_crosstask_s3d
+ annotation_path: data/crosstask/crosstask_release/annotations
+ n_train: 30
+ text_processor: CrossTaskTextProcessor
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 1
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/crosstask/checkpoint_best.pt
+model:
+ model_cls: MMFusionSeparateActionLocalization
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/crosstask/eval
+metric: CrossTaskMetric
+predictor: CrossTaskPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml
new file mode 100644
index 0000000000..6fc357cc1f
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml
@@ -0,0 +1,40 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: CrossTaskVideoProcessor
+ aligner: CrossTaskAligner
+ bert_name: bert-base-uncased
+ meta_processor: CrossTaskMetaProcessor
+ test_path: data/crosstask/crosstask_release/videos_val.csv
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
+ val_path: data/crosstask/crosstask_release/videos_val.csv
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
+ vfeat_dir: data/feat/feat_crosstask_s3d
+ annotation_path: data/crosstask/crosstask_release/annotations
+ n_train: 30
+ text_processor: CrossTaskTextProcessor
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 1
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/checkpoint_best.pt
+model:
+ model_cls: MMFusionSeparateActionLocalization
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/crosstask_zs/eval
+metric: CrossTaskMetric
+predictor: CrossTaskPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml b/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml
new file mode 100644
index 0000000000..8dc716815d
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml
@@ -0,0 +1,31 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: DiDeMoAligner
+ bert_name: bert-base-uncased
+ meta_processor: DiDeMoMetaProcessor
+ test_path: data/didemo/test_data.json
+ vfeat_dir: data/feat/feat_didemo_s3d
+ text_processor: DiDeMoTextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/checkpoint_best.pt
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/didemo_zs/eval
+metric: DiDeMoMetric
+predictor: DiDeMoPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml
new file mode 100644
index 0000000000..19321ad5f4
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml
@@ -0,0 +1,31 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: DSAligner
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTMetaProcessor
+ test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/vtt/checkpoint_last.pt
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/vtt/eval
+metric: RetrievalMetric
+predictor: RetrievalPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml b/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml
new file mode 100644
index 0000000000..d149fa3960
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml
@@ -0,0 +1,31 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: DSAligner
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTMetaProcessor
+ test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/checkpoint_best.pt
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/vtt_zs/eval
+metric: RetrievalMetric
+predictor: RetrievalPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml
new file mode 100644
index 0000000000..295aeedbb0
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml
@@ -0,0 +1,31 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: MSRVTTQAAligner
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTQAMetaProcessor
+ test_path: data/msrvtt-qa/MSR_MC_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTQATextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/vttqa/checkpoint_last.pt
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/vttqa/eval
+metric: QAMetric
+predictor: QAPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml b/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml
new file mode 100644
index 0000000000..7a876c822a
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml
@@ -0,0 +1,31 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: MSRVTTQAAligner
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTQAMetaProcessor
+ test_path: data/msrvtt-qa/MSR_MC_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTQATextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/checkpoint_best.pt
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/vttqa_zs/eval
+metric: QAMetric
+predictor: QAPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml
new file mode 100644
index 0000000000..86a4ab203e
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml
@@ -0,0 +1,33 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: YoucookVideoProcessor
+ aligner: DSAligner
+ bert_name: bert-base-uncased
+ meta_processor: YoucookMetaProcessor
+ test_path: data/youcook/youcook_val.pkl
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ use_annotation_text: true
+ vfeat_dir: data/feat/feat_youcook_s3d
+ text_processor: TextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/youcook/checkpoint_last.pt
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/youcook/eval
+metric: RetrievalMetric
+predictor: RetrievalPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/test_youcook_zs.yaml b/examples/MMPT/projects/retri/videoclip/test_youcook_zs.yaml
new file mode 100644
index 0000000000..fd2941708b
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/test_youcook_zs.yaml
@@ -0,0 +1,33 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: YoucookVideoProcessor
+ aligner: DSAligner
+ bert_name: bert-base-uncased
+ meta_processor: YoucookMetaProcessor
+ test_path: data/youcook/youcook_val.pkl
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ use_annotation_text: true
+ vfeat_dir: data/feat/feat_youcook_s3d
+ text_processor: TextProcessor
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
+ common_eval:
+ path: runs/retri/videoclip/checkpoint_best.pt
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/retri/videoclip/youcook_zs/eval
+metric: RetrievalMetric
+predictor: RetrievalPredictor
diff --git a/examples/MMPT/projects/retri/videoclip/vtt_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/vtt_videoclip.yaml
new file mode 100644
index 0000000000..d8b4079ac2
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/vtt_videoclip.yaml
@@ -0,0 +1,51 @@
+dataset:
+ video_processor: VideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTMetaProcessor
+ train_path: data/msrvtt/MSRVTT_train.csv
+ jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ full_test_path: data/msrvtt/MSRVTT_FULL_test.csv
+ dup: 20
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ json_path: data/msrvtt/MSRVTT_data.json
+ aligner: DSAligner
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 224
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 10
+ checkpoint:
+ restore_file: runs/retri/videoclip/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/retri/videoclip/vtt
+task_type: sweep_small
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+loss:
+ loss_cls: T2VContraLoss
diff --git a/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml
new file mode 100644
index 0000000000..f0566d784a
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml
@@ -0,0 +1,49 @@
+dataset:
+ video_processor: VideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: MSRVTTMetaProcessor
+ train_path: data/msrvtt/MSRVTT_train.csv
+ dup: 20
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ json_path: data/msrvtt/MSRVTT_data.json
+ aligner: DSAligner
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 128
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 5
+ checkpoint:
+ restore_file: runs/retri/videoclip/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/retri/videoclip/vttqa
+task_type: sweep_small
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+loss:
+ loss_cls: V2TContraLoss
diff --git a/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml b/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml
new file mode 100644
index 0000000000..c2b13e5519
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml
@@ -0,0 +1,49 @@
+dataset:
+ video_processor: YoucookVideoProcessor
+ bert_name: bert-base-uncased
+ meta_processor: YoucookMetaProcessor
+ train_path: data/youcook/youcook_train.pkl
+ val_path: data/youcook/youcook_val.pkl
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ use_annotation_text: true
+ vfeat_dir: data/feat/feat_youcook_s3d
+ text_processor: TextProcessor
+ aligner: DSAligner
+ num_iso_layer: 12
+ max_video_len: 32
+ max_len: 96
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ fp16: true
+ dataset:
+ num_workers: 4
+ batch_size: 128
+ optimization:
+ lr:
+ - 5.0e-05
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000
+ warmup_updates: 122
+ weight_decay: 0.0
+ ddp_backend: no_c10d
+ max_epoch: 10
+ checkpoint:
+ restore_file: runs/retri/videoclip/checkpoint_best.pt
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ save_dir: runs/retri/videoclip/youcook
+task_type: sweep_small
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls: null
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+loss:
+ loss_cls: T2VContraLoss
diff --git a/examples/MMPT/projects/retri/videoretri.yaml b/examples/MMPT/projects/retri/videoretri.yaml
new file mode 100644
index 0000000000..969e1fb279
--- /dev/null
+++ b/examples/MMPT/projects/retri/videoretri.yaml
@@ -0,0 +1,51 @@
+includes: projects/mfmmlm.yaml
+project_dir: retri/videoretri
+run_task:
+ - how2.yaml
+task_group:
+ pretrain:
+ task: VideoRetriTask
+ retri_epoch: 1
+ vectorpool_cls: VideoVectorPool
+ retriever_cls: VectorRetriever
+ num_cands: 64
+ dataset:
+ train_path: data/how2/how2_s3d_train.lst
+ meta_processor: ShardedHow2VideoRetriMetaProcessor
+ video_processor: ShardedVideoRetriVideoProcessor
+ text_processor: ShardedVideoRetriTextProcessor
+ aligner: VideoRetriOverlappedAligner
+ sampled_video_min_len: 3
+ sampled_video_max_len: 32
+ sampled_min_len: 8
+ sampled_max_len: 64
+ num_video_per_batch: 32
+ # do not use subsampling as it changes fairseq batch_size.
+ subsampling: 1 # disable subsampling
+ clip_per_video: 16
+ fairseq:
+ dataset:
+ batch_size: 1
+ optimization:
+ max_epoch: 25
+ model:
+ model_cls: MMFusionShare
+ mm_encoder_cls: MMBertForEncoder
+ loss:
+ loss_cls: MMContraLoss
+ finetune:
+ task_list: [vtt_videoclip.yaml, youcook_videoclip.yaml, vttqa_videoclip.yaml, crosstask_videoclip.yaml, coin_videoclip.yaml]
+ test:
+ task_list:
+ - test_youcook_zs.yaml
+ - test_vtt_zs.yaml
+ - test_vttqa_zs.yaml
+ - test_crosstask_zs_videoclip.yaml
+ - test_coin_zs.yaml
+ - test_didemo_zs.yaml
+ - test_youcook_videoclip.yaml
+ - test_vtt_videoclip.yaml
+ - test_vttqa_videoclip.yaml
+ - test_crosstask_videoclip.yaml
+ - test_coin_videoclip.yaml
+
diff --git a/examples/MMPT/projects/task/coin.yaml b/examples/MMPT/projects/task/coin.yaml
new file mode 100644
index 0000000000..e7772486e1
--- /dev/null
+++ b/examples/MMPT/projects/task/coin.yaml
@@ -0,0 +1,25 @@
+includes: projects/task/ft.yaml
+task_type: sweep_big
+dataset:
+ meta_processor: COINActionSegmentationMetaProcessor
+ train_path: data/coin/COIN.json
+ val_path: data/coin/COIN.json
+ vfeat_dir: data/feat/feat_coin_s3d
+ video_processor: VideoProcessor
+ text_processor: COINActionSegmentationTextProcessor
+ aligner: COINActionSegmentationAligner
+ num_iso_layer: 12
+ sliding_window: 8
+ sliding_window_size: 32
+model:
+ model_cls: MMFusionActionSegmentation
+ mm_encoder_cls: MMBertForTokenClassification
+loss:
+ loss_cls: CrossEntropy
+fairseq:
+ dataset:
+ batch_size: 1
+ optimization:
+ max_epoch: 8
+ checkpoint:
+ save_dir: runs/task/coin
diff --git a/examples/MMPT/projects/task/coin_videoclip.yaml b/examples/MMPT/projects/task/coin_videoclip.yaml
new file mode 100644
index 0000000000..69988bc18a
--- /dev/null
+++ b/examples/MMPT/projects/task/coin_videoclip.yaml
@@ -0,0 +1,7 @@
+includes: projects/task/coin.yaml
+model:
+ model_cls: MMFusionSeparateActionSegmentation
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForTokenClassification
+ text_encoder_cls: BertModel # dummy, not used.
+ num_hidden_video_layers: 6
diff --git a/examples/MMPT/projects/task/crosstask.yaml b/examples/MMPT/projects/task/crosstask.yaml
new file mode 100644
index 0000000000..cb4dbb0cb4
--- /dev/null
+++ b/examples/MMPT/projects/task/crosstask.yaml
@@ -0,0 +1,31 @@
+includes: projects/task/ft.yaml
+dataset:
+ meta_processor: CrossTaskMetaProcessor
+ train_path: data/crosstask/crosstask_release/videos.csv # dummy
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
+ val_path: data/crosstask/crosstask_release/videos_val.csv # dummy
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
+ vfeat_dir: data/feat/feat_crosstask_s3d
+ annotation_path: data/crosstask/crosstask_release/annotations
+ n_train: 30
+ video_processor: CrossTaskVideoProcessor
+ text_processor: CrossTaskTextProcessor
+ aligner: CrossTaskAligner
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+model:
+ model_cls: MMFusionActionLocalization
+ mm_encoder_cls: MMBertForJoint
+loss:
+ loss_cls: BCE
+fairseq:
+ dataset:
+ batch_size: 1
+ optimization:
+ max_epoch: 5
+ checkpoint:
+ save_dir: runs/task/crosstask
+ restore_file: runs/task/checkpoint11.pt # for VLM
diff --git a/examples/MMPT/projects/task/crosstask_videoclip.yaml b/examples/MMPT/projects/task/crosstask_videoclip.yaml
new file mode 100644
index 0000000000..6ec613c07f
--- /dev/null
+++ b/examples/MMPT/projects/task/crosstask_videoclip.yaml
@@ -0,0 +1,10 @@
+includes: projects/task/crosstask.yaml
+model:
+ model_cls: MMFusionSeparateActionLocalization
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel # dummy, not used.
+ num_hidden_video_layers: 6
+fairseq:
+ checkpoint:
+ restore_file: runs/task/checkpoint_best.pt # overwrite the default of VLM.
diff --git a/examples/MMPT/projects/task/default.yaml b/examples/MMPT/projects/task/default.yaml
new file mode 100644
index 0000000000..087fef71a4
--- /dev/null
+++ b/examples/MMPT/projects/task/default.yaml
@@ -0,0 +1,20 @@
+# this yaml cannot be run alone. you must use `how2.yaml`, `vtt.yaml` etc for training.
+dataset:
+ video_processor: VideoProcessor
+ bert_name: bert-base-uncased
+fairseq:
+ common:
+ tensorboard_logdir: run
+ log_interval: 1000
+ dataset:
+ num_workers: 4
+ optimization:
+ lr: [ 0.00005 ]
+ clip_norm: 2.0
+ optimizer: adam
+ adam_betas: (0.9, 0.98)
+ lr_scheduler: polynomial_decay
+ total_num_update: 1000000 # backward compatible on fairseq 1.0.0a0+af0389f for reproducibility.
+ warmup_updates: 1000
+ weight_decay: 0.0
+ ddp_backend: no_c10d
diff --git a/examples/MMPT/projects/task/ft.yaml b/examples/MMPT/projects/task/ft.yaml
new file mode 100644
index 0000000000..c93b8a73ea
--- /dev/null
+++ b/examples/MMPT/projects/task/ft.yaml
@@ -0,0 +1,13 @@
+includes: projects/task/default.yaml
+# all derived config will be run by fairseq-train.
+task_type: sweep_small
+fairseq:
+ optimization:
+ warmup_updates: 122 # copied from roberta glue: https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.glue.md
+ checkpoint:
+ # save_interval_updates: 512
+ # borrowed from Roberta script.
+ restore_file: runs/task/checkpoint_best.pt
+ reset_optimizer: True
+ reset_dataloader: True
+ reset_meters: True
diff --git a/examples/MMPT/projects/task/how2.yaml b/examples/MMPT/projects/task/how2.yaml
new file mode 100644
index 0000000000..094dd04bfc
--- /dev/null
+++ b/examples/MMPT/projects/task/how2.yaml
@@ -0,0 +1,22 @@
+includes: projects/task/default.yaml
+task_type: sweep_big
+slurm_config: big
+dataset:
+ meta_processor: ShardedHow2MetaProcessor
+ train_path: data/how2/how2_s3d_train.lst
+ val_path: data/how2/how2_s3d_val.lst
+ video_processor: ShardedVideoProcessor
+ vfeat_dir: data/feat/feat_how2_s3d_shard_small
+ text_processor: ShardedTextProcessor
+ tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased.
+ aligner: FixedLenAligner
+# disable direct running of this yaml
+eval:
+ save_path: runs/task
+fairseq:
+ checkpoint:
+ save_dir: runs/task
+ save_interval_updates: 1024
+ keep_interval_updates: 2
+ keep_last_epochs: 30
+
diff --git a/examples/MMPT/projects/task/test.yaml b/examples/MMPT/projects/task/test.yaml
new file mode 100644
index 0000000000..0a98445241
--- /dev/null
+++ b/examples/MMPT/projects/task/test.yaml
@@ -0,0 +1,13 @@
+# this yaml cannot be run alone: implement a test_${dataset}.yaml
+slurm_config: big
+task_type: local_predict
+dataset:
+ split: test
+ video_processor: VideoProcessor
+ aligner: DSAligner
+ bert_name: bert-base-uncased
+fairseq:
+ dataset:
+ batch_size: 256
+ valid_subset: test
+ num_workers: 2
diff --git a/examples/MMPT/projects/task/test_coin.yaml b/examples/MMPT/projects/task/test_coin.yaml
new file mode 100644
index 0000000000..6d919df7c2
--- /dev/null
+++ b/examples/MMPT/projects/task/test_coin.yaml
@@ -0,0 +1,24 @@
+includes: projects/task/test.yaml
+dataset:
+ split: test
+ test_path: data/coin/COIN.json
+ meta_processor: COINActionSegmentationMetaProcessor
+ vfeat_dir: data/feat/feat_coin_s3d
+ video_processor: VideoProcessor
+ text_processor: COINActionSegmentationTextProcessor
+ aligner: COINActionSegmentationAligner
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+model:
+ model_cls: MMFusionActionSegmentation
+ mm_encoder_cls: MMBertForTokenClassification
+eval:
+ save_path: runs/task/coin/eval
+fairseq:
+ dataset:
+ batch_size: 1
+ common_eval:
+ path: runs/task/coin/checkpoint_best.pt
+metric: COINActionSegmentationMetric
+predictor: COINPredictor
diff --git a/examples/MMPT/projects/task/test_coin_videoclip.yaml b/examples/MMPT/projects/task/test_coin_videoclip.yaml
new file mode 100644
index 0000000000..b41f5bc489
--- /dev/null
+++ b/examples/MMPT/projects/task/test_coin_videoclip.yaml
@@ -0,0 +1,7 @@
+includes: projects/task/test_coin.yaml
+model:
+ model_cls: MMFusionSeparateActionSegmentation
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForTokenClassification
+ text_encoder_cls: BertModel # dummy, not used.
+ num_hidden_video_layers: 6
diff --git a/examples/MMPT/projects/task/test_coin_zs.yaml b/examples/MMPT/projects/task/test_coin_zs.yaml
new file mode 100644
index 0000000000..5d19b09f1d
--- /dev/null
+++ b/examples/MMPT/projects/task/test_coin_zs.yaml
@@ -0,0 +1,13 @@
+includes: projects/task/test_coin.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/task/coin_zs/eval
+fairseq:
+ common_eval:
+ path: runs/task/checkpoint_best.pt
+predictor: COINZSPredictor
diff --git a/examples/MMPT/projects/task/test_crosstask.yaml b/examples/MMPT/projects/task/test_crosstask.yaml
new file mode 100644
index 0000000000..6dd778e30b
--- /dev/null
+++ b/examples/MMPT/projects/task/test_crosstask.yaml
@@ -0,0 +1,32 @@
+includes: projects/task/test.yaml
+dataset:
+ split: test
+ meta_processor: CrossTaskMetaProcessor
+ test_path: data/crosstask/crosstask_release/videos_val.csv
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
+ val_path: data/crosstask/crosstask_release/videos_val.csv # dummy
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
+ vfeat_dir: data/feat/feat_crosstask_s3d
+ annotation_path: data/crosstask/crosstask_release/annotations
+ n_train: 30
+ video_processor: CrossTaskVideoProcessor
+ text_processor: CrossTaskTextProcessor
+ aligner: CrossTaskAligner
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+model:
+ model_cls: MMFusionActionLocalization
+ mm_encoder_cls: MMBertForJoint
+eval:
+ save_path: runs/task/crosstask/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ dataset:
+ batch_size: 1
+ common_eval:
+ path: runs/task/crosstask/checkpoint_best.pt
+metric: CrossTaskMetric
+predictor: CrossTaskPredictor
diff --git a/examples/MMPT/projects/task/test_crosstask_videoclip.yaml b/examples/MMPT/projects/task/test_crosstask_videoclip.yaml
new file mode 100644
index 0000000000..df12535d23
--- /dev/null
+++ b/examples/MMPT/projects/task/test_crosstask_videoclip.yaml
@@ -0,0 +1,7 @@
+includes: projects/task/test_crosstask.yaml
+model:
+ model_cls: MMFusionSeparateActionLocalization
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel # dummy, not used.
+ num_hidden_video_layers: 6
diff --git a/examples/MMPT/projects/task/test_crosstask_zs.yaml b/examples/MMPT/projects/task/test_crosstask_zs.yaml
new file mode 100644
index 0000000000..19386e495b
--- /dev/null
+++ b/examples/MMPT/projects/task/test_crosstask_zs.yaml
@@ -0,0 +1,32 @@
+includes: projects/task/test.yaml
+dataset:
+ split: test
+ meta_processor: CrossTaskMetaProcessor
+ test_path: data/crosstask/crosstask_release/videos_val.csv
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
+ val_path: data/crosstask/crosstask_release/videos_val.csv # dummy
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
+ vfeat_dir: data/feat/feat_crosstask_s3d
+ annotation_path: data/crosstask/crosstask_release/annotations
+ n_train: 30
+ video_processor: CrossTaskVideoProcessor
+ text_processor: CrossTaskTextProcessor
+ aligner: CrossTaskAligner
+ num_iso_layer: 12
+ sliding_window: 16
+ sliding_window_size: 32
+model:
+ model_cls: MMFusionActionLocalization
+ mm_encoder_cls: MMBertForJoint
+eval:
+ save_path: runs/task/crosstask_zs/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ dataset:
+ batch_size: 1
+ common_eval:
+ path: runs/task/checkpoint_best.pt # load the best from how2 on ACL submission: runs/task/checkpoint11.pt
+metric: CrossTaskMetric
+predictor: CrossTaskPredictor
diff --git a/examples/MMPT/projects/task/test_crosstask_zs_videoclip.yaml b/examples/MMPT/projects/task/test_crosstask_zs_videoclip.yaml
new file mode 100644
index 0000000000..7f0198276f
--- /dev/null
+++ b/examples/MMPT/projects/task/test_crosstask_zs_videoclip.yaml
@@ -0,0 +1,7 @@
+includes: projects/task/test_crosstask_zs.yaml
+model:
+ model_cls: MMFusionSeparateActionLocalization
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel # dummy, not used.
+ num_hidden_video_layers: 6
diff --git a/examples/MMPT/projects/task/test_didemo_zs.yaml b/examples/MMPT/projects/task/test_didemo_zs.yaml
new file mode 100644
index 0000000000..4b53dca71e
--- /dev/null
+++ b/examples/MMPT/projects/task/test_didemo_zs.yaml
@@ -0,0 +1,23 @@
+includes: projects/task/test.yaml
+dataset:
+ meta_processor: DiDeMoMetaProcessor
+ test_path: data/didemo/test_data.json
+ video_processor: VideoProcessor
+ vfeat_dir: data/feat/feat_didemo_s3d
+ text_processor: DiDeMoTextProcessor
+ aligner: DiDeMoAligner
+ num_iso_layer: 12
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/task/didemo_zs/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ common_eval:
+ path: runs/task/checkpoint_best.pt
+metric: DiDeMoMetric
+predictor: DiDeMoPredictor
diff --git a/examples/MMPT/projects/task/test_vtt.yaml b/examples/MMPT/projects/task/test_vtt.yaml
new file mode 100644
index 0000000000..2f809b306d
--- /dev/null
+++ b/examples/MMPT/projects/task/test_vtt.yaml
@@ -0,0 +1,19 @@
+includes: projects/task/test.yaml
+dataset:
+ meta_processor: MSRVTTMetaProcessor
+ test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ video_processor: VideoProcessor
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ num_iso_layer: 12
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+eval:
+ save_path: runs/task/vtt/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ common_eval:
+ path: runs/task/vtt/checkpoint_last.pt
+metric: RetrievalMetric
+predictor: RetrievalPredictor
diff --git a/examples/MMPT/projects/task/test_vtt_videoclip.yaml b/examples/MMPT/projects/task/test_vtt_videoclip.yaml
new file mode 100644
index 0000000000..cb6564394c
--- /dev/null
+++ b/examples/MMPT/projects/task/test_vtt_videoclip.yaml
@@ -0,0 +1,8 @@
+includes: projects/task/test_vtt.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+
diff --git a/examples/MMPT/projects/task/test_vtt_zs.yaml b/examples/MMPT/projects/task/test_vtt_zs.yaml
new file mode 100644
index 0000000000..57340924b4
--- /dev/null
+++ b/examples/MMPT/projects/task/test_vtt_zs.yaml
@@ -0,0 +1,13 @@
+includes: projects/task/test_vtt.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/task/vtt_zs/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ common_eval:
+ path: runs/task/checkpoint_best.pt
diff --git a/examples/MMPT/projects/task/test_vttqa.yaml b/examples/MMPT/projects/task/test_vttqa.yaml
new file mode 100644
index 0000000000..ddf813c535
--- /dev/null
+++ b/examples/MMPT/projects/task/test_vttqa.yaml
@@ -0,0 +1,20 @@
+includes: projects/task/test.yaml
+dataset:
+ meta_processor: MSRVTTQAMetaProcessor
+ test_path: data/msrvtt-qa/MSR_MC_test.csv
+ video_processor: VideoProcessor
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTQATextProcessor
+ aligner: MSRVTTQAAligner
+ num_iso_layer: 12
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+eval:
+ save_path: runs/task/vttqa/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ common_eval:
+ path: runs/task/vttqa/checkpoint_last.pt
+metric: QAMetric
+predictor: QAPredictor
diff --git a/examples/MMPT/projects/task/test_vttqa_videoclip.yaml b/examples/MMPT/projects/task/test_vttqa_videoclip.yaml
new file mode 100644
index 0000000000..32a41e861c
--- /dev/null
+++ b/examples/MMPT/projects/task/test_vttqa_videoclip.yaml
@@ -0,0 +1,8 @@
+includes: projects/task/test_vttqa.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+
diff --git a/examples/MMPT/projects/task/test_vttqa_zs.yaml b/examples/MMPT/projects/task/test_vttqa_zs.yaml
new file mode 100644
index 0000000000..5e0e29d207
--- /dev/null
+++ b/examples/MMPT/projects/task/test_vttqa_zs.yaml
@@ -0,0 +1,13 @@
+includes: projects/task/test_vttqa.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/task/vttqa_zs/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ common_eval:
+ path: runs/task/checkpoint_best.pt
diff --git a/examples/MMPT/projects/task/test_youcook.yaml b/examples/MMPT/projects/task/test_youcook.yaml
new file mode 100644
index 0000000000..092b680fa6
--- /dev/null
+++ b/examples/MMPT/projects/task/test_youcook.yaml
@@ -0,0 +1,22 @@
+includes: projects/task/test.yaml
+dataset:
+ meta_processor: YoucookMetaProcessor
+ test_path: data/youcook/youcook_val.pkl
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ use_annotation_text: True
+ video_processor: YoucookVideoProcessor
+ vfeat_dir: data/feat/feat_youcook_s3d # /checkpoint/huxu/feat/youcook_vmz # /checkpoint/prarora/berniehuang/feat_youcook_vmz
+ text_processor: TextProcessor
+ aligner: DSAligner
+ num_iso_layer: 12
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+eval:
+ save_path: runs/task/youcook/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ common_eval:
+ path: runs/task/youcook/checkpoint_last.pt
+metric: RetrievalMetric
+predictor: RetrievalPredictor
diff --git a/examples/MMPT/projects/task/test_youcook_videoclip.yaml b/examples/MMPT/projects/task/test_youcook_videoclip.yaml
new file mode 100644
index 0000000000..b85ea43474
--- /dev/null
+++ b/examples/MMPT/projects/task/test_youcook_videoclip.yaml
@@ -0,0 +1,8 @@
+includes: projects/task/test_youcook.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+
diff --git a/examples/MMPT/projects/task/test_youcook_zs.yaml b/examples/MMPT/projects/task/test_youcook_zs.yaml
new file mode 100644
index 0000000000..0a5875bea4
--- /dev/null
+++ b/examples/MMPT/projects/task/test_youcook_zs.yaml
@@ -0,0 +1,13 @@
+includes: projects/task/test_youcook.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+eval:
+ save_path: runs/task/youcook_zs/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ common_eval:
+ path: runs/task/checkpoint_best.pt
diff --git a/examples/MMPT/projects/task/test_youcookcap.yaml b/examples/MMPT/projects/task/test_youcookcap.yaml
new file mode 100644
index 0000000000..24f6518b7b
--- /dev/null
+++ b/examples/MMPT/projects/task/test_youcookcap.yaml
@@ -0,0 +1,23 @@
+includes: projects/task/test.yaml
+dataset:
+ meta_processor: YoucookNLGMetaProcessor
+ test_path: data/youcook/val_list.txt
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ video_processor: YoucookVideoProcessor
+ vfeat_dir: data/feat/feat_youcook_s3d
+ text_processor: NLGTextProcessor
+ aligner: DSNLGAligner
+model:
+ model_cls: MMFusionNLG
+ mm_encoder_cls: MMBertForNLG
+ max_decode_length: 24
+eval:
+ save_path: runs/task/youcookcap/eval
+fairseq:
+ # read code and find what is the checkpoint arg.
+ common_eval:
+ path: runs/task/youcookcap/checkpoint_best.pt
+metric: NLGMetric
+predictor: NLGPredictor
+gen_param:
+ num_beams: 5
diff --git a/examples/MMPT/projects/task/vtt.yaml b/examples/MMPT/projects/task/vtt.yaml
new file mode 100644
index 0000000000..395e2ee6fe
--- /dev/null
+++ b/examples/MMPT/projects/task/vtt.yaml
@@ -0,0 +1,25 @@
+includes: projects/task/ft.yaml
+dataset:
+ meta_processor: MSRVTTMetaProcessor
+ train_path: data/msrvtt/MSRVTT_train.csv
+ jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ full_test_path: data/msrvtt/MSRVTT_FULL_test.csv
+ dup: 20
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ json_path: data/msrvtt/MSRVTT_data.json
+ aligner: DSAligner
+ num_iso_layer: 12
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+loss:
+ loss_cls: T2VContraLoss
+fairseq:
+ dataset:
+ batch_size: 256
+ optimization:
+ max_epoch: 10
+ checkpoint:
+ save_dir: runs/task/vtt
diff --git a/examples/MMPT/projects/task/vtt_videoclip.yaml b/examples/MMPT/projects/task/vtt_videoclip.yaml
new file mode 100644
index 0000000000..a9892cab01
--- /dev/null
+++ b/examples/MMPT/projects/task/vtt_videoclip.yaml
@@ -0,0 +1,12 @@
+includes: projects/task/vtt.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+fairseq:
+ dataset:
+ batch_size: 224
+# model_cls: MMFusionShare
+# mm_encoder_cls: MMBertForEncoder
diff --git a/examples/MMPT/projects/task/vttqa.yaml b/examples/MMPT/projects/task/vttqa.yaml
new file mode 100644
index 0000000000..56d578eff0
--- /dev/null
+++ b/examples/MMPT/projects/task/vttqa.yaml
@@ -0,0 +1,23 @@
+includes: projects/task/ft.yaml
+dataset:
+ meta_processor: MSRVTTMetaProcessor
+ train_path: data/msrvtt/MSRVTT_train.csv
+ dup: 20
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+ vfeat_dir: data/feat/feat_vtt_s3d
+ text_processor: MSRVTTTextProcessor
+ json_path: data/msrvtt/MSRVTT_data.json
+ aligner: DSAligner
+ num_iso_layer: 12
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+loss:
+ loss_cls: V2TContraLoss
+fairseq:
+ dataset:
+ batch_size: 128
+ optimization:
+ max_epoch: 5
+ checkpoint:
+ save_dir: runs/task/vttqa
diff --git a/examples/MMPT/projects/task/vttqa_videoclip.yaml b/examples/MMPT/projects/task/vttqa_videoclip.yaml
new file mode 100644
index 0000000000..2d484ca8a5
--- /dev/null
+++ b/examples/MMPT/projects/task/vttqa_videoclip.yaml
@@ -0,0 +1,10 @@
+includes: projects/task/vttqa.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+
+# model_cls: MMFusionShare
+# mm_encoder_cls: MMBertForEncoder
diff --git a/examples/MMPT/projects/task/youcook.yaml b/examples/MMPT/projects/task/youcook.yaml
new file mode 100644
index 0000000000..e0cd841747
--- /dev/null
+++ b/examples/MMPT/projects/task/youcook.yaml
@@ -0,0 +1,25 @@
+includes: projects/task/ft.yaml
+dataset:
+ meta_processor: YoucookMetaProcessor
+ train_path: data/youcook/youcook_train.pkl
+ val_path: data/youcook/youcook_val.pkl
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ use_annotation_text: True
+ video_processor: YoucookVideoProcessor
+ vfeat_dir: data/feat/feat_youcook_s3d # /checkpoint/huxu/feat/youcook_vmz # /checkpoint/prarora/berniehuang/feat_youcook_vmz
+ text_processor: TextProcessor
+ aligner: DSAligner
+ num_iso_layer: 12
+model:
+ model_cls: MMFusionJoint
+ mm_encoder_cls: MMBertForJoint
+loss:
+ loss_cls: T2VContraLoss
+fairseq:
+ dataset:
+ batch_size: 128
+ optimization:
+ max_epoch: 10
+ checkpoint:
+ save_dir: runs/task/youcook
+
diff --git a/examples/MMPT/projects/task/youcook_videoclip.yaml b/examples/MMPT/projects/task/youcook_videoclip.yaml
new file mode 100644
index 0000000000..e3e901c30c
--- /dev/null
+++ b/examples/MMPT/projects/task/youcook_videoclip.yaml
@@ -0,0 +1,9 @@
+includes: projects/task/youcook.yaml
+model:
+ model_cls: MMFusionSeparate
+ mm_encoder_cls:
+ video_encoder_cls: MMBertForEncoder
+ text_encoder_cls: BertModel
+ num_hidden_video_layers: 6
+ # model_cls: MMFusionShare
+ # mm_encoder_cls: MMBertForEncoder
diff --git a/examples/MMPT/projects/task/youcookcap.yaml b/examples/MMPT/projects/task/youcookcap.yaml
new file mode 100644
index 0000000000..047735f217
--- /dev/null
+++ b/examples/MMPT/projects/task/youcookcap.yaml
@@ -0,0 +1,23 @@
+# finetuning for youcook captioning.
+includes: projects/task/ft.yaml
+dataset:
+ meta_processor: YoucookNLGMetaProcessor
+ train_path: data/youcook/train_list.txt
+ val_path: data/youcook/val_list.txt
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+ video_processor: YoucookVideoProcessor
+ vfeat_dir: data/feat/feat_youcook_s3d
+ text_processor: NLGTextProcessor
+ aligner: DSNLGAligner
+model:
+ model_cls: MMFusionNLG
+ mm_encoder_cls: MMBertForNLG
+loss:
+ loss_cls: NLGLoss
+fairseq:
+ dataset:
+ batch_size: 128
+ optimization:
+ max_epoch: 10
+ checkpoint:
+ save_dir: runs/task/youcookcap
diff --git a/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml b/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml
new file mode 100644
index 0000000000..473dd9b45b
--- /dev/null
+++ b/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml
@@ -0,0 +1,5 @@
+dataset:
+ bert_name: bert-base-uncased
+ caption_pkl_path: data/how2/raw_caption_dedup.pkl
+ use_fast: true
+ target_dir: data/feat/feat_how2_s3d_shard_small
diff --git a/examples/MMPT/scripts/text_token_extractor/pretokenization.py b/examples/MMPT/scripts/text_token_extractor/pretokenization.py
new file mode 100644
index 0000000000..29ae5dc151
--- /dev/null
+++ b/examples/MMPT/scripts/text_token_extractor/pretokenization.py
@@ -0,0 +1,106 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pickle
+import os
+import argparse
+import numpy as np
+
+from torch.utils.data import Dataset, DataLoader
+from mmpt.processors import PKLJSONStrTextProcessor
+from mmpt.utils import ShardedTensor, recursive_config
+
+
+class TokenizerDataset(Dataset):
+ def __init__(self, config):
+ self.text_processor = PKLJSONStrTextProcessor(config)
+ self.video_ids = list(self.text_processor.data.keys())
+
+ def __getitem__(self, idx):
+ video_id = self.video_ids[idx]
+ return video_id, self.text_processor(video_id)
+
+ def __len__(self):
+ return len(self.video_ids)
+
+
+def numpify(shard_idx, video_ids, captions, target_dir, split, prefix, max_cap_len=32):
+ startends = []
+ caps_ids = []
+ for video_id in video_ids:
+ caption = captions[video_id]
+ startend = []
+ cap_ids = []
+ for start, end, cap in zip(
+ caption["start"], caption["end"], caption["cap"]):
+ startend.append(np.array([start, end]).astype("float32"))
+ cap_id = np.full((max_cap_len,), -1, dtype=np.int32)
+ cap = cap[:max_cap_len]
+ cap_id[:len(cap)] = cap
+ cap_ids.append(cap_id)
+ startends.append(np.stack(startend))
+ caps_ids.append(np.stack(cap_ids))
+
+ startends = ShardedTensor.from_list(startends)
+ target_path = os.path.join(
+ target_dir,
+ prefix + split + "_" + str(shard_idx)
+ )
+ print("save to", target_path)
+ startends.save(target_path + ".startends")
+ caps_ids = ShardedTensor.from_list(caps_ids)
+ caps_ids.save(target_path + ".caps_ids")
+
+
+def sharding(config, out_file):
+ with open(out_file, "rb") as fr:
+ captions = pickle.load(fr)
+ target_dir = config.target_dir
+ prefix = os.path.basename(
+ os.path.splitext(config.caption_pkl_path)[0]
+ ) + "." + config.bert_name + "."
+ for split in ["train", "val"]:
+ target_path = os.path.join(target_dir, split + "_meta")
+ with open(target_path + ".pkl", "rb") as fr:
+ meta = pickle.load(fr)
+ print("load meta", target_path, len(meta))
+ for shard_id in meta:
+ numpify(
+ shard_id, meta[shard_id], captions,
+ target_dir, split, prefix
+ )
+
+
+def tokenize(config, out_file):
+ def collator(samples):
+ return samples
+ dataset = TokenizerDataset(config)
+ data = {}
+ for idx, batch in enumerate(
+ DataLoader(dataset, collate_fn=collator, num_workers=16)):
+ for video_id, caption in batch:
+ data[video_id] = caption
+ if idx % 5000 == 0:
+ print(idx)
+ with open(out_file, "wb") as fw:
+ pickle.dump(data, fw, pickle.HIGHEST_PROTOCOL)
+
+
+def main(args):
+ config = recursive_config(args.config).dataset
+
+ out_file = os.path.splitext(config.caption_pkl_path)[0] \
+ + "." + config.bert_name + ".pkl"
+ if not os.path.isfile(out_file):
+ tokenize(config, out_file)
+ sharding(config, out_file)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="pretokenize (raw_)caption.json into pkl.")
+ parser.add_argument('config', type=str)
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/MMPT/scripts/video_feature_extractor/extract.py b/examples/MMPT/scripts/video_feature_extractor/extract.py
new file mode 100755
index 0000000000..b5ee7b7788
--- /dev/null
+++ b/examples/MMPT/scripts/video_feature_extractor/extract.py
@@ -0,0 +1,157 @@
+# Copyright Howto100M authors.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import torch as th
+import torch.nn.functional as F
+import math
+import numpy as np
+import argparse
+
+from torch.utils.data import DataLoader
+from model import get_model
+from preprocessing import Preprocessing
+from random_sequence_shuffler import RandomSequenceSampler
+
+from tqdm import tqdm
+from pathbuilder import PathBuilder
+from videoreader import VideoLoader
+
+
+parser = argparse.ArgumentParser(description='Easy video feature extractor')
+
+parser.add_argument('--vdir', type=str)
+parser.add_argument('--fdir', type=str)
+parser.add_argument('--hflip', type=int, default=0)
+
+parser.add_argument('--batch_size', type=int, default=64,
+ help='batch size')
+parser.add_argument('--type', type=str, default='2d',
+ help='CNN type')
+parser.add_argument('--half_precision', type=int, default=0,
+ help='output half precision float')
+parser.add_argument('--num_decoding_thread', type=int, default=4,
+ help='Num parallel thread for video decoding')
+parser.add_argument('--l2_normalize', type=int, default=1,
+ help='l2 normalize feature')
+parser.add_argument('--resnext101_model_path', type=str, default='model/resnext101.pth',
+ help='Resnext model path')
+parser.add_argument('--vmz_model_path', type=str, default='model/r2plus1d_34_clip8_ig65m_from_scratch-9bae36ae.pth',
+ help='vmz model path')
+
+args = parser.parse_args()
+
+
+# TODO: refactor all args into config. (current code is from different people.)
+CONFIGS = {
+ "2d": {
+ "fps": 1,
+ "size": 224,
+ "centercrop": False,
+ "shards": 0,
+ },
+ "3d": {
+ "fps": 24,
+ "size": 112,
+ "centercrop": True,
+ "shards": 0,
+ },
+ "s3d": {
+ "fps": 30,
+ "size": 224,
+ "centercrop": True,
+ "shards": 0,
+ },
+ "vmz": {
+ "fps": 24,
+ "size": 112,
+ "centercrop": True,
+ "shards": 0,
+ },
+ "vae": {
+ "fps": 2,
+ "size": 256,
+ "centercrop": True,
+ "shards": 100,
+ }
+}
+
+config = CONFIGS[args.type]
+
+
+video_dirs = args.vdir
+feature_dir = args.fdir
+
+video_dict = PathBuilder.build(video_dirs, feature_dir, ".npy", config["shards"])
+
+dataset = VideoLoader(
+ video_dict=video_dict,
+ framerate=config["fps"],
+ size=config["size"],
+ centercrop=config["centercrop"],
+ hflip=args.hflip
+)
+n_dataset = len(dataset)
+sampler = RandomSequenceSampler(n_dataset, 10)
+loader = DataLoader(
+ dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=args.num_decoding_thread,
+ sampler=sampler if n_dataset > 10 else None,
+)
+preprocess = Preprocessing(args.type)
+model = get_model(args)
+
+with th.no_grad():
+ for k, data in tqdm(enumerate(loader), total=loader.__len__(), ascii=True):
+ input_file = data['input'][0]
+ output_file = data['output'][0]
+ if len(data['video'].shape) > 3:
+ video = data['video'].squeeze()
+ if len(video.shape) == 4:
+ video = preprocess(video)
+ n_chunk = len(video)
+ if args.type == 'vmz':
+ n_chunk = math.ceil(n_chunk/float(3))
+ features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
+ elif args.type == 's3d':
+ features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
+ elif args.type == "vae":
+ features = th.cuda.LongTensor(n_chunk, 1024).fill_(0)
+ else:
+ features = th.cuda.FloatTensor(n_chunk, 2048).fill_(0)
+ n_iter = int(math.ceil(n_chunk / float(args.batch_size)))
+ for i in range(n_iter):
+ factor = 1
+ if args.type == 'vmz':
+ factor = 3
+ min_ind = factor * i * args.batch_size
+ max_ind = factor * (i + 1) * args.batch_size
+ video_batch = video[min_ind:max_ind:factor].cuda()
+ if args.type == '2d':
+ batch_features = model(video_batch) # (51, 487), (51, 512)
+ elif args.type == 's3d':
+ batch_features = model(video_batch)
+ batch_features = batch_features['video_embedding']
+ elif args.type == "vae":
+ # image_code.
+ batch_features = model(video_batch)
+ else:
+ batch_pred, batch_features = model(video_batch) # (51, 487), (51, 512)
+ if args.l2_normalize:
+ batch_features = F.normalize(batch_features, dim=1)
+ features[i*args.batch_size:(i+1)*args.batch_size] = batch_features
+ features = features.cpu().numpy()
+ if args.half_precision:
+ if args.type == "vae":
+ features = features.astype(np.int16)
+ else:
+ features = features.astype('float16')
+ else:
+ if args.type == "vae":
+ features = features.astype(np.int32)
+ else:
+ features = features.astype('float32')
+ np.save(output_file, features)
+ else:
+ print('Video {} error.'.format(input_file))
diff --git a/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh b/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh
new file mode 100644
index 0000000000..90102c89fb
--- /dev/null
+++ b/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+
+python scripts/video_feature_extractor/extract.py \
+ --vdir \
+ --fdir data/feat/feat_how2_s3d \
+ --type=s3d --num_decoding_thread=4 \
+ --batch_size 32 --half_precision 1
diff --git a/examples/MMPT/scripts/video_feature_extractor/model.py b/examples/MMPT/scripts/video_feature_extractor/model.py
new file mode 100755
index 0000000000..ac266e844c
--- /dev/null
+++ b/examples/MMPT/scripts/video_feature_extractor/model.py
@@ -0,0 +1,58 @@
+# Copyright (c) Howto100M authors and Facebook, Inc. All Rights Reserved
+
+import torch as th
+
+from torch import nn
+
+
+class GlobalAvgPool(nn.Module):
+ def __init__(self):
+ super(GlobalAvgPool, self).__init__()
+
+ def forward(self, x):
+ return th.mean(x, dim=[-2, -1])
+
+
+def get_model(args):
+ assert args.type in ['2d', '3d', 'vmz', 's3d', 'vae']
+ if args.type == '2d':
+ print('Loading 2D-ResNet-152 ...')
+ import torchvision.models as models
+ model = models.resnet152(pretrained=True)
+ model = nn.Sequential(*list(model.children())[:-2], GlobalAvgPool())
+ model = model.cuda()
+ elif args.type == 'vmz':
+ print('Loading VMZ ...')
+ from vmz34 import r2plus1d_34
+ model = r2plus1d_34(pretrained_path=args.vmz_model_path, pretrained_num_classes=487)
+ model = model.cuda()
+ elif args.type == 's3d':
+ # we use one copy of s3d instead of dup another one for feature extraction.
+ from mmpt.processors.models.s3dg import S3D
+ model = S3D('pretrained_models/s3d_dict.npy', 512)
+ model.load_state_dict(th.load('pretrained_models/s3d_howto100m.pth'))
+ model = model.cuda()
+
+ elif args.type == '3d':
+ print('Loading 3D-ResneXt-101 ...')
+ from videocnn.models import resnext
+ model = resnext.resnet101(
+ num_classes=400,
+ shortcut_type='B',
+ cardinality=32,
+ sample_size=112,
+ sample_duration=16,
+ last_fc=False)
+ model = model.cuda()
+ model_data = th.load(args.resnext101_model_path)
+ model.load_state_dict(model_data)
+ elif args.type == 'vae':
+ from openaivae import OpenAIParallelDiscreteVAE
+ model = OpenAIParallelDiscreteVAE()
+ model = model.cuda()
+ else:
+ raise ValueError("model not supported yet.")
+
+ model.eval()
+ print('loaded')
+ return model
diff --git a/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py b/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py
new file mode 100644
index 0000000000..2392d6d63b
--- /dev/null
+++ b/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py
@@ -0,0 +1,89 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import urllib.parse
+import json
+import pandas as pd
+
+from tqdm import tqdm
+
+
+# TODO: extending to other datasets.
+supported_formats = {}
+
+
+class PathBuilder(object):
+ @classmethod
+ def build(cls, video_dirs, feature_dir, ext, shards=0, split=None):
+ meta_fn = os.path.join(feature_dir, "meta_plan.json")
+ os.makedirs(feature_dir, exist_ok=True)
+ if os.path.isfile(meta_fn):
+ with open(meta_fn) as fr:
+ meta = json.load(fr)
+ return meta
+ print("searching videos...")
+
+ video_id_to_path = {}
+ for video_dir in video_dirs.split(","):
+ # TODO: add supports of recursive listdir.
+ if video_dir in supported_formats:
+ supported_formats[video_dir].load(video_dir, video_id_to_path)
+ else:
+ for idx, fn in enumerate(tqdm(os.listdir(video_dir))):
+ video_fn = os.path.join(video_dir, fn)
+ if os.path.isfile(video_fn):
+ video_id = os.path.splitext(fn)[0]
+ video_id_to_path[video_id] = video_fn
+ elif os.path.isdir(video_fn):
+ # shards of folders.
+ shard_dir = video_fn
+ for idx, fn in enumerate(os.listdir(shard_dir)):
+ video_fn = os.path.join(shard_dir, fn)
+ if os.path.isfile(video_fn):
+ video_id = os.path.splitext(fn)[0]
+ video_id_to_path[video_id] = video_fn
+
+ video_path, feature_path = [], []
+ valid_ext = set()
+ for idx, video_id in enumerate(video_id_to_path):
+ video_path.append(video_id_to_path[video_id])
+ if ext is None:
+ # use original file ext for format compatibility.
+ video_id_to_path[video_id]
+ path = urllib.parse.urlparse(video_id_to_path[video_id]).path
+ ext = os.path.splitext(path)[1]
+ if ext not in valid_ext:
+ valid_ext.add(ext)
+ print("adding", ext)
+ if shards:
+ shard_id = str(idx % shards)
+ feature_fn = os.path.join(
+ feature_dir, shard_id, video_id + ext)
+ else:
+ feature_fn = os.path.join(
+ feature_dir, video_id + ext)
+ feature_path.append(feature_fn)
+
+ print("targeting", len(feature_path), "videos")
+ meta = {
+ "video_path": video_path, "feature_path": feature_path}
+ with open(meta_fn, "w") as fw:
+ json.dump(meta, fw)
+
+ if split is not None:
+ splits = split.split("/")
+ assert len(splits) == 2
+ cur, total = int(splits[0]), int(splits[1])
+ assert cur < total
+ import math
+ chunk = math.ceil(len(meta["video_path"]) / total)
+ start = cur * chunk
+ end = (cur + 1) * chunk
+ meta = {
+ "video_path": meta["video_path"][start:end],
+ "feature_path": meta["feature_path"][start:end]
+ }
+
+ return meta
diff --git a/examples/MMPT/scripts/video_feature_extractor/preprocessing.py b/examples/MMPT/scripts/video_feature_extractor/preprocessing.py
new file mode 100755
index 0000000000..fa0cec3a76
--- /dev/null
+++ b/examples/MMPT/scripts/video_feature_extractor/preprocessing.py
@@ -0,0 +1,57 @@
+# Copyright Howto100m authors.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import torch as th
+
+class Normalize(object):
+
+ def __init__(self, mean, std):
+ self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
+ self.std = th.FloatTensor(std).view(1, 3, 1, 1)
+
+ def __call__(self, tensor):
+ tensor = (tensor - self.mean) / (self.std + 1e-8)
+ return tensor
+
+class Preprocessing(object):
+
+ def __init__(self, type):
+ self.type = type
+ if type == '2d':
+ self.norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ elif type == '3d':
+ self.norm = Normalize(mean=[110.6, 103.2, 96.3], std=[1.0, 1.0, 1.0])
+ elif type == 'vmz':
+ self.norm = Normalize(mean=[110.201, 100.64, 95.997], std=[58.1489, 56.4701, 55.3324])
+
+ def _zero_pad(self, tensor, size):
+ n = size - len(tensor) % size
+ if n == size:
+ return tensor
+ else:
+ z = th.zeros(n, tensor.shape[1], tensor.shape[2], tensor.shape[3])
+ return th.cat((tensor, z), 0)
+
+ def __call__(self, tensor):
+ if self.type == '2d':
+ tensor = tensor / 255.0
+ tensor = self.norm(tensor)
+ elif self.type == 'vmz':
+ #tensor = self._zero_pad(tensor, 8)
+ tensor = self._zero_pad(tensor, 10)
+ tensor = self.norm(tensor)
+ #tensor = tensor.view(-1, 8, 3, 112, 112)
+ tensor = tensor.view(-1, 10, 3, 112, 112)
+ tensor = tensor.transpose(1, 2)
+ elif self.type == '3d':
+ tensor = self._zero_pad(tensor, 16)
+ tensor = self.norm(tensor)
+ tensor = tensor.view(-1, 16, 3, 112, 112)
+ tensor = tensor.transpose(1, 2)
+ elif self.type == 's3d':
+ tensor = tensor / 255.0
+ tensor = self._zero_pad(tensor, 30)
+ tensor = tensor.view(-1, 30, 3, 224, 224) # N x 30 x 3 x H x W
+ tensor = tensor.transpose(1, 2) # N x 3 x 30 x H x W
+ # for vae do nothing
+ return tensor
diff --git a/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py b/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py
new file mode 100755
index 0000000000..1f3e4aceaa
--- /dev/null
+++ b/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py
@@ -0,0 +1,29 @@
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import numpy as np
+
+from torch.utils.data.sampler import Sampler
+
+
+class RandomSequenceSampler(Sampler):
+
+ def __init__(self, n_sample, seq_len):
+ self.n_sample = n_sample
+ self.seq_len = seq_len
+
+ def _pad_ind(self, ind):
+ zeros = np.zeros(self.seq_len - self.n_sample % self.seq_len)
+ ind = np.concatenate((ind, zeros))
+ return ind
+
+ def __iter__(self):
+ idx = np.arange(self.n_sample)
+ if self.n_sample % self.seq_len != 0:
+ idx = self._pad_ind(idx)
+ idx = np.reshape(idx, (-1, self.seq_len))
+ np.random.shuffle(idx)
+ idx = np.reshape(idx, (-1))
+ return iter(idx.astype(int))
+
+ def __len__(self):
+ return self.n_sample + (self.seq_len - self.n_sample % self.seq_len)
diff --git a/examples/MMPT/scripts/video_feature_extractor/shard_feature.py b/examples/MMPT/scripts/video_feature_extractor/shard_feature.py
new file mode 100644
index 0000000000..f75e1dfae5
--- /dev/null
+++ b/examples/MMPT/scripts/video_feature_extractor/shard_feature.py
@@ -0,0 +1,64 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import numpy as np
+import os
+import pickle
+
+from mmpt.utils import ShardedTensor
+
+
+class Shard(object):
+ def __init__(
+ self,
+ vfeat_dir,
+ tfeat_dir,
+ target_dir,
+ file_paths,
+ shard_size=4096
+ ):
+ self.vfeat_dir = vfeat_dir
+ self.tfeat_dir = tfeat_dir
+ self.target_dir = target_dir
+ self.video_ids = {}
+ for split, file_path in zip(["train", "val"], file_paths):
+ with open(file_path) as fr:
+ self.video_ids[split] = [
+ line.strip() for line in fr.readlines()]
+ self.shard_size = shard_size
+
+ def __call__(self, split="train"):
+ for split in ["train", "val"]:
+ meta = {}
+ for shard_idx, shard_offset in enumerate(
+ range(0, len(self.video_ids[split]), self.shard_size)
+ ):
+ print(shard_idx)
+ meta_shard = []
+ video_shard = []
+ for video_id in self.video_ids[split][shard_offset:shard_offset+self.shard_size]:
+ meta_shard.append(video_id)
+ npy_file = os.path.join(self.vfeat_dir, video_id + ".npy")
+ video_shard.append(np.load(npy_file))
+
+ meta[shard_idx] = meta_shard
+ video_shard = ShardedTensor.from_list(video_shard)
+ target_path = os.path.join(
+ self.target_dir, split + "_" + str(shard_idx))
+ video_shard.save(target_path)
+
+ target_path = os.path.join(self.target_dir, split + "_meta")
+ with open(target_path + ".pkl", "wb") as fw:
+ pickle.dump(meta, fw, pickle.HIGHEST_PROTOCOL)
+
+
+if __name__ == "__main__":
+ shard = Shard(
+ "data/feat/feat_how2_s3d",
+ "data/how2/raw_caption_dedup.bert-base-uncased",
+ "data/feat/feat_how2_s3d_shard_small",
+ ["data/how2/how2_s3d_train.lst", "data/how2/how2_s3d_val.lst"]
+ )
+
+ shard()
diff --git a/examples/MMPT/scripts/video_feature_extractor/videoreader.py b/examples/MMPT/scripts/video_feature_extractor/videoreader.py
new file mode 100644
index 0000000000..429e05f8bc
--- /dev/null
+++ b/examples/MMPT/scripts/video_feature_extractor/videoreader.py
@@ -0,0 +1,242 @@
+# Copyright Howto100M authors.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import torch as th
+import pandas as pd
+import os
+import numpy as np
+import ffmpeg
+import random
+
+from torch.utils.data import Dataset
+
+
+class VideoLoader(Dataset):
+ """modified from how2's video_feature_extractor."""
+ def __init__(
+ self,
+ csv=None,
+ video_dict=None,
+ framerate=1,
+ size=112,
+ centercrop=False,
+ hflip=False,
+ **kwargs
+ ):
+ if csv is None and video_dict is None:
+ raise ValueError("csv and video_dict cannot be both None.")
+ if csv is not None:
+ self.csv = pd.read_csv(csv)
+ if video_dict is not None:
+ self.csv = pd.DataFrame.from_dict(video_dict)
+
+ self.centercrop = centercrop
+ self.size = size
+ self.framerate = framerate
+ self.hflip = hflip
+
+ def __len__(self):
+ return len(self.csv)
+
+ def _get_video_dim(self, video_path):
+ probe = ffmpeg.probe(video_path)
+ video_stream = next((stream for stream in probe['streams']
+ if stream['codec_type'] == 'video'), None)
+ width = int(video_stream['width'])
+ height = int(video_stream['height'])
+ return height, width
+
+ def _get_video_info(self, video_path):
+ probe = ffmpeg.probe(video_path)
+ video_stream = next((stream for stream in probe['streams']
+ if stream['codec_type'] == 'video'), None)
+ return video_stream
+
+ def _get_output_dim(self, h, w):
+ if isinstance(self.size, tuple) and len(self.size) == 2:
+ return self.size
+ elif h >= w:
+ return int(h * self.size / w), self.size
+ else:
+ return self.size, int(w * self.size / h)
+
+ def __getitem__(self, idx):
+ video_path = self.csv['video_path'].values[idx]
+ output_file = self.csv['feature_path'].values[idx]
+ return self._decode(output_file, video_path)
+
+ def _decode(self, output_file, video_path):
+ if not(os.path.isfile(output_file)) and os.path.isfile(video_path):
+ try:
+ h, w = self._get_video_dim(video_path)
+ except Exception:
+ print('ffprobe failed at: {}'.format(video_path))
+ return {'video': th.zeros(1), 'input': video_path,
+ 'output': output_file}
+ try:
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+ height, width = self._get_output_dim(h, w)
+
+ cmd = (
+ ffmpeg
+ .input(video_path)
+ .filter('fps', fps=self.framerate)
+ .filter('scale', width, height)
+ )
+ if self.hflip:
+ cmd = cmd.filter('hflip')
+
+ if self.centercrop:
+ x = int((width - self.size) / 2.0)
+ y = int((height - self.size) / 2.0)
+ cmd = cmd.crop(x, y, self.size, self.size)
+ video = self._run(cmd, output_file)
+ except Exception:
+ video = th.zeros(1)
+ else:
+ video = th.zeros(1)
+
+ return {'video': video, 'input': video_path, 'output': output_file}
+
+ def _run(self, cmd, output_file):
+ out, _ = (
+ cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
+ .run(capture_stdout=True, quiet=True)
+ )
+ if self.centercrop and isinstance(self.size, int):
+ height, width = self.size, self.size
+ video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
+ video = th.from_numpy(video.astype('float32'))
+ return video.permute(0, 3, 1, 2)
+
+
+class VideoVerifier(VideoLoader):
+ def __getitem__(self, idx):
+ video_path = self.csv['video_path'].values[idx]
+ try:
+ return self._get_video_info(video_path)
+ except Exception:
+ # print('ffprobe failed at: {}'.format(video_path))
+ return None
+
+
+class VideoCompressor(VideoLoader):
+ def __init__(
+ self,
+ csv=None,
+ video_dict=None,
+ framerate=1,
+ size=112,
+ centercrop=False,
+ hflip=False,
+ crf=32,
+ **kwargs
+ ):
+ super().__init__(
+ csv,
+ video_dict,
+ framerate,
+ size,
+ centercrop,
+ hflip
+ )
+ self.crf = crf
+
+ def _run(self, cmd, output_file):
+ out, _ = (
+ cmd.output(filename=output_file, crf=self.crf)
+ .run(quiet=True)
+ )
+ video = None
+ return video
+
+
+class VideoDownloader(VideoCompressor):
+ """download"""
+ def __getitem__(self, idx):
+ video_path = self.csv['video_path'].values[idx]
+ output_file = self.csv['feature_path'].values[idx]
+ if not(os.path.isfile(output_file)):
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+ cmd = "wget -O" + output_file + " " + video_path
+ # import subprocess
+ # subprocess.check_output(
+ # cmd,
+ # stderr=subprocess.STDOUT, shell=True)
+ os.system(cmd)
+ return {'video': None, 'input': video_path, 'output': output_file}
+
+
+class AvKeyframeVideoCompressor(VideoLoader):
+ """extract keyframes from a video and save it as jpg.
+ TODO: consider to merge with `CodecProcessor`.
+ """
+ def __init__(
+ self,
+ csv=None,
+ video_dict=None,
+ framerate=1,
+ size=112,
+ centercrop=False,
+ max_num_frames=5,
+ **kwargs
+ ):
+ super().__init__(csv, video_dict, framerate, size, centercrop)
+ self.max_num_frames = max_num_frames
+
+ def _get_video_dim(self, video_fn):
+ """decord cannot probe the size of a video, we use pyav instead."""
+ import av
+ with av.open(video_fn) as container:
+ height = container.streams.video[0].codec_context.height
+ width = container.streams.video[0].codec_context.width
+ return height, width
+
+ def _get_output_dim(self, height, width):
+ """
+ keep the shorter side be `self.size`, strech the other.
+ """
+ if height >= width:
+ return int(height * self.size / width), self.size
+ else:
+ return self.size, int(width * self.size / height)
+
+ def __getitem__(self, idx):
+ import av
+ video_path = self.csv['video_path'].values[idx]
+ output_file = self.csv['feature_path'].values[idx]
+ if not(os.path.isdir(output_file)) and os.path.isfile(video_path):
+ try:
+ h, w = self._get_video_dim(video_path)
+ except Exception:
+ print('probe failed at: {}'.format(video_path))
+ return {'video': th.zeros(1), 'input': video_path,
+ 'output': output_file}
+
+ try:
+ height, width = self._get_output_dim(h, w)
+
+ # new for av.
+ with av.open(video_path) as container:
+ container.streams.video[0].thread_type = "AUTO"
+ container.streams.video[0].codec_context.height = height
+ container.streams.video[0].codec_context.width = width
+ if self.framerate == 0: # keyframe.
+ container.streams.video[0].codec_context.skip_frame = 'NONKEY'
+ frames = []
+ for frame in container.decode(video=0):
+ frames.append(frame)
+ frames = random.sample(frames, self.max_num_frames)
+
+ os.makedirs(output_file, exist_ok=True)
+ for frame in frames:
+ frame.to_image().save(
+ os.path.join(
+ output_file,
+ "%04d.jpg" % frame.index))
+ except Exception:
+ print('extract failed at: {}'.format(video_path))
+ return {'video': th.zeros(1), 'input': video_path,
+ 'output': output_file}
+ video = th.zeros(1)
+ return {'video': video, 'input': video_path, 'output': output_file}
diff --git a/examples/MMPT/setup.py b/examples/MMPT/setup.py
new file mode 100644
index 0000000000..a9a82296ea
--- /dev/null
+++ b/examples/MMPT/setup.py
@@ -0,0 +1,24 @@
+import setuptools
+
+with open("README.md", "r") as fh:
+ long_description = fh.read()
+
+setuptools.setup(
+ name="mmpt",
+ version="0.0.1",
+ author="Hu Xu, Po-yao Huang",
+ author_email="huxu@fb.com",
+ description="A package for multimodal pretraining.",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url="https://github.com/pytorch/fairseq/examples/MMPT",
+ packages=setuptools.find_packages(),
+ install_requires=[
+ ],
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: CC-BY-NC",
+ "Operating System :: OS Independent",
+ ],
+ python_requires='>=3.6',
+)
diff --git a/examples/MMPT/videoclip.png b/examples/MMPT/videoclip.png
new file mode 100644
index 0000000000..50dd0abfe4
Binary files /dev/null and b/examples/MMPT/videoclip.png differ
diff --git a/examples/MMPT/vlm.png b/examples/MMPT/vlm.png
new file mode 100644
index 0000000000..55c97dbc9f
Binary files /dev/null and b/examples/MMPT/vlm.png differ
diff --git a/examples/__init__.py b/examples/__init__.py
index 80d95f5fe7..44bb24ae61 100644
--- a/examples/__init__.py
+++ b/examples/__init__.py
@@ -3,4 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from fairseq.version import __version__ # noqa
+try:
+ from fairseq.version import __version__ # noqa
+except ImportError:
+ pass
diff --git a/examples/adaptive_span/README.md b/examples/adaptive_span/README.md
new file mode 100644
index 0000000000..d5224fb289
--- /dev/null
+++ b/examples/adaptive_span/README.md
@@ -0,0 +1,90 @@
+# Adaptive Span
+
+Adaptive Span is a novel self-attention mechanism that can learn its optimal
+attention span. This allows us to extend significantly the maximum context size
+used in Transformer, while maintaining control over their memory footprint
+and computational time. It uses the Truncated BPTT technique for training,
+as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
+
+Adaptive Span was introduced by paper:
+[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
+which achieved state-of-the-art language modeling results at the time of publication.
+
+We manage to reproduce their result in fairseq and keep most of the
+[original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
+You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
+
+##### 0. Setup
+
+First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
+from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
+You can download the dataset, and then run:
+```bash
+fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
+ --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
+ --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
+```
+
+##### 1. Train a Adaptive Span model on Enwik8
+
+We will train a 12-layer Adaptive Span model following the [hyperparameters
+used in the original
+paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
+
+The following command assumes 4 GPUs, so that the total batch size is 64
+sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
+ --user-dir examples/adaptive_span \
+ --data ~/data/enwik8/data-bin/ \
+ --fp16 --fp16-no-flatten-grads --max-update 600000 \
+ --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
+ --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
+ --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
+ --validate-interval-updates 1000 \
+ --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
+ --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
+ --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
+```
+This should land around 1.05 on validation, 1.03 on test. You can lower the
+--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
+improvement to the transformerXL baseline here.
+If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
+and simulate training on 4 GPUs.
+You can also reproduce the transformerXL result on enwik8 using this code base.
+It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
+You can try by
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
+ --user-dir examples/truncated_bptt \
+ ~/data/enwik8/data-bin/ \
+ --task truncated_bptt_lm --fp16 --max-update 400000 \
+ --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
+ --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
+ --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
+ --lr-scheduler cosine --warmup-updates 0 \
+ --lr 0.0 --lr 0.00025 --batch-size 15 \
+ --update-freq 1 --seed 2 --log-format json --log-interval 25 \
+ --fp16
+```
+
+##### 2. Evaluate
+For Adaptive Span:
+```bash
+fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
+ --user-dir examples/adaptive_span \
+ --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
+```
+For Transformer-XL evaluation:
+```bash
+fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
+ --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
+ --tokens-per-sample 80 \
+ --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
+ --gen-subset valid
+```
+
+*Note:* During training the model saw 512 tokens of context
+(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
+settings from [the original
+paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
diff --git a/examples/adaptive_span/__init__.py b/examples/adaptive_span/__init__.py
new file mode 100644
index 0000000000..e0a142a769
--- /dev/null
+++ b/examples/adaptive_span/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import importlib
+import os
+
+# automatically import any Python files in the current directory
+cur_dir = os.path.dirname(__file__)
+for file in os.listdir(cur_dir):
+ path = os.path.join(cur_dir, file)
+ if (
+ not file.startswith("_")
+ and not file.startswith(".")
+ and (file.endswith(".py") or os.path.isdir(path))
+ ):
+ mod_name = file[: file.find(".py")] if file.endswith(".py") else file
+ module = importlib.import_module(__name__ + "." + mod_name)
diff --git a/examples/adaptive_span/adagrad_with_grad_clip.py b/examples/adaptive_span/adagrad_with_grad_clip.py
new file mode 100644
index 0000000000..585ce184ab
--- /dev/null
+++ b/examples/adaptive_span/adagrad_with_grad_clip.py
@@ -0,0 +1,128 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch.optim import Adagrad
+
+from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
+
+
+@register_optimizer("adagrad_with_grad_clip")
+class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
+ def __init__(self, args, params):
+ super().__init__(args)
+ self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
+
+ @staticmethod
+ def add_args(parser):
+ """Add optimizer-specific arguments to the parser."""
+ # fmt: off
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
+ help='weight decay')
+ parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
+ help='internal grad clip')
+ # fmt: on
+
+ @property
+ def optimizer_config(self):
+ """
+ Return a kwarg dictionary that will be used to override optimizer
+ args stored in checkpoints. This allows us to load a checkpoint and
+ resume training using a different set of optimizer args, e.g., with a
+ different learning rate.
+ """
+ return {
+ "lr": self.args.lr[0],
+ "weight_decay": self.args.weight_decay,
+ "grad_clip": self.args.adagrad_clip,
+ }
+
+ @property
+ def supports_flat_params(self):
+ return False
+
+
+def _clip_grad(clr, grad, group_grad_clip):
+ if group_grad_clip > 0:
+ norm = grad.norm(2).item()
+ if norm > group_grad_clip:
+ clr *= group_grad_clip / (norm + 1e-10)
+ return clr
+
+
+class AdagradWithGradClip(Adagrad):
+ """Adagrad algorithm with custom gradient clipping"""
+
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ grad_clip=0,
+ ):
+ Adagrad.__init__(
+ self,
+ params,
+ lr=lr,
+ lr_decay=lr_decay,
+ weight_decay=weight_decay,
+ initial_accumulator_value=initial_accumulator_value,
+ )
+ self.defaults["grad_clip"] = grad_clip
+ self.param_groups[0].setdefault("grad_clip", grad_clip)
+
+ def step(self, closure=None):
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+
+ grad = p.grad.data
+ state = self.state[p]
+
+ state["step"] += 1
+
+ if group["weight_decay"] != 0:
+ if p.grad.data.is_sparse:
+ raise RuntimeError(
+ "weight_decay option is "
+ "not compatible with sparse "
+ "gradients"
+ )
+ grad = grad.add(group["weight_decay"], p.data)
+
+ clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
+
+ # clip
+ clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
+
+ if grad.is_sparse:
+ # the update is non-linear so indices must be unique
+ grad = grad.coalesce()
+ grad_indices = grad._indices()
+ grad_values = grad._values()
+ size = grad.size()
+
+ def make_sparse(values):
+ constructor = grad.new
+ if grad_indices.dim() == 0 or values.dim() == 0:
+ return constructor().resize_as_(grad)
+ return constructor(grad_indices, values, size)
+
+ state["sum"].add_(make_sparse(grad_values.pow(2)))
+ std = state["sum"]._sparse_mask(grad)
+ std_values = std._values().sqrt_().add_(1e-10)
+ p.data.add_(-clr, make_sparse(grad_values / std_values))
+ else:
+ state["sum"].addcmul_(1, grad, grad)
+ std = state["sum"].sqrt().add_(1e-10)
+ p.data.addcdiv_(-clr, grad, std)
+
+ return loss
diff --git a/examples/adaptive_span/adaptive_span_attention.py b/examples/adaptive_span/adaptive_span_attention.py
new file mode 100644
index 0000000000..07f757bb8e
--- /dev/null
+++ b/examples/adaptive_span/adaptive_span_attention.py
@@ -0,0 +1,160 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class AdaptiveMask(nn.Module):
+ """Soft masking function for adaptive size.
+ It masks out the last K values of an input. The masking value
+ goes from 1 to 0 gradually, so K can be learned with
+ back-propagation.
+ Args:
+ max_size: maximum size (i.e. input dimension)
+ ramp_size: size of the ramp going from 0 to 1
+ init_val: initial size proportion not to be masked out
+ shape: learn multiple sizes independent of each other
+ """
+
+ def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
+ nn.Module.__init__(self)
+ self._max_size = max_size
+ self._ramp_size = ramp_size
+ self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
+ mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
+ self.register_buffer("mask_template", mask_template)
+
+ def forward(self, x):
+ mask = self.mask_template.float() + self.current_val.float() * self._max_size
+ mask = mask / self._ramp_size + 1
+ mask = mask.clamp(0, 1)
+ if x.size(-1) < self._max_size:
+ # the input could have been trimmed beforehand to save computation
+ mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
+ x = (x * mask).type_as(x)
+ return x
+
+ def get_current_max_size(self, include_ramp=True):
+ current_size = math.ceil(self.current_val.max().item() * self._max_size)
+ if include_ramp:
+ current_size += self._ramp_size
+ current_size = max(0, min(self._max_size, current_size))
+ return current_size
+
+ def get_current_avg_size(self, include_ramp=True):
+ current_size = math.ceil(
+ self.current_val.float().mean().item() * self._max_size
+ )
+ if include_ramp:
+ current_size += self._ramp_size
+ current_size = max(0, min(self._max_size, current_size))
+ return current_size
+
+ def clamp_param(self):
+ """this need to be called after each update"""
+ self.current_val.data.clamp_(0, 1)
+
+
+class AdaptiveSpan(nn.Module):
+ """Adaptive attention span for Transformerself.
+ This module learns an attention span length from data for each
+ self-attention head.
+ Args:
+ attn_span: maximum attention span
+ adapt_span_loss: loss coefficient for the span length
+ adapt_span_ramp: length of the masking ramp
+ adapt_span_init: initial size ratio
+ adapt_span_cache: adapt cache size to reduce memory usage
+ """
+
+ def __init__(
+ self,
+ attn_span,
+ adapt_span_ramp,
+ adapt_span_init,
+ n_head,
+ adapt_span_layer,
+ **kargs
+ ):
+ nn.Module.__init__(self)
+ self._max_span = attn_span
+ self._n_head = n_head
+ self._adapt_span_layer = adapt_span_layer
+ if self._adapt_span_layer:
+ self._mask = AdaptiveMask(
+ max_size=self._max_span,
+ ramp_size=adapt_span_ramp,
+ init_val=adapt_span_init,
+ )
+ else:
+ self._mask = AdaptiveMask(
+ max_size=self._max_span,
+ ramp_size=adapt_span_ramp,
+ init_val=adapt_span_init,
+ shape=(n_head, 1, 1),
+ )
+
+ def forward(self, attn, normalize=True):
+ """mask attention with the right span"""
+ # batch and head dimensions are merged together, so separate them first
+ self.clamp_param()
+ if self._adapt_span_layer:
+ attn = self._mask(attn)
+ else:
+ B = attn.size(0) # batch size
+ M = attn.size(1) # block size
+ attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
+ attn = self._mask(attn)
+ attn = attn.view(B, M, -1)
+ return attn
+
+ def get_trim_len(self):
+ """how much of memory can be trimmed to reduce computation"""
+ L = self._max_span
+ trim_len = min(L - 1, L - self._mask.get_current_max_size())
+ # too fine granularity might be bad for the memory management
+ trim_len = math.floor(trim_len / 64) * 64
+ return trim_len
+
+ def trim_memory(self, query, key, value, key_pe):
+ """trim out unnecessary memory beforehand to reduce computation"""
+ trim_len = self.get_trim_len()
+ cache_size = key.size(1) - query.size(1)
+ trim_len_cache = trim_len - (self._max_span - cache_size)
+ if trim_len_cache > 0:
+ key = key[:, trim_len_cache:, :]
+ value = value[:, trim_len_cache:, :]
+ elif trim_len_cache < 0:
+ # cache is too short! this happens when validation resumes
+ # after a lot of updates.
+ key = F.pad(key, [0, 0, -trim_len_cache, 0])
+ value = F.pad(value, [0, 0, -trim_len_cache, 0])
+ if trim_len > 0:
+ if key_pe is not None:
+ key_pe = key_pe[:, :, trim_len:]
+ return key, value, key_pe
+
+ def get_cache_size(self):
+ """determine how long the cache should be"""
+ trim_len = self.get_trim_len()
+ # give a buffer of 64 steps since a span might increase
+ # in future updates
+ return min(self._max_span, self._max_span - trim_len + 64)
+
+ def get_loss(self):
+ """a loss term for regularizing the span length"""
+ return self._max_span * self._mask.current_val.float().mean()
+
+ def get_current_max_span(self):
+ return self._mask.get_current_max_size()
+
+ def get_current_avg_span(self):
+ return self._mask.get_current_avg_size()
+
+ def clamp_param(self):
+ self._mask.clamp_param()
diff --git a/examples/adaptive_span/adaptive_span_loss.py b/examples/adaptive_span/adaptive_span_loss.py
new file mode 100644
index 0000000000..fe95b0d949
--- /dev/null
+++ b/examples/adaptive_span/adaptive_span_loss.py
@@ -0,0 +1,107 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from dataclasses import dataclass
+
+import torch.nn.functional as F
+from fairseq import utils
+from fairseq.logging import metrics
+from fairseq.criterions import register_criterion
+from fairseq.criterions.cross_entropy import CrossEntropyCriterion
+from fairseq.dataclass import FairseqDataclass
+from omegaconf import II
+
+
+@dataclass
+class AdaptiveSpanCriterionConfig(FairseqDataclass):
+ sentence_avg: bool = II("optimization.sentence_avg")
+
+
+@register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
+class AdaptiveSpanCriterion(CrossEntropyCriterion):
+ def __init__(self, task, sentence_avg):
+ super().__init__(task, sentence_avg)
+
+ def forward(self, model, sample, reduce=True):
+ """Compute the loss for the given sample.
+
+ Returns a tuple with three elements:
+ 1) the loss here is summed, different from the adaptive span code
+ 2) the sample size, which is used as the denominator for the gradient
+ 3) logging outputs to display while training
+ """
+ net_output = model(**sample["net_input"])
+ loss, aux_loss, avg_span, max_span = self.compute_loss(
+ model, net_output, sample, reduce=reduce
+ )
+ sample_size = (
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
+ )
+ loss /= sample_size
+ total_loss = loss + aux_loss
+ sample_size = 1
+
+ logging_output = {
+ "loss": loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample["target"].size(0),
+ "sample_size": sample_size,
+ "total_loss": total_loss.data,
+ "avg_span": avg_span * sample_size,
+ "max_span": max_span * sample_size,
+ }
+ return total_loss, sample_size, logging_output
+
+ def compute_loss(self, model, net_output, sample, reduce=True):
+ loss, _ = super().compute_loss(model, net_output, sample, reduce)
+ aux_loss = model.get_aux_loss()
+ avg_span = model.get_current_avg_span()
+ max_span = model.get_current_max_span()
+ return loss, aux_loss, avg_span, max_span
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ """Aggregate logging outputs from data parallel training."""
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+ total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
+ avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
+ max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
+
+ # we divide by log(2) to convert the loss from base e to base 2
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
+ metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
+ # total loss contains the L1 norm on adaptive-span
+ metrics.log_scalar(
+ "total_loss",
+ total_loss_sum / sample_size / math.log(2),
+ sample_size,
+ round=3,
+ )
+ if sample_size != ntokens:
+ metrics.log_scalar(
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
+ )
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
+ )
+ else:
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
+ )
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return True
diff --git a/examples/adaptive_span/adaptive_span_model.py b/examples/adaptive_span/adaptive_span_model.py
new file mode 100644
index 0000000000..d96c95b85d
--- /dev/null
+++ b/examples/adaptive_span/adaptive_span_model.py
@@ -0,0 +1,263 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq.modules.layer_norm import LayerNorm
+
+from .adaptive_span_attention import AdaptiveSpan
+
+# Size notations:
+# B = batch_size, H = d_model, M = block_size, L = attn_span
+
+
+def _skew(X, pad_value):
+ """shift every row 1 step to right"""
+ # X = B x M x L
+ B, M, L = X.size()
+ X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
+ X = X.view(B, -1) # B x ML+MM+M
+ X = X[:, :-M] # B x ML+MM
+ X = X.view(B, M, M + L) # B x M x L+M
+ return X
+
+
+def _unskew(X):
+ """reverse _skew operation"""
+ # X = B x M x L+M
+ B, M, L = X.size()
+ L -= M
+ X = X.view(B, -1) # B x ML+MM
+ X = F.pad(X, (0, M)) # B x ML+MM+M
+ X = X.view(B, M, M + L + 1) # B x M x L+M+1
+ X = X[:, :, :L] # B x M x L
+ return X
+
+
+class SeqAttention(nn.Module):
+ """Sequential self-attention layer.
+ Each token will attend to its previous fixed number of steps.
+ Note that attention doesn't include the current step itself.
+ """
+
+ def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
+ nn.Module.__init__(self)
+ self.dropout = nn.Dropout(dropout)
+ self.d_model = d_model # size of a single head
+ self.attn_span = attn_span
+ self.adaptive_span = AdaptiveSpan(
+ attn_span=attn_span,
+ n_head=n_head,
+ adapt_span_layer=adapt_span_layer,
+ **kargs
+ )
+
+ def forward(self, query, key, value, key_pe):
+ # query size = B x M x H
+ # key, value sizes = B x (M+L) x H
+
+ key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
+
+ # compute attention from context
+ # B x M (dest) x (M+L) (src)
+ attn_cont = torch.matmul(query, key.transpose(-1, -2))
+ attn_cont = _unskew(attn_cont) # B x M x L
+
+ # compute the effect of position embedding
+ attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
+ attn = attn_cont + attn_pos
+
+ attn = attn / math.sqrt(self.d_model) # B x M X L_pos
+
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
+
+ # trim attention lengths according to the learned span
+ attn = self.adaptive_span(attn)
+
+ attn = self.dropout(attn) # B x M X L_pos
+
+ attn_cont = _skew(attn, 0) # B x M X (L+M)
+ out = torch.matmul(attn_cont, value) # B x M x H
+ return out
+
+ def get_cache_size(self):
+ return self.adaptive_span.get_cache_size()
+
+
+class MultiHeadSeqAttention(nn.Module):
+ def __init__(self, d_model, n_head, **kargs):
+ nn.Module.__init__(self)
+ assert d_model % n_head == 0
+ self.n_head = n_head
+ self.head_dim = d_model // n_head
+ self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
+ self.proj_query = nn.Linear(d_model, d_model, bias=False)
+ nn.init.xavier_normal_(self.proj_query.weight)
+ self.proj_out = nn.Linear(d_model, d_model, bias=False)
+ nn.init.xavier_normal_(self.proj_out.weight)
+ self.proj_val = nn.Linear(d_model, d_model, bias=False)
+ nn.init.xavier_normal_(self.proj_val.weight)
+ self.proj_key = nn.Linear(d_model, d_model, bias=False)
+ nn.init.xavier_normal_(self.proj_key.weight)
+
+ def head_reshape(self, x):
+ K = self.n_head
+ D = self.head_dim
+ x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
+ x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
+ x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
+ return x
+
+ def forward(self, query, key, value, key_pe):
+ B = query.size(0)
+ K = self.n_head
+ D = self.head_dim
+ M = query.size(1)
+
+ query = self.proj_query(query)
+ query = self.head_reshape(query)
+ value = self.proj_val(value)
+ value = self.head_reshape(value)
+ key = self.proj_key(key)
+ key = self.head_reshape(key)
+
+ out = self.attn(query, key, value, key_pe) # B_K x M x D
+ out = out.view(B, K, M, D) # B x K x M x D
+ out = out.transpose(1, 2).contiguous() # B x M x K x D
+ out = out.view(B, M, -1) # B x M x K_D
+ out = self.proj_out(out)
+ return out
+
+
+class FeedForwardLayer(nn.Module):
+ def __init__(self, d_model, d_inner, dropout, **kargs):
+ nn.Module.__init__(self)
+ self.fc1 = nn.Linear(d_model, d_inner)
+ self.fc2 = nn.Linear(d_inner, d_model)
+ nn.init.xavier_uniform_(self.fc1.weight)
+ nn.init.xavier_uniform_(self.fc2.weight)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, h):
+ h1 = F.relu(self.fc1(h))
+ h1 = self.dropout(h1)
+ h2 = self.fc2(h1)
+ return h2
+
+
+class TransformerSeqLayer(nn.Module):
+ def __init__(self, d_model, **kargs):
+ nn.Module.__init__(self)
+ self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
+ self.norm1 = LayerNorm(d_model)
+ self.ff = FeedForwardLayer(d_model=d_model, **kargs)
+ self.norm2 = LayerNorm(d_model)
+
+ def forward(self, h, h_cache, key_pe):
+ # h = B x M x H
+ # h_cache = B x L x H
+ h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
+ attn_out = self.attn(h, h_all, h_all, key_pe)
+ h = self.norm1(h + attn_out) # B x M x H
+ if self.ff is not None:
+ ff_out = self.ff(h)
+ out = self.norm2(h + ff_out) # B x M x H
+ else:
+ out = h
+ return out
+
+ def get_cache_size(self):
+ return self.attn.attn.get_cache_size()
+
+
+class TransformerSeq(nn.Module):
+ def __init__(
+ self,
+ vocab_size,
+ d_model,
+ n_head,
+ n_layer,
+ attn_span,
+ emb_dropout,
+ aux_loss_scaler,
+ adapt_span_layer,
+ **kargs
+ ):
+ nn.Module.__init__(self)
+ # token embeddings
+ self.in_emb = nn.Embedding(vocab_size, d_model)
+ nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
+ self.out_emb = nn.Linear(d_model, vocab_size)
+ self.aux_loss_scaler = aux_loss_scaler
+ if emb_dropout > 0:
+ self.emb_dropout = nn.Dropout(emb_dropout)
+ else:
+ self.emb_dropout = None
+ # position embeddings
+ self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
+
+ self.layers = nn.ModuleList()
+ self.layers.extend(
+ TransformerSeqLayer(
+ d_model=d_model,
+ n_head=n_head,
+ attn_span=attn_span,
+ adapt_span_layer=adapt_span_layer,
+ **kargs
+ )
+ for _ in range(n_layer)
+ )
+
+ def forward(self, x, h_cache, target=None):
+ # x size = B x M
+ block_size = x.size(1)
+ h = self.in_emb(x) # B x M x H
+ if self.emb_dropout is not None:
+ h = self.emb_dropout(h)
+
+ h_cache_next = []
+ for l, layer in enumerate(self.layers):
+ cache_size = layer.attn.attn.get_cache_size()
+ if cache_size > block_size:
+ h_cache_next_l = torch.cat(
+ [h_cache[l][:, -cache_size + block_size :, :], h], dim=1
+ ).detach()
+ else:
+ h_cache_next_l = h[:, -cache_size:, :].detach()
+ h_cache_next.append(h_cache_next_l)
+ h = layer(h, h_cache[l], self.key_pe) # B x M x H
+
+ if self.emb_dropout is not None:
+ h = self.emb_dropout(h)
+
+ out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
+ dummy_loss = None
+
+ return out, h_cache_next, dummy_loss
+
+ def get_aux_loss(self):
+ loss = 0.0
+ for layer in self.layers:
+ loss += layer.attn.attn.adaptive_span.get_loss()
+ return self.aux_loss_scaler * loss
+
+ def get_current_max_span(self):
+ max_span = 0.0
+ for layer in self.layers:
+ max_span = max(
+ max_span, layer.attn.attn.adaptive_span.get_current_max_span()
+ )
+ return max_span
+
+ def get_current_avg_span(self):
+ avg_span = 0.0
+ for layer in self.layers:
+ avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
+ return avg_span / len(self.layers)
diff --git a/examples/adaptive_span/adaptive_span_model_wrapper.py b/examples/adaptive_span/adaptive_span_model_wrapper.py
new file mode 100644
index 0000000000..5b147fe11f
--- /dev/null
+++ b/examples/adaptive_span/adaptive_span_model_wrapper.py
@@ -0,0 +1,145 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+import torch
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import (
+ FairseqIncrementalDecoder,
+ FairseqLanguageModel,
+ register_model,
+)
+from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class AdaptiveSpanSmallConfig(FairseqDataclass):
+ # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
+ vocab_size: int = 50
+ d_model: int = 256
+ n_head: int = 4
+ d_inner: int = 1024
+ n_layer: int = 8
+ attn_span: int = 1024
+ dropout: float = 0.0
+ emb_dropout: float = 0.0
+ adapt_span_ramp: int = 32
+ adapt_span_init: float = 0.0
+ aux_loss_scaler: float = 0.000002
+ adapt_span_layer: bool = False
+
+
+@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
+class AdaptiveSpanTransformer(FairseqLanguageModel):
+ @classmethod
+ def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
+ return cls(AdaptiveSpanDecoder(cfg, task))
+
+ def get_aux_loss(self):
+ return self.decoder.get_aux_loss()
+
+ def get_current_max_span(self):
+ return self.decoder.get_current_max_span()
+
+ def get_current_avg_span(self):
+ return self.decoder.get_current_avg_span()
+
+
+class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
+ def __init__(self, cfg, task):
+
+ super().__init__(task.target_dictionary)
+
+ self.config = cfg
+ config = AdaptiveSpanSmallConfig(
+ vocab_size=len(task.target_dictionary),
+ d_model=cfg.d_model,
+ n_head=cfg.n_head,
+ d_inner=cfg.d_inner,
+ n_layer=cfg.n_layer,
+ attn_span=cfg.attn_span,
+ dropout=cfg.dropout,
+ emb_dropout=cfg.emb_dropout,
+ adapt_span_ramp=cfg.adapt_span_ramp,
+ adapt_span_init=cfg.adapt_span_init,
+ aux_loss_scaler=cfg.aux_loss_scaler,
+ adapt_span_layer=cfg.adapt_span_layer,
+ )
+ logger.info(config)
+ self.model = AdaptiveSpanTransformerModel(**config.__dict__)
+
+ self._mems = None
+
+ def forward(
+ self,
+ src_tokens,
+ incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
+ encoder_out=None,
+ ):
+ bsz = src_tokens.size(0)
+ if incremental_state is not None: # used during inference
+ mems = self.get_incremental_state("mems")
+ src_tokens = src_tokens[:, -1:] # only keep the most recent token
+ else:
+ mems = self._mems
+
+ if mems is None:
+ # first time init
+ mems = self.init_hid_cache(bsz)
+ output = self.model(x=src_tokens, h_cache=mems,)
+ if incremental_state is not None:
+ self.set_incremental_state(incremental_state, "mems", output[1])
+ else:
+ self._mems = output[1]
+ return (output[0],)
+
+ def max_positions(self):
+ return self.config.attn_span
+
+ def init_hid_cache(self, batch_sz):
+ hid = []
+ for layer in self.model.layers:
+ param = next(self.model.parameters())
+ h = torch.zeros(
+ batch_sz,
+ layer.get_cache_size(),
+ self.config.d_model,
+ dtype=param.dtype,
+ device=param.device,
+ )
+ hid.append(h)
+ return hid
+
+ def get_aux_loss(self):
+ return self.model.get_aux_loss()
+
+ def get_current_max_span(self):
+ return self.model.get_current_max_span()
+
+ def get_current_avg_span(self):
+ return self.model.get_current_avg_span()
+
+ def reorder_incremental_state(
+ self,
+ incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
+ new_order: torch.Tensor,
+ ):
+ """Reorder incremental state.
+
+ This will be called when the order of the input has changed from the
+ previous time step. A typical use case is beam search, where the input
+ order changes between time steps based on the selection of beams.
+ """
+ raise NotImplementedError("This is required for generation/beam search")
+ # mems = self.get_incremental_state(incremental_state, "mems")
+ # if mems is not None:
+ # new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
+ # self.set_incremental_state(incremental_state, "mems", new_mems)
diff --git a/examples/adaptive_span/truncated_bptt_lm_task.py b/examples/adaptive_span/truncated_bptt_lm_task.py
new file mode 120000
index 0000000000..a92da3a298
--- /dev/null
+++ b/examples/adaptive_span/truncated_bptt_lm_task.py
@@ -0,0 +1 @@
+../truncated_bptt/truncated_bptt_lm_task.py
\ No newline at end of file
diff --git a/examples/attention_head_selection/README.md b/examples/attention_head_selection/README.md
new file mode 100644
index 0000000000..2434f1fb21
--- /dev/null
+++ b/examples/attention_head_selection/README.md
@@ -0,0 +1,161 @@
+# Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling (Gong et al., 2021)
+
+[https://arxiv.org/pdf/2106.10840.pdf](https://arxiv.org/pdf/2106.10840.pdf)
+
+## Introduction
+
+We present attention head selection strategies in multilingual and multi-domain sequence modeling including text translation, speech recognition and speech translation tasks.
+
+Below is an example of training multilingual/multi-domain speech recognition models.
+
+## Data Preparation
+Prepare mTEDx data as in [mTEDx example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/mtedx_example.md) and CoVoST data as in [CoVoST example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/covost_example.md). Similarly prepare EuroParl data.
+
+
+## Training a multilingual ASR model with attention head selection
+
+```bash
+data_dir=
+train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
+valid_subset="valid_ar_ar_tedx,valid_de_de_tedx,valid_el_el_tedx,valid_es_es_tedx,valid_fr_fr_tedx,valid_it_it_tedx,valid_pt_pt_tedx,valid_ru_ru_tedx"
+strateg=
+
+fairseq-train ${data_dir} \
+ --user-dir examples/attention_head_selection/src \
+ --train-subset "${train_subset}" \
+ --valid-subset "${valid_subset}" \
+ --config-yaml 'config_asr.yaml' \
+ --arch 'head_selection_s2t_transformer_s' \
+ --task 'speech_to_text_head_selection' \
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
+ --lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
+ --lr 5e-4 \
+ --clip-norm 10.0 \
+ --seed 1 \
+ --max-epoch 400 \
+ --max-tokens 32000 \
+ --ignore-prefix-size 1 \
+ --dropout 0.3 \
+ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
+ --skip-invalid-size-inputs-valid-test \
+ --encoder-attn-head-select \
+ --total-encoder-attention-heads 8 \
+ --decoder-self-attn-head-select \
+ --total-decoder-attention-heads 8 \
+ --attn-head-select-strategy ${strategy} \
+ --task-type lang \
+```
+
+## Training a multi-domain ASR model with attention head selection
+
+```bash
+data_dir=
+train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
+valid_subset="dev_es_es_tedx,dev_fr_fr_tedx,dev_pt_pt_tedx,dev_it_it_tedx,dev_ru_ru_tedx,dev_el_el_tedx,dev_ar_ar_tedx,dev_de_de_tedx,dev_ar_ar_cv,dev_de_de_cv,dev_es_es_cv,dev_fr_fr_cv,dev_it_it_cv,dev_pt_pt_cv,dev_ru_ru_cv,dev_de_de_ep,dev_es_es_ep,dev_fr_fr_ep,dev_it_it_ep,dev_pt_pt_ep"
+strateg=
+
+fairseq-train ${data_dir} \
+ --user-dir examples/attention_head_selection/src \
+ --train-subset "${train_subset}" \
+ --valid-subset "${valid_subset}" \
+ --config-yaml 'config_asr.yaml' \
+ --arch head_selection_s2t_transformer_s \
+ --task speech_to_text_head_selection \
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
+ --lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
+ --lr 5e-4 \
+ --clip-norm 10.0 \
+ --seed 1 \
+ --max-epoch 400 \
+ --max-tokens 32000 \
+ --ignore-prefix-size 1 \
+ --dropout 0.3 \
+ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
+ --skip-invalid-size-inputs-valid-test \
+ --encoder-attn-head-select \
+ --total-encoder-attention-heads 8 \
+ --decoder-self-attn-head-select \
+ --total-decoder-attention-heads 8 \
+ --attn-head-select-strategy ${strategy} \
+ --task-type domain
+```
+
+## Inference in multilingual setting
+
+```bash
+MODEL_DIR=
+data_dir=
+gen_subset=
+train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
+last_n=10
+CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
+CHECKPOINT="_avg"
+RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
+if [ ! -d $RESULTS ]; then
+ mkdir -p $RESULTS
+fi;
+
+python scripts/average_checkpoints.py \
+ --inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
+ --output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
+
+fairseq-generate ${data_dir} \
+ --user-dir examples/attention_head_selection/src \
+ --arch 'head_selection_s2t_transformer_s' \
+ --task 'speech_to_text_head_selection' \
+ --train-subset ${train_subset} \
+ --gen-subset ${gen_subset} \
+ --path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
+ --config-yaml 'config_asr.yaml' \
+ --prefix-size 1 \
+ --max-tokens 40000 --beam 5 \
+ --skip-invalid-size-inputs-valid-test \
+ --results-path ${RESULTS} \
+ --scoring wer --wer-tokenizer 13a \
+ --wer-lowercase --wer-remove-punct --remove-bpe
+```
+
+## Inference in multi-domain setting
+
+```bash
+MODEL_DIR=
+data_dir=
+gen_subset=
+train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
+last_n=10
+CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
+CHECKPOINT="_avg"
+RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
+if [ ! -d $RESULTS ]; then
+ mkdir -p $RESULTS
+fi;
+
+python scripts/average_checkpoints.py \
+ --inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
+ --output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
+
+fairseq-generate ${data_dir} \
+ --user-dir examples/attention_head_selection/src \
+ --arch 'head_selection_s2t_transformer_s' \
+ --task 'speech_to_text_head_selection' \
+ --train-subset ${train_subset} \
+ --gen-subset ${gen_subset} \
+ --path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
+ --config-yaml 'config_asr.yaml' \
+ --prefix-size 1 \
+ --max-tokens 40000 --beam 5 \
+ --skip-invalid-size-inputs-valid-test \
+ --results-path ${RESULTS} \
+ --scoring wer --wer-tokenizer 13a \
+ --wer-lowercase --wer-remove-punct --remove-bpe
+```
+
+## Citation
+```bibtex
+@article{gong2021pay,
+ title={Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling},
+ author={Gong, Hongyu and Tang, Yun and Pino, Juan and Li, Xian},
+ journal={arXiv preprint arXiv:2106.10840},
+ year={2021}
+}
+'''
diff --git a/examples/latent_depth/src/loss/__init__.py b/examples/attention_head_selection/src/__init__.py
similarity index 100%
rename from examples/latent_depth/src/loss/__init__.py
rename to examples/attention_head_selection/src/__init__.py
diff --git a/examples/latent_depth/src/models/__init__.py b/examples/attention_head_selection/src/data/__init__.py
similarity index 100%
rename from examples/latent_depth/src/models/__init__.py
rename to examples/attention_head_selection/src/data/__init__.py
diff --git a/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py b/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py
new file mode 100644
index 0000000000..1f1823a7ac
--- /dev/null
+++ b/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py
@@ -0,0 +1,242 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from pathlib import Path
+from typing import Dict, List, Optional
+from dataclasses import dataclass
+
+import torch
+from fairseq.data import (
+ ConcatDataset,
+ Dictionary,
+ FairseqDataset,
+ ResamplingDataset
+)
+from fairseq.data.audio.data_cfg import S2TDataConfig
+from fairseq.data.audio.speech_to_text_dataset import (
+ SpeechToTextDatasetItem,
+ SpeechToTextDataset,
+ SpeechToTextDatasetCreator
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SpeechToTextDatasetItemWithDomain(SpeechToTextDatasetItem):
+ src_lang_id: Optional[torch.Tensor] = None
+ tgt_lang_id: Optional[torch.Tensor] = None
+ domain_id: Optional[torch.Tensor] = None
+
+
+class SpeechToTextDatasetWithDomain(SpeechToTextDataset):
+
+ def __init__(
+ self,
+ split: str,
+ is_train_split: bool,
+ cfg: S2TDataConfig,
+ audio_paths: List[str],
+ n_frames: List[int],
+ src_texts: Optional[List[str]] = None,
+ tgt_texts: Optional[List[str]] = None,
+ speakers: Optional[List[str]] = None,
+ src_langs: Optional[List[str]] = None,
+ tgt_langs: Optional[List[str]] = None,
+ ids: Optional[List[str]] = None,
+ tgt_dict: Optional[Dictionary] = None,
+ pre_tokenizer=None,
+ bpe_tokenizer=None,
+ n_frames_per_step=1,
+ speaker_to_id=None,
+ src_lang_ids: Optional[List[int]] = None,
+ tgt_lang_ids: Optional[List[int]] = None,
+ domain_ids: Optional[List[int]] = None
+ ):
+ super().__init__(
+ split, is_train_split, cfg, audio_paths, n_frames,
+ src_texts, tgt_texts, speakers, src_langs, tgt_langs,
+ ids, tgt_dict, pre_tokenizer, bpe_tokenizer,
+ n_frames_per_step, speaker_to_id
+ )
+ assert src_lang_ids is None or len(src_lang_ids) == self.n_samples
+ assert tgt_lang_ids is None or len(tgt_lang_ids) == self.n_samples
+ assert domain_ids is None or len(domain_ids) == self.n_samples
+
+ self.src_lang_ids = src_lang_ids
+ self.tgt_lang_ids = tgt_lang_ids
+ self.domain_ids = domain_ids
+
+ def __getitem__(self, index: int) -> SpeechToTextDatasetItemWithDomain:
+ item = super().__getitem__(index)
+ src_lang_id = self.src_lang_ids[index]
+ tgt_lang_id = self.tgt_lang_ids[index]
+ domain_id = self.domain_ids[index]
+ return SpeechToTextDatasetItemWithDomain(
+ index=item.index, source=item.source,
+ target=item.target, speaker_id=item.speaker_id,
+ src_lang_id=src_lang_id,
+ tgt_lang_id=tgt_lang_id,
+ domain_id=domain_id
+ )
+
+ def collater(
+ self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
+ ) -> Dict:
+ if len(samples) == 0:
+ return {}
+ out = super().collater(samples, return_order=True)
+ order = out["order"]
+ src_lang_ids = torch.tensor([x.src_lang_id for x in samples], dtype=torch.long).index_select(0, order)
+ tgt_lang_ids = torch.tensor([x.tgt_lang_id for x in samples], dtype=torch.long).index_select(0, order)
+ domain_ids = torch.tensor([x.domain_id for x in samples], dtype=torch.long).index_select(0, order)
+
+ out["src_lang_ids"] = src_lang_ids
+ out["tgt_lang_ids"] = tgt_lang_ids
+ out["domain_ids"] = domain_ids
+ if not return_order:
+ del out["order"]
+ return out
+
+
+class SpeechToTextDatasetCreatorWithDomain(SpeechToTextDatasetCreator):
+ KEY_SRC_LANG_ID, KEY_TGT_LANG_ID = "src_lang_id", "tgt_lang_id"
+ KEY_DOMAIN_ID = "domain_id"
+ # default values
+ DEFAULT_SRC_LANG_ID, DEFAULT_TGT_LANG_ID, DEFAULT_DOMAIN_ID = 0, 0, 0
+
+ @classmethod
+ def _from_list(
+ cls,
+ split_name: str,
+ is_train_split,
+ samples: List[Dict],
+ cfg: S2TDataConfig,
+ tgt_dict,
+ pre_tokenizer,
+ bpe_tokenizer,
+ n_frames_per_step,
+ speaker_to_id
+ ) -> SpeechToTextDatasetWithDomain:
+ audio_root = Path(cfg.audio_root)
+ ids = [s[cls.KEY_ID] for s in samples]
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
+ src_lang_ids = [s.get(cls.KEY_SRC_LANG_ID, cls.DEFAULT_SRC_LANG_ID) for s in samples]
+ tgt_lang_ids = [s.get(cls.KEY_TGT_LANG_ID, cls.DEFAULT_TGT_LANG_ID) for s in samples]
+ domain_ids = [s.get(cls.KEY_DOMAIN_ID, cls.DEFAULT_DOMAIN_ID) for s in samples]
+ return SpeechToTextDatasetWithDomain(
+ split_name,
+ is_train_split,
+ cfg,
+ audio_paths,
+ n_frames,
+ src_texts=src_texts,
+ tgt_texts=tgt_texts,
+ speakers=speakers,
+ src_langs=src_langs,
+ tgt_langs=tgt_langs,
+ ids=ids,
+ tgt_dict=tgt_dict,
+ pre_tokenizer=pre_tokenizer,
+ bpe_tokenizer=bpe_tokenizer,
+ n_frames_per_step=n_frames_per_step,
+ speaker_to_id=speaker_to_id,
+ src_lang_ids=src_lang_ids,
+ tgt_lang_ids=tgt_lang_ids,
+ domain_ids=domain_ids
+ )
+
+ @classmethod
+ def _load_samples_from_tsv(
+ cls,
+ root: str,
+ split: str,
+ src_lang_map,
+ tgt_lang_map,
+ domain_map
+ ):
+ # metadata from split
+ _, src_lang, tgt_lang, domain = split.split("_")
+ src_lang_id = src_lang_map[src_lang]
+ tgt_lang_id = tgt_lang_map[tgt_lang]
+ domain_id = domain_map[domain]
+
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
+ for s in samples:
+ s.update({
+ cls.KEY_SRC_LANG_ID: src_lang_id,
+ cls.KEY_TGT_LANG_ID: tgt_lang_id,
+ cls.KEY_DOMAIN_ID: domain_id
+ })
+ return samples
+
+ @classmethod
+ def _from_tsv(
+ cls,
+ root: str,
+ cfg: S2TDataConfig,
+ split: str,
+ tgt_dict,
+ is_train_split: bool,
+ pre_tokenizer,
+ bpe_tokenizer,
+ n_frames_per_step,
+ speaker_to_id,
+ src_lang_map: Dict[str, int],
+ tgt_lang_map: Dict[str, int],
+ domain_map: Dict[str, int]
+ ) -> SpeechToTextDatasetItemWithDomain:
+ samples = cls._load_samples_from_tsv(
+ root, split, src_lang_map,
+ tgt_lang_map, domain_map
+ )
+ return cls._from_list(
+ split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer,
+ bpe_tokenizer, n_frames_per_step, speaker_to_id
+ )
+
+ @classmethod
+ def from_tsv(
+ cls,
+ root: str,
+ cfg: S2TDataConfig,
+ splits: str,
+ tgt_dict,
+ pre_tokenizer,
+ bpe_tokenizer,
+ is_train_split: bool,
+ epoch: int,
+ seed: int,
+ src_lang_map: Dict[str, int],
+ tgt_lang_map: Dict[str, int],
+ domain_map: Dict[str, int],
+ n_frames_per_step: int = 1,
+ speaker_to_id=None
+ ) -> SpeechToTextDatasetWithDomain:
+ datasets = [
+ cls._from_tsv(
+ root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, src_lang_map, tgt_lang_map, domain_map
+ )
+ for split in splits.split(",")
+ ]
+
+ if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
+ # temperature-based sampling
+ size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
+ datasets = [
+ ResamplingDataset(
+ d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
+ )
+ for r, d in zip(size_ratios, datasets)
+ ]
+
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
diff --git a/examples/latent_depth/src/modules/__init__.py b/examples/attention_head_selection/src/loss/__init__.py
similarity index 100%
rename from examples/latent_depth/src/modules/__init__.py
rename to examples/attention_head_selection/src/loss/__init__.py
diff --git a/examples/attention_head_selection/src/loss/attention_head_selection.py b/examples/attention_head_selection/src/loss/attention_head_selection.py
new file mode 100644
index 0000000000..4ba33954d0
--- /dev/null
+++ b/examples/attention_head_selection/src/loss/attention_head_selection.py
@@ -0,0 +1,27 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+from torch.nn.modules.loss import _Loss
+
+
+class HeadSelectionLoss(_Loss):
+
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+ self.kl_weight = getattr(args, "kl_weight", 0.0)
+
+ def forward(self, head_samples, sample_sizes, prior=0.5, eps=1e-7):
+ """
+ head_scores: (num_tasks, num_layers, num_heads)
+ sample_sizes: (num_tasks, )
+ """
+ kl_loss = (head_samples * (torch.log(head_samples + eps) - math.log(prior))).sum(-1).sum(-1)
+ kl_loss /= (torch.numel(head_samples) / head_samples.size(0))
+ kl_loss = self.kl_weight * torch.matmul(kl_loss, sample_sizes)
+ return kl_loss
diff --git a/examples/linformer/src/models/__init__.py b/examples/attention_head_selection/src/models/__init__.py
similarity index 100%
rename from examples/linformer/src/models/__init__.py
rename to examples/attention_head_selection/src/models/__init__.py
diff --git a/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py b/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py
new file mode 100644
index 0000000000..2c7ed89e89
--- /dev/null
+++ b/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py
@@ -0,0 +1,170 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import Dict, List, Optional
+from pathlib import Path
+import torch.nn as nn
+from torch import Tensor
+from fairseq import checkpoint_utils
+
+from fairseq.models import register_model, register_model_architecture
+from fairseq.utils import safe_hasattr
+from fairseq.models.speech_to_text.s2t_transformer import (
+ S2TTransformerModel,
+ S2TTransformerEncoder,
+ TransformerDecoderScriptable
+)
+from fairseq.models.speech_to_text.s2t_transformer import base_architecture as s2t_base_architecture
+
+from ..modules.attn_head_selector import AttnHeadSelector
+from ..modules.head_selection_transformer_layer import HeadSelectionTransformerEncoderLayer
+from .head_selection_transformer import HeadSelectionTransformerDecoder
+
+
+logger = logging.getLogger(__name__)
+
+
+@register_model("head_selection_s2t_transformer")
+class HeadSelectionS2TTransformerModel(S2TTransformerModel):
+ """
+ Head selection implemented in S2TTransformer
+ """
+ def __init__(self, encoder, decoder):
+ super().__init__(encoder, decoder)
+
+ @staticmethod
+ def add_args(parser):
+ S2TTransformerModel.add_args(parser)
+ # encoder head selection
+ parser.add_argument(
+ "--encoder-attn-head-select",
+ action="store_true",
+ default=False,
+ help="encoder head selection"
+ )
+ parser.add_argument(
+ "--total-encoder-attention-heads",
+ type=int,
+ help="total number of encoder attention heads"
+ )
+ # decoder self attention selection
+ parser.add_argument(
+ "--decoder-self-attn-head-select",
+ action="store_true",
+ default=False,
+ help="decoder self-attention head selection"
+ )
+ # decoder-encoder attention selection
+ parser.add_argument(
+ "--dec-enc-attn-head-select",
+ action="store_true",
+ default=False,
+ help="decoder-encoder attention head selection"
+ )
+ parser.add_argument(
+ "--total-decoder-attention-heads",
+ type=int,
+ help="total number of decoder attention heads"
+ )
+ # selection strategy
+ parser.add_argument(
+ "--attn-head-select-strategy",
+ type=str,
+ help="attention head selection strategy, subset or group"
+ )
+
+ @classmethod
+ def build_encoder(cls, args):
+ if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
+ encoder = HeadSelectionS2TTransformerEncoder(args)
+ else:
+ encoder = S2TTransformerEncoder(args)
+ pretraining_path = getattr(args, "load_pretrained_encoder_from", None)
+ if pretraining_path is not None:
+ if not Path(pretraining_path).exists():
+ logger.warning(
+ f"skipped pretraining because {pretraining_path} does not exist"
+ )
+ else:
+ encoder = checkpoint_utils.load_pretrained_component_from_model(
+ component=encoder, checkpoint=pretraining_path
+ )
+ logger.info(f"loaded pretrained encoder from: {pretraining_path}")
+ return encoder
+
+ @classmethod
+ def build_decoder(cls, args, task, embed_tokens):
+ if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select):
+ return HeadSelectionTransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
+ else:
+ return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
+
+
+class HeadSelectionS2TTransformerEncoder(S2TTransformerEncoder):
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.attn_head_selector = AttnHeadSelector(
+ args.encoder_tasks,
+ args.encoder_layers,
+ args.total_encoder_attention_heads,
+ args.encoder_attention_heads,
+ args.attn_head_select_strategy,
+ )
+ self.task_ids = None
+ self.transformer_layers = nn.ModuleList([
+ HeadSelectionTransformerEncoderLayer(args, layer_idx, attn_head_selector=self.attn_head_selector) for layer_idx in range(args.encoder_layers)
+ ])
+
+ def set_task_ids(self, task_ids):
+ self.task_ids = task_ids
+
+ def _forward(self, src_tokens, src_lengths, return_all_hiddens=False):
+ self.attn_head_selector.head_select(self.task_ids)
+ return super()._forward(src_tokens, src_lengths, return_all_hiddens)
+
+
+class HeadSelectionTransformerDecoderScriptable(HeadSelectionTransformerDecoder):
+ def extract_features(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ ):
+ # call scriptable method from parent class
+ x, _ = self.extract_features_scriptable(
+ prev_output_tokens,
+ encoder_out,
+ incremental_state,
+ full_context_alignment,
+ alignment_layer,
+ alignment_heads,
+ )
+ return x, None
+
+
+@register_model_architecture(model_name="head_selection_s2t_transformer", arch_name="head_selection_s2t_transformer")
+def base_architecture(args):
+ s2t_base_architecture(args)
+ args.encoder_attn_head_select = getattr(args, "encoder_attn_head_select", False)
+ args.decoder_self_attn_head_select = getattr(args, "decoder_self_attn_head_select", False)
+ args.dec_enc_attn_head_select = getattr(args, "dec_enc_attn_head_select", False)
+ args.total_encoder_attention_heads = getattr(args, "total_encoder_attention_heads", 8)
+ args.total_decoder_attention_heads = getattr(args, "total_decoder_attention_heads", 8)
+ args.attn_head_select_strategy = getattr(args, "attn_head_select_strategy", "group")
+
+
+@register_model_architecture("head_selection_s2t_transformer", "head_selection_s2t_transformer_s")
+def head_selection_s2t_transformer_s(args):
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
+ args.dropout = getattr(args, "dropout", 0.1)
+ base_architecture(args)
diff --git a/examples/attention_head_selection/src/models/head_selection_transformer.py b/examples/attention_head_selection/src/models/head_selection_transformer.py
new file mode 100644
index 0000000000..b9d595699d
--- /dev/null
+++ b/examples/attention_head_selection/src/models/head_selection_transformer.py
@@ -0,0 +1,215 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, List, Dict, Optional
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from fairseq.utils import safe_hasattr
+from fairseq.models.transformer import (
+ TransformerModel,
+ TransformerEncoder,
+ TransformerDecoder
+)
+
+from ..modules.attn_head_selector import AttnHeadSelector
+from ..modules.head_selection_transformer_layer import (
+ HeadSelectionTransformerEncoderLayer,
+ HeadSelectionTransformerDecoderLayer
+)
+
+
+class HeadSelectionTransformerModel(TransformerModel):
+ def __init__(self, args, encoder, decoder):
+ super().__init__(args, encoder, decoder)
+
+ @staticmethod
+ def add_args(parser):
+ TransformerModel.add_args(parser)
+ # encoder head selection
+ parser.add_argument(
+ "--encoder-attn-head-select",
+ action="store_true",
+ default=False,
+ help="encoder head selection"
+ )
+ parser.add_argument(
+ "--total-encoder-attention-heads",
+ type=int,
+ help="total number of encoder attention heads"
+ )
+ # decoder self attention
+ parser.add_argument(
+ "--decoder-self-attn-head-select",
+ action="store_true",
+ default=False,
+ help="decoder self-attention head selection"
+ )
+ # decoder-encoder attention
+ parser.add_argument(
+ "--dec-enc-attn-head-select",
+ action="store_true",
+ default=False,
+ help="decoder-encoder attention head selection"
+ )
+ parser.add_argument(
+ "--total-decoder-attention-heads",
+ type=int,
+ help="total number of decoder attention heads"
+ )
+ # selection strategy
+ parser.add_argument(
+ "--attn-head-select-strategy",
+ type=str,
+ help="attention head selection strategy, subset or group"
+ )
+
+ @classmethod
+ def build_encoder(cls, args, src_dict, embed_tokens):
+ if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
+ return HeadSelectionTransformerEncoder(
+ args, src_dict, embed_tokens
+ )
+ else:
+ return TransformerEncoder(args, src_dict, embed_tokens)
+
+ @classmethod
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
+ if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select):
+ return HeadSelectionTransformerDecoder(
+ args, tgt_dict, embed_tokens
+ )
+ else:
+ return TransformerDecoder(args, tgt_dict, embed_tokens)
+
+
+class HeadSelectionTransformerEncoder(TransformerEncoder):
+
+ def __init__(self, args, dictionary, embed_tokens):
+ self.num_tasks = args.encoder_tasks
+ self.num_layers = args.encoder_layers
+ self.total_num_heads = args.total_encoder_attention_heads
+ self.num_heads = args.encoder_attention_heads
+ self.select_strategy = args.attn_head_select_strategy
+
+ super().__init__(args, dictionary, embed_tokens)
+ self.attn_head_selector = AttnHeadSelector(
+ self.num_tasks,
+ self.num_layers,
+ self.total_num_heads,
+ self.num_heads,
+ self.select_strategy
+ )
+ self.task_ids = None
+ self.layers = nn.ModuleList(
+ [self.build_encoder_layer(args, i) for i in range(args.encoder_layers)]
+ )
+
+ def set_task_ids(self, task_ids):
+ self.task_ids = task_ids
+
+ def build_encoder_layer(self, args, layer_idx=None):
+ return HeadSelectionTransformerEncoderLayer(
+ args,
+ layer_idx,
+ attn_head_selector=self.attn_head_selector
+ )
+
+ def forward(
+ self,
+ src_tokens,
+ src_lengths: Optional[torch.Tensor] = None,
+ return_all_hiddens: bool = False,
+ token_embeddings: Optional[torch.Tensor] = None,
+ ):
+ self.attn_head_selector.head_select(self.task_ids)
+ return super().forward(src_tokens, src_lengths, return_all_hiddens, token_embeddings)
+
+
+class HeadSelectionTransformerDecoder(TransformerDecoder):
+
+ def __init__(
+ self,
+ args,
+ dictionary,
+ embed_tokens,
+ no_encoder_attn=False,
+ output_projection=None,
+ ):
+ self.num_tasks = args.decoder_tasks
+ self.num_layers = args.decoder_layers
+ self.total_num_heads = args.total_decoder_attention_heads
+ self.num_heads = args.decoder_attention_heads
+ self.select_strategy = args.attn_head_select_strategy
+ super().__init__(
+ args, dictionary, embed_tokens,
+ no_encoder_attn=no_encoder_attn,
+ output_projection=output_projection
+ )
+ self.self_attn_head_selector = None
+ self.enc_attn_head_selector = None
+ if safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select:
+ self.self_attn_head_selector = AttnHeadSelector(
+ self.num_tasks,
+ self.num_layers,
+ self.total_num_heads,
+ self.num_heads,
+ self.select_strategy
+ )
+ if safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select:
+ self.enc_attn_head_selector = AttnHeadSelector(
+ self.num_tasks,
+ self.num_layers,
+ self.total_num_heads,
+ self.num_heads,
+ self.select_strategy
+ )
+ self.task_ids = None
+ self.layers = nn.ModuleList(
+ [
+ self.build_head_selection_decoder_layer(args, no_encoder_attn, idx) for idx in range(args.decoder_layers)
+ ]
+ )
+
+ def set_task_ids(self, task_ids):
+ self.task_ids = task_ids
+
+ def build_head_selection_decoder_layer(self, args, no_encoder_attn=False, layer_idx=None):
+ return HeadSelectionTransformerDecoderLayer(
+ args,
+ layer_idx,
+ self.self_attn_head_selector,
+ self.enc_attn_head_selector,
+ no_encoder_attn=no_encoder_attn
+ )
+
+ def forward(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ features_only: bool = False,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ src_lengths: Optional[Any] = None,
+ return_all_hiddens: bool = False,
+ ):
+ if self.self_attn_head_selector is not None:
+ self.self_attn_head_selector.head_select(self.task_ids)
+ if self.enc_attn_head_selector is not None:
+ self.enc_attn_head_selector.head_select(self.task_ids)
+ return super().forward(
+ prev_output_tokens=prev_output_tokens,
+ encoder_out=encoder_out,
+ incremental_state=incremental_state,
+ features_only=features_only,
+ full_context_alignment=full_context_alignment,
+ alignment_layer=alignment_layer,
+ alignment_heads=alignment_heads,
+ src_lengths=src_lengths,
+ return_all_hiddens=return_all_hiddens
+ )
diff --git a/examples/linformer/src/modules/__init__.py b/examples/attention_head_selection/src/modules/__init__.py
similarity index 100%
rename from examples/linformer/src/modules/__init__.py
rename to examples/attention_head_selection/src/modules/__init__.py
diff --git a/examples/attention_head_selection/src/modules/attn_head_selector.py b/examples/attention_head_selection/src/modules/attn_head_selector.py
new file mode 100644
index 0000000000..346fc62308
--- /dev/null
+++ b/examples/attention_head_selection/src/modules/attn_head_selector.py
@@ -0,0 +1,81 @@
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import math
+
+
+class AttnHeadSelector(nn.Module):
+ """
+ Latent variable modeling of attention head selection
+ """
+ def __init__(
+ self, num_tasks, num_layers,
+ total_num_heads, num_heads,
+ select_strategy="group",
+ head_select_temp=5.0
+ ):
+ super(AttnHeadSelector, self).__init__()
+ self.num_tasks = num_tasks
+ self.num_layers = num_layers
+ self.total_num_heads = total_num_heads
+ self.num_heads = num_heads
+ self.select_strategy = select_strategy
+ self.temp = head_select_temp
+
+ self.head_logits = torch.nn.Parameter(
+ torch.Tensor(self.num_tasks, self.num_layers, total_num_heads),
+ requires_grad=True
+ )
+ nn.init.uniform_(
+ self.head_logits, a=math.log(0.01),
+ b=math.log(1.0)
+ )
+
+ def gumbel_sample(self, logits, tau=1.0):
+ gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
+ gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
+ gumbels1 = (logits + gumbels1 - gumbels2) / tau
+ y_soft = gumbels1.sigmoid()
+ return y_soft
+
+ def subset_select(self, y_soft, topk, dim=-1):
+ top_values, top_inds = torch.topk(y_soft, k=topk, dim=dim)
+ top_ret = 1.0 - top_values.detach() + top_values
+ return top_inds.detach(), top_ret
+
+ def group_selet(self, y_soft, topk, dim=-1):
+ # top_values: (num_tasks, num_layers, topk)
+ top_values, top_inds = torch.max(
+ y_soft.view(self.num_tasks, self.num_layers, -1, topk), dim=2
+ )
+ top_inds = top_inds * topk + torch.arange(topk, device=top_inds.device).unsqueeze(0).unsqueeze(1)
+ top_ret = 1.0 - top_values.detach() + top_values
+ return top_inds.detach(), top_ret
+
+ def head_select(self, task_ids=None):
+ # gumbel_sample
+ self.head_samples = self.gumbel_sample(self.head_logits, tau=self.temp)
+ # head select
+ if self.select_strategy == "subset":
+ self.subset_heads, self.subset_weights = self.subset_select(
+ self.head_samples,
+ topk=self.num_heads,
+ )
+ elif self.select_strategy == "group":
+ self.subset_heads, self.subset_weights = self.group_selet(
+ self.head_samples,
+ topk=self.num_heads,
+ )
+ else:
+ raise ValueError("{} is not supported".format(self.select_strategy))
+
+ self.batch_subset = self.subset_heads[task_ids, :, :]
+ self.batch_weights = self.subset_weights[task_ids, :, :]
+
+ def forward(self, layer_idx):
+ assert layer_idx is not None
+ batch_subset = self.batch_subset[:, layer_idx, :]
+ batch_weights = self.batch_weights[:, layer_idx, :]
+ return batch_subset, batch_weights
diff --git a/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py b/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py
new file mode 100644
index 0000000000..c792143503
--- /dev/null
+++ b/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py
@@ -0,0 +1,92 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.utils import safe_getattr
+from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
+from ..modules.multihead_attention_selection import MultiheadAttentionSelection
+
+
+class HeadSelectionTransformerEncoderLayer(TransformerEncoderLayer):
+
+ def __init__(self, args, layer_idx, attn_head_selector=None):
+ super().__init__(args)
+ self.layer_idx = layer_idx
+ self.self_attn = self.build_self_attention_selection(
+ self.embed_dim, args, attn_head_selector
+ )
+
+ def build_self_attention_selection(self, embed_dim, args, attn_head_selector=None):
+ return MultiheadAttentionSelection(
+ embed_dim,
+ args.total_encoder_attention_heads,
+ args.encoder_attention_heads,
+ dropout=args.attention_dropout,
+ self_attention=True,
+ q_noise=self.quant_noise,
+ qn_block_size=self.quant_noise_block_size,
+ layer_idx=self.layer_idx,
+ attn_head_selector=attn_head_selector
+ )
+
+
+class HeadSelectionTransformerDecoderLayer(TransformerDecoderLayer):
+
+ def __init__(
+ self,
+ args,
+ layer_idx,
+ self_attn_head_selector=None,
+ enc_attn_head_selector=None,
+ no_encoder_attn=False,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ ):
+ self.layer_idx = layer_idx
+ super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
+ if self_attn_head_selector is not None:
+ self.self_attn = self.build_self_attention_selection(
+ self.embed_dim, args,
+ self_attn_head_selector=self_attn_head_selector,
+ add_bias_kv=add_bias_kv,
+ add_zero_attn=add_zero_attn
+ )
+ if enc_attn_head_selector is not None:
+ self.encoder_attn = self.build_encoder_attention_selection(
+ self.embed_dim, args,
+ enc_attn_head_selector=enc_attn_head_selector
+ )
+
+ def build_self_attention_selection(
+ self, embed_dim, args, self_attn_head_selector=None,
+ add_bias_kv=False, add_zero_attn=False
+ ):
+ return MultiheadAttentionSelection(
+ embed_dim,
+ args.total_decoder_attention_heads,
+ args.decoder_attention_heads,
+ dropout=args.attention_dropout,
+ add_bias_kv=add_bias_kv,
+ add_zero_attn=add_zero_attn,
+ self_attention=not safe_getattr(args, "cross_self_attention"),
+ q_noise=self.quant_noise,
+ qn_block_size=self.quant_noise_block_size,
+ layer_idx=self.layer_idx,
+ attn_head_selector=self_attn_head_selector,
+ )
+
+ def build_encoder_attention_selection(self, embed_dim, args, enc_attn_head_selector=None):
+ return MultiheadAttentionSelection(
+ embed_dim,
+ args.total_decoder_attention_heads,
+ args.decoder_attention_heads,
+ kdim=args.encoder_embed_dim,
+ vdim=args.encoder_embed_dim,
+ dropout=args.attention_dropout,
+ encoder_decoder_attention=True,
+ q_noise=self.quant_noise,
+ qn_block_size=self.quant_noise_block_size,
+ layer_idx=self.layer_idx,
+ attn_head_selector=enc_attn_head_selector,
+ )
diff --git a/examples/attention_head_selection/src/modules/multihead_attention_selection.py b/examples/attention_head_selection/src/modules/multihead_attention_selection.py
new file mode 100644
index 0000000000..566ad822ac
--- /dev/null
+++ b/examples/attention_head_selection/src/modules/multihead_attention_selection.py
@@ -0,0 +1,355 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Dict, Optional, Tuple
+import torch
+from fairseq import utils
+from fairseq.modules.quant_noise import quant_noise
+from torch import Tensor, nn
+from torch.nn import Parameter
+
+from fairseq.modules.multihead_attention import MultiheadAttention
+from ..modules.multihead_functional import multi_head_attention_forward
+
+
+class MultiheadAttentionSelection(MultiheadAttention):
+
+ def __init__(
+ self,
+ embed_dim,
+ total_num_heads,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False,
+ q_noise=0.0,
+ qn_block_size=8,
+ layer_idx=0,
+ attn_head_selector=None
+ ):
+ super().__init__(
+ embed_dim,
+ num_heads,
+ kdim=kdim,
+ vdim=vdim,
+ dropout=dropout,
+ bias=bias,
+ add_bias_kv=add_bias_kv,
+ add_zero_attn=add_zero_attn,
+ self_attention=self_attention,
+ encoder_decoder_attention=encoder_decoder_attention,
+ q_noise=q_noise,
+ qn_block_size=qn_block_size,
+ )
+ self.layer_idx = layer_idx
+ self.attn_head_selector = attn_head_selector
+ self.total_num_heads = total_num_heads
+ self.total_embed_dim = self.head_dim * total_num_heads
+ self.k_proj = quant_noise(
+ nn.Linear(self.kdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ self.v_proj = quant_noise(
+ nn.Linear(self.vdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ self.q_proj = quant_noise(
+ nn.Linear(embed_dim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, self.total_embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, self.total_embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+ self.reset_parameters()
+
+ def forward(
+ self,
+ query,
+ key: Optional[Tensor],
+ value: Optional[Tensor],
+ key_padding_mask: Optional[Tensor] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ need_weights: bool = True,
+ static_kv: bool = False,
+ attn_mask: Optional[Tensor] = None,
+ before_softmax: bool = False,
+ need_head_weights: bool = False,
+ # subset_heads: Optional[Tensor] = None,
+ # subset_weights: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ if need_head_weights:
+ need_weights = True
+
+ is_tpu = query.device.type == "xla"
+
+ subset_heads, subset_weights = self.attn_head_selector(self.layer_idx)
+
+ tgt_len, bsz, embed_dim = query.size()
+ src_len = tgt_len
+ assert list(query.size()) == [tgt_len, bsz, self.embed_dim]
+ if key is not None:
+ src_len, key_bsz, _ = key.size()
+ if not torch.jit.is_scripting():
+ assert key_bsz == bsz
+ assert value is not None
+ assert src_len, bsz == value.shape[:2]
+
+ if (
+ not self.onnx_trace
+ and not is_tpu # don't use PyTorch version on TPUs
+ and incremental_state is None
+ and not static_kv
+ # A workaround for quantization to work. Otherwise JIT compilation
+ # treats bias in linear module as method.
+ and not torch.jit.is_scripting()
+ ):
+ assert key is not None and value is not None
+ return multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.total_num_heads,
+ self.num_heads,
+ torch.empty([0]),
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout_module.p,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ self.training or self.dropout_module.apply_during_inference,
+ key_padding_mask,
+ need_weights,
+ attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ subset_heads=subset_heads,
+ subset_weights=subset_weights
+ )
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if saved_state is not None and "prev_key" in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ q = self.q_proj(query)
+ k = self.k_proj(query)
+ v = self.v_proj(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.q_proj(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.k_proj(key)
+ v = self.v_proj(key)
+
+ else:
+ assert key is not None and value is not None
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
+ ],
+ dim=1,
+ )
+
+ q = (
+ q.contiguous()
+ .view(tgt_len, bsz * self.total_num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+ if k is not None:
+ k = (
+ k.contiguous()
+ .view(-1, bsz * self.total_num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+ if v is not None:
+ v = (
+ v.contiguous()
+ .view(-1, bsz * self.total_num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if "prev_key" in saved_state:
+ _prev_key = saved_state["prev_key"]
+ assert _prev_key is not None
+ prev_key = _prev_key.view(bsz * self.total_num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ assert k is not None
+ k = torch.cat([prev_key, k], dim=1)
+ src_len = k.size(1)
+ if "prev_value" in saved_state:
+ _prev_value = saved_state["prev_value"]
+ assert _prev_value is not None
+ prev_value = _prev_value.view(bsz * self.total_num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ assert v is not None
+ v = torch.cat([prev_value, v], dim=1)
+ prev_key_padding_mask: Optional[Tensor] = None
+ if "prev_key_padding_mask" in saved_state:
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+ assert k is not None and v is not None
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+ key_padding_mask=key_padding_mask,
+ prev_key_padding_mask=prev_key_padding_mask,
+ batch_size=bsz,
+ src_len=k.size(1),
+ static_kv=static_kv,
+ )
+
+ saved_state["prev_key"] = k.view(bsz, self.total_num_heads, -1, self.head_dim)
+ saved_state["prev_value"] = v.view(bsz, self.total_num_heads, -1, self.head_dim)
+ saved_state["prev_key_padding_mask"] = key_padding_mask
+ # In this branch incremental_state is never None
+ assert incremental_state is not None
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
+ assert k is not None
+ assert k.size(1) == src_len
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ assert v is not None
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
+ key_padding_mask
+ ),
+ ],
+ dim=1,
+ )
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [bsz * self.total_num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ if self.onnx_trace:
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
+ attn_weights += attn_mask
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.total_num_heads, tgt_len, src_len)
+ if not is_tpu:
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
+ )
+ else:
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = utils.softmax(
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
+ )
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = self.dropout_module(attn_weights)
+
+ assert v is not None
+
+ # evaluation
+ if subset_heads is not None and subset_heads.numel() == 1:
+ subset_heads = subset_heads.repeat(bsz)
+ subset_weights = subset_weights.repeat(bsz)
+
+ if subset_heads is None:
+ attn = torch.bmm(attn_probs, v)
+ else:
+ # training with head selection
+ mixed_attn = torch.bmm(attn_probs, v).contiguous().view(bsz, self.total_num_heads, tgt_len, self.head_dim)
+ attn = torch.stack(
+ [mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
+ )
+ attn = attn * subset_weights.unsqueeze(2).unsqueeze(3)
+ attn = attn.contiguous().view(bsz * self.num_heads, tgt_len, self.head_dim)
+
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ if self.onnx_trace and attn.size(1) == 1:
+ # when ONNX tracing a single decoder step (sequence length == 1)
+ # the transpose is a no-op copy before view, thus unnecessary
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
+ else:
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+ attn_weights: Optional[Tensor] = None
+ if need_weights:
+ if subset_heads is None:
+ attn_weights = attn_weights_float.view(
+ bsz, self.num_heads, tgt_len, src_len
+ ).transpose(1, 0)
+ else:
+ mixed_attn_weights = attn_weights_float.view(
+ bsz, self.total_num_heads, tgt_len, src_len
+ )
+ attn_weights = torch.stack(
+ [mixed_attn_weights[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
+ ).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+
+ return attn, attn_weights
diff --git a/examples/attention_head_selection/src/modules/multihead_functional.py b/examples/attention_head_selection/src/modules/multihead_functional.py
new file mode 100644
index 0000000000..d5edc777e3
--- /dev/null
+++ b/examples/attention_head_selection/src/modules/multihead_functional.py
@@ -0,0 +1,278 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Tuple
+import torch
+from torch import Tensor
+from torch.nn.functional import (
+ linear, softmax, dropout, pad,
+ has_torch_function,
+ handle_torch_function,
+ _in_projection_packed,
+)
+import math
+import warnings
+
+
+def _scaled_dot_product_attention(
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ attn_mask: Optional[Tensor] = None,
+ dropout_p: float = 0.0,
+ bsz: int = 1,
+ subset_heads: Optional[Tensor] = None,
+ subset_weights: Optional[Tensor] = None,
+) -> Tuple[Tensor, Tensor]:
+ B, Nt, E = q.shape
+ q = q / math.sqrt(E)
+ # B: bsz * total_num_heads
+ # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
+ attn = torch.bmm(q, k.transpose(-2, -1))
+ if attn_mask is not None:
+ attn += attn_mask
+ attn = softmax(attn, dim=-1)
+ if dropout_p > 0.0:
+ attn = dropout(attn, p=dropout_p)
+ if subset_heads is None:
+ # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
+ output = torch.bmm(attn, v)
+ else:
+ mixed_output = torch.bmm(attn, v).contiguous().view(bsz, -1, Nt, E)
+ output = torch.stack(
+ [mixed_output[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))],
+ dim=1
+ )
+ output = output * subset_weights.unsqueeze(2).unsqueeze(3)
+ output = output.contiguous().view(-1, Nt, E)
+ if subset_heads is not None:
+ _, Nt, Ns = attn.size()
+ mixed_attn = attn.view(bsz, -1, Nt, Ns)
+ attn = torch.stack(
+ [mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
+ )
+ return output, attn
+
+
+def _in_projection(
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w_q: Tensor,
+ w_k: Tensor,
+ w_v: Tensor,
+ b_q: Optional[Tensor] = None,
+ b_k: Optional[Tensor] = None,
+ b_v: Optional[Tensor] = None,
+) -> Tuple[Tensor, Tensor, Tensor]:
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
+
+
+def multi_head_attention_forward(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ embed_dim_to_check: int,
+ total_num_heads: int,
+ num_heads: int,
+ in_proj_weight: Tensor,
+ in_proj_bias: Optional[Tensor],
+ bias_k: Optional[Tensor],
+ bias_v: Optional[Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Optional[Tensor],
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Tensor] = None,
+ k_proj_weight: Optional[Tensor] = None,
+ v_proj_weight: Optional[Tensor] = None,
+ static_k: Optional[Tensor] = None,
+ static_v: Optional[Tensor] = None,
+ subset_heads: Optional[Tensor] = None,
+ subset_weights: Optional[Tensor] = None,
+):
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
+ if has_torch_function(tens_ops):
+ return handle_torch_function(
+ multi_head_attention_forward,
+ tens_ops,
+ query,
+ key,
+ value,
+ embed_dim_to_check,
+ total_num_heads,
+ num_heads,
+ in_proj_weight,
+ in_proj_bias,
+ bias_k,
+ bias_v,
+ add_zero_attn,
+ dropout_p,
+ out_proj_weight,
+ out_proj_bias,
+ training=training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ use_separate_proj_weight=use_separate_proj_weight,
+ q_proj_weight=q_proj_weight,
+ k_proj_weight=k_proj_weight,
+ v_proj_weight=v_proj_weight,
+ static_k=static_k,
+ static_v=static_v,
+ subset_heads=subset_heads,
+ subset_weights=subset_weights
+ )
+
+ # set up shape vars
+ tgt_len, bsz, embed_dim = query.shape
+ src_len, _, _ = key.shape
+ assert embed_dim == embed_dim_to_check, \
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
+ if isinstance(embed_dim, torch.Tensor):
+ # embed_dim can be a tensor when JIT tracing
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
+ else:
+ head_dim = embed_dim // num_heads
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
+ if use_separate_proj_weight:
+ # allow MHA to have different embedding dimensions when separate projection weights are used
+ assert key.shape[:2] == value.shape[:2], \
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
+ else:
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
+
+ #
+ # compute in-projection
+ #
+ if not use_separate_proj_weight:
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
+ else:
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
+ if in_proj_bias is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
+ q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
+
+ # prep attention mask
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.uint8:
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
+ attn_mask = attn_mask.to(torch.bool)
+ else:
+ assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
+ f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
+ # ensure attn_mask's dim is 3
+ if attn_mask.dim() == 2:
+ correct_2d_size = (tgt_len, src_len)
+ if attn_mask.shape != correct_2d_size:
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
+ attn_mask = attn_mask.unsqueeze(0)
+ elif attn_mask.dim() == 3:
+ correct_3d_size = (bsz * total_num_heads, tgt_len, src_len)
+ if attn_mask.shape != correct_3d_size:
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
+ else:
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
+
+ # prep key padding mask
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
+ key_padding_mask = key_padding_mask.to(torch.bool)
+
+ # add bias along batch dimension (currently second)
+ if bias_k is not None and bias_v is not None:
+ assert static_k is None, "bias cannot be added to static key."
+ assert static_v is None, "bias cannot be added to static value."
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+ else:
+ assert bias_k is None
+ assert bias_v is None
+
+ #
+ # reshape q, k, v for multihead attention and make em batch first
+ #
+ q = q.contiguous().view(tgt_len, bsz * total_num_heads, head_dim).transpose(0, 1)
+ if static_k is None:
+ k = k.contiguous().view(k.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1)
+ else:
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
+ assert static_k.size(0) == bsz * total_num_heads, \
+ f"expecting static_k.size(0) of {bsz * total_num_heads}, but got {static_k.size(0)}"
+ assert static_k.size(2) == head_dim, \
+ f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
+ k = static_k
+ if static_v is None:
+ v = v.contiguous().view(v.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1)
+ else:
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
+ assert static_v.size(0) == bsz * total_num_heads, \
+ f"expecting static_v.size(0) of {bsz * total_num_heads}, but got {static_v.size(0)}"
+ assert static_v.size(2) == head_dim, \
+ f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
+ v = static_v
+
+ # add zero attention along batch dimension (now first)
+ if add_zero_attn:
+ zero_attn_shape = (bsz * total_num_heads, 1, head_dim)
+ k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
+ v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+
+ # update source sequence length after adjustments
+ src_len = k.size(1)
+
+ # merge key padding and attention masks
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (bsz, src_len), \
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
+ expand(-1, total_num_heads, -1, -1).reshape(bsz * total_num_heads, 1, src_len)
+ if attn_mask is None:
+ attn_mask = key_padding_mask
+ elif attn_mask.dtype == torch.bool:
+ attn_mask = attn_mask.logical_or(key_padding_mask)
+ else:
+ attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
+
+ # convert mask to float
+ if attn_mask is not None and attn_mask.dtype == torch.bool:
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
+ attn_mask = new_attn_mask
+
+ # adjust dropout probability
+ if not training:
+ dropout_p = 0.0
+
+ #
+ # (deep breath) calculate attention and out projection
+ #
+ attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, bsz, subset_heads, subset_weights)
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
+
+ if need_weights:
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
+ else:
+ return attn_output, None
diff --git a/examples/attention_head_selection/src/speech_to_text_head_selection.py b/examples/attention_head_selection/src/speech_to_text_head_selection.py
new file mode 100644
index 0000000000..6e0ce11d63
--- /dev/null
+++ b/examples/attention_head_selection/src/speech_to_text_head_selection.py
@@ -0,0 +1,180 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from fairseq.optim.amp_optimizer import AMPOptimizer
+from fairseq.tasks import register_task
+from fairseq.tasks.speech_to_text import SpeechToTextTask
+
+from .data.speech_to_text_dataset_with_domain import SpeechToTextDatasetCreatorWithDomain
+from .loss.attention_head_selection import HeadSelectionLoss
+
+
+@register_task("speech_to_text_head_selection")
+class SpeechToTextHeadSelectionTask(SpeechToTextTask):
+
+ @classmethod
+ def add_args(cls, parser):
+ SpeechToTextTask.add_args(parser)
+ parser.add_argument(
+ "--task-type",
+ type=str,
+ default="lang",
+ help="task type for head selection, lang or domain"
+ )
+ parser.add_argument(
+ "--kl-weight",
+ type=float,
+ default=0.0,
+ help="the weight of KL loss"
+ )
+
+ def __init__(self, args, tgt_dict):
+ super().__init__(args, tgt_dict)
+ self.task_type = args.task_type
+ assert self.task_type in ["lang", "domain"], "invalid task_type: {}, should be either lang or domain".format(self.task_type)
+ self.map_task_to_id(args.train_subset)
+ self.encoder_head_prior = float(args.decoder_attention_heads) / args.total_decoder_attention_heads
+ self.decoder_head_prior = float(args.encoder_attention_heads) / args.total_encoder_attention_heads
+ self.kl_loss = HeadSelectionLoss(args)
+
+ def map_task_to_id(self, train_subset):
+ src_lang_set, tgt_lang_set, domain_set = set(), set(), set()
+ for split in train_subset.split(","):
+ seq = split.split("_")
+ assert len(seq) == 4, "subset {} should be in the format of train_src_tgt_domain".format(split)
+ _, src_lang, tgt_lang, domain = seq
+ src_lang_set.add(src_lang)
+ tgt_lang_set.add(tgt_lang)
+ domain_set.add(domain)
+ src_langs = sorted(src_lang_set)
+ tgt_langs = sorted(tgt_lang_set)
+ domains = sorted(domain_set)
+ self.src_lang_map = {src_lang: i for (i, src_lang) in enumerate(src_langs)}
+ self.tgt_lang_map = {tgt_lang: i for (i, tgt_lang) in enumerate(tgt_langs)}
+ self.domain_map = {domain: i for (i, domain) in enumerate(domains)}
+ if self.task_type == "lang":
+ self.encoder_tasks = len(self.src_lang_map)
+ self.decoder_tasks = len(self.tgt_lang_map)
+ elif self.task_type == "domain":
+ self.encoder_tasks = len(self.domain_map)
+ self.decoder_tasks = len(self.domain_map)
+
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
+ is_train_split = split.startswith("train")
+ pre_tokenizer = self.build_tokenizer(self.args)
+ bpe_tokenizer = self.build_bpe(self.args)
+ self.datasets[split] = SpeechToTextDatasetCreatorWithDomain.from_tsv(
+ self.args.data,
+ self.data_cfg,
+ split,
+ self.tgt_dict,
+ pre_tokenizer,
+ bpe_tokenizer,
+ is_train_split=is_train_split,
+ epoch=epoch,
+ seed=self.args.seed,
+ src_lang_map=self.src_lang_map,
+ tgt_lang_map=self.tgt_lang_map,
+ domain_map=self.domain_map,
+ speaker_to_id=self.speaker_to_id
+ )
+
+ def build_model(self, args):
+ args.encoder_tasks = self.encoder_tasks
+ args.decoder_tasks = self.decoder_tasks
+ return super(SpeechToTextHeadSelectionTask, self).build_model(args)
+
+ def get_sample_sizes(self, sample, task_ids, num_tasks):
+ """
+ task_ids: (bsz,)
+ get sample sizes for each task
+ """
+ bsz = task_ids.size(0)
+ mat = torch.zeros((num_tasks, bsz), device=task_ids.device)
+ mat[task_ids, torch.arange(bsz)] = 1.0
+ ntokens = torch.sum(sample['target'] != 1, dim=-1)
+ sample_sizes = torch.matmul(mat, ntokens.float())
+ return sample_sizes
+
+ def train_step(
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
+ ):
+ model.train()
+ model.set_num_updates(update_num)
+ # task ids
+ if self.task_type == "lang":
+ encoder_task_ids = sample["src_lang_ids"]
+ decoder_task_ids = sample["tgt_lang_ids"]
+ elif self.task_type == "domain":
+ encoder_task_ids = sample["domain_ids"]
+ decoder_task_ids = sample["domain_ids"]
+ model.encoder.set_task_ids(encoder_task_ids)
+ model.decoder.set_task_ids(decoder_task_ids)
+
+ with torch.autograd.profiler.record_function("forward"):
+ with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
+ loss, sample_size, logging_output = criterion(model, sample)
+ # KL loss
+ if self.args.encoder_attn_head_select:
+ sample_sizes = self.get_sample_sizes(sample, encoder_task_ids, self.encoder_tasks)
+ loss += self.kl_loss(
+ model.encoder.attn_head_selector.head_samples,
+ sample_sizes,
+ self.encoder_head_prior
+ )
+ if self.args.decoder_self_attn_head_select:
+ sample_sizes = self.get_sample_sizes(sample, decoder_task_ids, self.decoder_tasks)
+ loss += self.kl_loss(
+ model.decoder.self_attn_head_selector.head_samples,
+ sample_sizes,
+ self.decoder_head_prior
+ )
+ if self.args.dec_enc_attn_head_select:
+ sample_sizes = self.get_sample_sizes(sample, decoder_task_ids, self.decoder_tasks)
+ loss += self.kl_loss(
+ model.decoder.enc_attn_head_selector.head_sampes,
+ sample_sizes,
+ self.decoder_head_prior
+ )
+
+ if ignore_grad:
+ loss *= 0
+ with torch.autograd.profiler.record_function("backward"):
+ optimizer.backward(loss)
+ return loss, sample_size, logging_output
+
+ def valid_step(self, sample, model, criterion):
+ model.eval()
+ # task ids
+ if self.task_type == "lang":
+ encoder_task_ids = sample["src_lang_ids"]
+ decoder_task_ids = sample["tgt_lang_ids"]
+ elif self.task_type == "domain":
+ encoder_task_ids = sample["domain_ids"]
+ decoder_task_ids = sample["domain_ids"]
+ model.encoder.set_task_ids(encoder_task_ids)
+ model.decoder.set_task_ids(decoder_task_ids)
+ with torch.no_grad():
+ loss, sample_size, logging_output = criterion(model, sample)
+ return loss, sample_size, logging_output
+
+ def inference_step(
+ self, generator, models, sample, prefix_tokens=None, constraints=None
+ ):
+ with torch.no_grad():
+ # task ids
+ if self.task_type == "lang":
+ encoder_task_ids = sample["src_lang_ids"][:1]
+ decoder_task_ids = sample["tgt_lang_ids"][:1]
+ elif self.task_type == "domain":
+ encoder_task_ids = sample["domain_ids"][:1]
+ decoder_task_ids = sample["domain_ids"][:1]
+ for model in models:
+ model.encoder.set_task_ids(encoder_task_ids)
+ model.decoder.set_task_ids(decoder_task_ids)
+ return generator.generate(
+ models, sample, prefix_tokens=prefix_tokens, constraints=constraints
+ )
diff --git a/examples/audio_nlp/nlu/README.md b/examples/audio_nlp/nlu/README.md
new file mode 100644
index 0000000000..a11b3f3065
--- /dev/null
+++ b/examples/audio_nlp/nlu/README.md
@@ -0,0 +1,53 @@
+# End-to-end NLU
+
+End-to-end spoken language understanding (SLU) predicts intent directly from audio using a single model. It promises to improve the performance of assistant systems by leveraging acoustic information lost in the intermediate textual representation and preventing cascading errors from Automatic Speech Recognition (ASR). Further, having one unified model has efficiency advantages when deploying assistant systems on-device.
+
+This page releases the code for reproducing the results in [STOP: A dataset for Spoken Task Oriented Semantic Parsing](https://arxiv.org/abs/2207.10643)
+
+The dataset can be downloaded here: [download link](https://dl.fbaipublicfiles.com/stop/stop.tar.gz)
+
+The low-resource splits can be downloaded here: [download link](http://dl.fbaipublicfiles.com/stop/low_resource_splits.tar.gz)
+
+## Pretrained models end-to-end NLU Models
+
+| Speech Pretraining | ASR Pretraining | Test EM Accuracy | Tesst EM-Tree Accuracy | Link |
+| ----------- | ----------- |----------|----------|----------|
+| None | None | 36.54 | 57.01 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-none-none.pt) |
+| Wav2Vec | None | 68.05 | 82.53 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-none.pt) |
+| HuBERT | None | 68.40 | 82.85 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-none.pt) |
+| Wav2Vec | STOP | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-stop.pt) |
+| HuBERT | STOP | 69.23 | 82.87 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-stop.pt) |
+| Wav2Vec | Librispeech | 68.47 | 82.49 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-ls.pt) |
+| HuBERT | Librispeech | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-ls.pt) |
+
+## Pretrained models ASR Models
+| Speech Pre-training | ASR Dataset | STOP Eval WER | STOP Test WER | dev\_other WER | dev\_clean WER | test\_clean WER | test\_other WER | Link |
+| ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- |
+| HuBERT | Librispeech | 8.47 | 2.99 | 3.25 | 8.06 | 25.68 | 26.19 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls.pt) |
+| Wav2Vec | Librispeech | 9.215 | 3.204 | 3.334 | 9.006 | 27.257 | 27.588 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls.pt) |
+| HuBERT | STOP | 46.31 | 31.30 | 31.52 | 47.16 | 4.29 | 4.26 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-stop.pt) |
+| Wav2Vec | STOP | 43.103 | 27.833 | 28.479 | 28.479 | 4.679 | 4.667 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-stop.pt) |
+| HuBERT | Librispeech + STOP | 9.015 | 3.211 | 3.372 | 8.635 | 5.133 | 5.056 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls-stop.pt) |
+| Wav2Vec | Librispeech + STOP | 9.549 | 3.537 | 3.625 | 9.514 | 5.59 | 5.562 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls-stop.pt) |
+
+## Creating the fairseq datasets from STOP
+
+First, create the audio file manifests and label files:
+
+```
+python examples/audio_nlp/nlu/generate_manifests.py --stop_root $STOP_DOWNLOAD_DIR/stop --output $FAIRSEQ_DATASET_OUTPUT/
+```
+
+
+Run `./examples/audio_nlp/nlu/create_dict_stop.sh $FAIRSEQ_DATASET_OUTPUT` to generate the fairseq dictionaries.
+
+
+## Training an End-to-end NLU Model
+
+
+Download a wav2vec or hubert model from [link](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert) or [link](https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec)
+
+
+```
+python fairseq_cli/hydra-train --config-dir examples/audio_nlp/nlu/configs/ --config-name nlu_finetuning task.data=$FAIRSEQ_DATA_OUTPUT model.w2v_path=$PRETRAINED_MODEL_PATH
+```
diff --git a/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml b/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml
new file mode 100644
index 0000000000..bb90f45a30
--- /dev/null
+++ b/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 10
+ tensorboard_logdir: tb
+
+checkpoint:
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: em_error
+ save_interval: 10
+
+task:
+ _name: nlu_finetuning
+ data: ???
+ labels: parse
+ eval_wer_parse: true
+ autoregressive: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1600000
+ skip_invalid_size_inputs_valid_test: true
+ valid_subset: eval,test
+ train_subset: train
+ validate_interval: 10
+
+criterion:
+ _name: label_smoothed_cross_entropy
+
+optimization:
+ max_update: 320000
+ lr: [0.0001]
+ sentence_avg: true
+ update_freq: [1]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_seq2seq
+ w2v_path: ???
+ autoregressive: true
+ apply_mask: true
+ mask_prob: 0.5
+ mask_channel_prob: 0.5
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 0
diff --git a/examples/audio_nlp/nlu/create_dict_stop.sh b/examples/audio_nlp/nlu/create_dict_stop.sh
new file mode 100755
index 0000000000..753393284d
--- /dev/null
+++ b/examples/audio_nlp/nlu/create_dict_stop.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+### Script handling creation of data binaries
+### for model training within fairseq
+
+
+fairseq_root="."
+
+data_root=$1
+train_prefix="${data_root}/train"
+valid_prefix="${data_root}/eval"
+test_prefix="${data_root}/test"
+
+dest_dir="$data_root/"
+
+#echo "src dict: $src_dict" > "$dest_dir/src_dict.txt"
+#echo "trg dict: $tgt_dict" > "$dest_dir/tgt_dict.txt"
+
+ #--tgtdict $tgt_dict \
+PYTHONPATH=$fairseq_root \
+ python $fairseq_root/fairseq_cli/preprocess.py \
+ --source-lang "parse" \
+ --trainpref "$train_prefix" \
+ --validpref "$valid_prefix" \
+ --destdir "$dest_dir" \
+ --only-source \
+ --dict-only \
+ --workers 60;
+
+PYTHONPATH=$fairseq_root \
+ python $fairseq_root/fairseq_cli/preprocess.py \
+ --source-lang "ltr" \
+ --trainpref "$train_prefix" \
+ --validpref "$valid_prefix" \
+ --destdir "$dest_dir" \
+ --only-source \
+ --dict-only \
+ --workers 60;
diff --git a/examples/audio_nlp/nlu/generate_manifests.py b/examples/audio_nlp/nlu/generate_manifests.py
new file mode 100644
index 0000000000..e2176099cb
--- /dev/null
+++ b/examples/audio_nlp/nlu/generate_manifests.py
@@ -0,0 +1,83 @@
+import argparse
+from pathlib import Path
+import soundfile
+
+def get_insl_frame(parse):
+ out = []
+ def is_ont_token(tok):
+ return tok[0] in ["[", "]"];
+
+ res = []
+ x = []
+ for tok in parse.split():
+ if is_ont_token(tok):
+ res.extend('_'.join(x))
+ x = []
+ res.append(tok.upper())
+ else:
+ x.append(tok.upper())
+
+ return " ".join(res) + ' | '
+
+def sequencify_utterance(utterance):
+ utterance = utterance.upper()
+ utterance = utterance.replace(' ', '|') + '|'
+ utterance = list(utterance)
+ utterance = ' '.join(utterance)
+ return utterance
+
+
+def generate_fairseq_manifests(manifest, output_path, audio_root=None):
+
+ with open(manifest, 'r') as i:
+ parses = []
+ utterances = []
+ filepaths = []
+ keys = None
+ for (idx, line) in enumerate(i):
+ if idx == 0: keys = line.strip().split('\t')
+ else:
+ data = { k: v for (k, v) in zip(keys, line.split('\t'))}
+ parses.append(get_insl_frame(data['decoupled_normalized_seqlogical']))
+ utterances.append(sequencify_utterance(data['normalized_utterance']))
+ filepaths.append(data['file_id'])
+
+ parses_fp = output_path.with_suffix('.parse')
+ with open(str(parses_fp), 'w') as o:
+ for p in parses:
+ o.write(p + '\n')
+
+ utterances_fp = output_path.with_suffix('.ltr')
+ with open(str(utterances_fp), 'w') as o:
+ for u in utterances:
+ o.write(u + '\n')
+
+ filepaths_fp = output_path.with_suffix('.tsv')
+ with open(str(filepaths_fp), 'w') as o:
+ o.write(str(audio_root) + '\n')
+ for f in filepaths:
+ fullpath = audio_root / f
+ assert fullpath.exists(), f'{fullpath}'
+ frames = soundfile.info(fullpath).frames
+ o.write(f'{f}\t{frames}\n')
+
+def main(args):
+
+ splits = ['train', 'eval', 'test']
+ root = Path(args.stop_root)
+ output_root = Path(args.output)
+
+ for split in splits:
+ stop_manifest_path = root / 'manifests' / (split + '.tsv')
+ output_path = output_root / (split)
+
+ generate_fairseq_manifests(stop_manifest_path, output_path, root)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Process some integers.')
+ parser.add_argument('--stop_root', type=str,
+ help='path to stop root directory')
+ parser.add_argument('--output', type=str,
+ help='output directory')
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/bart/README.md b/examples/bart/README.md
index 76857a99a2..4050a724ee 100644
--- a/examples/bart/README.md
+++ b/examples/bart/README.md
@@ -1,6 +1,6 @@
# BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
-[https://arxiv.org/pdf/1910.13461.pdf]
+[https://arxiv.org/abs/1910.13461](https://arxiv.org/abs/1910.13461)
## Introduction
@@ -100,7 +100,7 @@ bart.predict('mnli', tokens).argmax() # 2: entailment
##### Register a new (randomly initialized) classification head:
```python
bart.register_classification_head('new_task', num_classes=3)
-logprobs = bart.predict('new_task', tokens)
+logprobs = bart.predict('new_task', tokens)
```
##### Batched prediction:
@@ -137,15 +137,23 @@ BART can be used to fill multiple `` tokens in the input.
```python
bart = torch.hub.load('pytorch/fairseq', 'bart.base')
bart.eval()
-bart.fill_mask('The cat on the .', topk=3, beam=10)
-# [('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]
+bart.fill_mask(['The cat on the .'], topk=3, beam=10)
+# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]]
```
Note that by default we enforce the output length to match the input length.
This can be disabled by setting ``match_source_len=False``:
```
-bart.fill_mask('The cat on the .', topk=3, beam=10, match_source_len=False)
-# [('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]
+bart.fill_mask(['The cat on the .'], topk=3, beam=10, match_source_len=False)
+# [[('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]]
+```
+
+Example code to fill masks for a batch of sentences using GPU
+```
+bart.cuda()
+bart.fill_mask(['The cat on the .', 'The dog on the .'], topk=3, beam=10)
+# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))], [('The dog was on the ground.', tensor(-0.6190)), ('The dog lay on the ground.', tensor(-0.6711)),
+('The dog was asleep on the couch', tensor(-0.6796))]]
```
#### Evaluating the `bart.large.mnli` model:
@@ -171,38 +179,23 @@ with open('glue_data/MNLI/dev_matched.tsv') as fin:
```
#### Evaluating the `bart.large.cnn` model:
-Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
+- Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
+- For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores
+- `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search.
+ In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`.
-```python
-bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn')
-bart.cuda()
-bart.eval()
-bart.half()
-count = 1
-bsz = 32
-with open('test.source') as source, open('test.hypo', 'w') as fout:
- sline = source.readline().strip()
- slines = [sline]
- for sline in source:
- if count % bsz == 0:
- with torch.no_grad():
- hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
-
- for hypothesis in hypotheses_batch:
- fout.write(hypothesis + '\n')
- fout.flush()
- slines = []
-
- slines.append(sline.strip())
- count += 1
- if slines != []:
- hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
- for hypothesis in hypotheses_batch:
- fout.write(hypothesis + '\n')
- fout.flush()
-```
-
-Install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
+In `fairseq`, summaries can be generated using:
+
+```bash
+cp data-bin/cnn_dm/dict.source.txt checkpoints/
+python examples/bart/summarize.py \
+ --model-dir pytorch/fairseq \
+ --model-file bart.large.cnn \
+ --src cnn_dm/test.source \
+ --out cnn_dm/test.hypo
+```
+
+For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
```bash
export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
diff --git a/examples/bart/README.summarization.md b/examples/bart/README.summarization.md
index d7fecc9ce6..8727584f2b 100644
--- a/examples/bart/README.summarization.md
+++ b/examples/bart/README.summarization.md
@@ -80,42 +80,23 @@ Expected training time is about `5 hours`. Training time can be reduced with dis
Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
### Inference for CNN-DM test data using above trained checkpoint.
-After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
+After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using `eval_cnn.py`, for example
-```python
-import torch
-from fairseq.models.bart import BARTModel
-
-bart = BARTModel.from_pretrained(
- 'checkpoints/',
- checkpoint_file='checkpoint_best.pt',
- data_name_or_path='cnn_dm-bin'
-)
-
-bart.cuda()
-bart.eval()
-bart.half()
-count = 1
-bsz = 32
-with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout:
- sline = source.readline().strip()
- slines = [sline]
- for sline in source:
- if count % bsz == 0:
- with torch.no_grad():
- hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
-
- for hypothesis in hypotheses_batch:
- fout.write(hypothesis + '\n')
- fout.flush()
- slines = []
-
- slines.append(sline.strip())
- count += 1
- if slines != []:
- hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
- for hypothesis in hypotheses_batch:
- fout.write(hypothesis + '\n')
- fout.flush()
+```bash
+cp data-bin/cnn_dm/dict.source.txt checkpoints/
+python examples/bart/summarize.py \
+ --model-dir checkpoints \
+ --model-file checkpoint_best.pt \
+ --src cnn_dm/test.source \
+ --out cnn_dm/test.hypo
+```
+For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10:
+```bash
+cp data-bin/cnn_dm/dict.source.txt checkpoints/
+python examples/bart/summarize.py \
+ --model-dir checkpoints \
+ --model-file checkpoint_best.pt \
+ --src cnn_dm/test.source \
+ --out cnn_dm/test.hypo \
+ --xsum-kwargs
```
-Use beam=6, lenpen=1.0, max_len_b=60, min_len=10 for Xsum Generation
diff --git a/examples/bart/summarize.py b/examples/bart/summarize.py
new file mode 100644
index 0000000000..04435f80e3
--- /dev/null
+++ b/examples/bart/summarize.py
@@ -0,0 +1,100 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from fairseq.models.bart import BARTModel
+import argparse
+
+XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
+CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
+
+
+@torch.no_grad()
+def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs):
+ count = 1
+
+ # if n_obs is not None: bsz = min(bsz, n_obs)
+
+ with open(infile) as source, open(outfile, "w") as fout:
+ sline = source.readline().strip()
+ slines = [sline]
+ for sline in source:
+ if n_obs is not None and count > n_obs:
+ break
+ if count % bsz == 0:
+ hypotheses_batch = bart.sample(slines, **eval_kwargs)
+ for hypothesis in hypotheses_batch:
+ fout.write(hypothesis + "\n")
+ fout.flush()
+ slines = []
+
+ slines.append(sline.strip())
+ count += 1
+
+ if slines != []:
+ hypotheses_batch = bart.sample(slines, **eval_kwargs)
+ for hypothesis in hypotheses_batch:
+ fout.write(hypothesis + "\n")
+ fout.flush()
+
+
+def main():
+ """
+ Usage::
+
+ python examples/bart/summarize.py \
+ --model-dir $HOME/bart.large.cnn \
+ --model-file model.pt \
+ --src $HOME/data-bin/cnn_dm/test.source
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model-dir",
+ required=True,
+ type=str,
+ default="bart.large.cnn/",
+ help="path containing model file and src_dict.txt",
+ )
+ parser.add_argument(
+ "--model-file",
+ default="checkpoint_best.pt",
+ help="where in model_dir are weights saved",
+ )
+ parser.add_argument(
+ "--src", default="test.source", help="text to summarize", type=str
+ )
+ parser.add_argument(
+ "--out", default="test.hypo", help="where to save summaries", type=str
+ )
+ parser.add_argument("--bsz", default=32, help="where to save summaries", type=int)
+ parser.add_argument(
+ "--n", default=None, help="how many examples to summarize", type=int
+ )
+ parser.add_argument(
+ "--xsum-kwargs",
+ action="store_true",
+ default=False,
+ help="if true use XSUM_KWARGS else CNN_KWARGS",
+ )
+ args = parser.parse_args()
+ eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS
+ if args.model_dir == "pytorch/fairseq":
+ bart = torch.hub.load("pytorch/fairseq", args.model_file)
+ else:
+ bart = BARTModel.from_pretrained(
+ args.model_dir,
+ checkpoint_file=args.model_file,
+ data_name_or_path=args.model_dir,
+ )
+ bart = bart.eval()
+ if torch.cuda.is_available():
+ bart = bart.cuda().half()
+ generate(
+ bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/constrained_decoding/README.md b/examples/constrained_decoding/README.md
index cfca9c91fd..e04b8b6a01 100644
--- a/examples/constrained_decoding/README.md
+++ b/examples/constrained_decoding/README.md
@@ -12,7 +12,7 @@ Constrained search is enabled by adding the command-line argument `--constraints
Constraints are appended to each line of input, separated by tabs. Each constraint (one or more tokens)
is a separate field.
-The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md),
+The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/main/examples/wmt19/README.md),
translates the sentence *Die maschinelle Übersetzung ist schwer zu kontrollieren.* with the constraints
"hard" and "to influence".
diff --git a/examples/criss/mining/mine.py b/examples/criss/mining/mine.py
index c86f73ae87..c872da196f 100644
--- a/examples/criss/mining/mine.py
+++ b/examples/criss/mining/mine.py
@@ -7,7 +7,12 @@
import glob
from subprocess import check_call
-import faiss
+try:
+ import faiss
+
+ has_faiss = True
+except ImportError:
+ has_faiss = False
import numpy as np
@@ -40,6 +45,8 @@ def load_batch(emb_file, dim):
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
+ if not has_faiss:
+ raise ImportError("Please install Faiss")
sims = []
inds = []
xfrom = 0
diff --git a/examples/criss/save_encoder.py b/examples/criss/save_encoder.py
index d911d066e3..24a842e409 100644
--- a/examples/criss/save_encoder.py
+++ b/examples/criss/save_encoder.py
@@ -11,6 +11,7 @@
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.sequence_generator import EnsembleModel
+from fairseq.utils import safe_hasattr
def get_avg_pool(
@@ -109,9 +110,9 @@ def main(args):
shard_id = 0
all_avg_pool = None
encoder_has_langtok = (
- hasattr(task.args, "encoder_langtok")
+ safe_hasattr(task.args, "encoder_langtok")
and task.args.encoder_langtok is not None
- and hasattr(task.args, "lang_tok_replacing_bos_eos")
+ and safe_hasattr(task.args, "lang_tok_replacing_bos_eos")
and not task.args.lang_tok_replacing_bos_eos
)
with progress_bar.build_progress_bar(args, itr) as t:
diff --git a/examples/cross_lingual_language_model/README.md b/examples/cross_lingual_language_model/README.md
index a78f86d8da..af9128e39e 100644
--- a/examples/cross_lingual_language_model/README.md
+++ b/examples/cross_lingual_language_model/README.md
@@ -61,14 +61,14 @@ fairseq-train \
--max-update 2400000 --save-interval 1 --no-epoch-checkpoints \
--arch xlm_base \
--optimizer adam --lr-scheduler reduce_lr_on_plateau \
---lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \
+--lr-shrink 0.5 --lr 0.0001 --stop-min-lr 1e-09 \
--dropout 0.1 \
--criterion legacy_masked_lm_loss \
--max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \
--dataset-impl lazy --seed 0 \
--masked-lm-only \
--monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \
---ddp-backend=no_c10d
+--ddp-backend=legacy_ddp
```
Some Notes:
diff --git a/examples/data2vec/README.md b/examples/data2vec/README.md
new file mode 100644
index 0000000000..a0ff21b82a
--- /dev/null
+++ b/examples/data2vec/README.md
@@ -0,0 +1,261 @@
+# data2vec 2.0
+
+data2vec 2.0 improves the training efficiency of the original data2vec algorithm. We make the following improvements for efficiency considerations - we forward only the unmasked timesteps through the encoder, we use convolutional decoder and we use multimasking to amortize the compute overhead of the teacher model. You can find details in the paper [Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language](https://arxiv.org/abs/2212.07525) and our [blog post](https://ai.facebook.com/blog/ai-self-supervised-learning-data2vec/).
+
+## Pretrained and finetuned models
+### Vision
+| Model | Finetuning split | Link
+|---|---|---
+data2vec ViT-B | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet.pt)
+data2vec ViT-B | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet_ft.pt)
+data2vec ViT-L | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet.pt)
+data2vec ViT-L | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet_ft.pt)
+data2vec ViT-H | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet.pt)
+data2vec ViT-H | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet_ft.pt)
+
+Vision models only are license under CC-BY-NC.
+### Speech
+
+| Model | Finetuning split | Dataset | Link
+|---|---|---|---
+data2vec Base | No fine-tuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri.pt)
+data2vec Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri_960h.pt)
+data2vec Large | No fine-tuning | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox.pt)
+data2vec Large | 960 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox_960h.pt)
+
+### NLP
+
+| Model | Fine-tuning data | Dataset | Link | Dict | BPE
+|---|---|---|---|---|---
+data2vec Base | No fine-tuning | Books + Wiki | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/nlp_base.pt) | [dict](https://dl.fbaipublicfiles.com/fairseq/data2vec2/dict.txt) | [encoder](https://dl.fbaipublicfiles.com/fairseq/data2vec2/encoder.json) / [vocab](https://dl.fbaipublicfiles.com/fairseq/data2vec2/vocab.bpe)
+
+[//]: # (## Data Preparation)
+
+[//]: # ()
+[//]: # (### Vision)
+
+[//]: # (add details)
+
+[//]: # (### Speech)
+
+[//]: # (add details)
+
+[//]: # ()
+[//]: # (### NLP)
+
+[//]: # (add details)
+
+
+## Commands to train different models using data2vec 2.0
+
+### Vision
+
+Commands to pretrain different model configurations
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name base_images_only_task task.data=/path/to/dir
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name large_images_only_task task.data=/path/to/dir
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name huge_images14_only_task task.data=/path/to/dir
+```
+
+Commands to finetune different model configurations
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
+--config-name mae_imagenet_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
+--config-name mae_imagenet_large_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
+--config-name mae_imagenet_huge_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
+```
+
+### Speech
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name base_audio_only_task task.data=/path/to/manifests
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name large_audio_only_task task.data=/path/to/manifests
+```
+
+Finetuning:
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/wav2vec/config/finetuning --config-name vox_10h \
+task.data=/path/to/manifests model.w2v_path=/path/to/pretrained/model common.user_dir=examples/data2vec
+```
+
+Replace vox_10h with the right config depending on your model and fine-tuning split.
+See examples/wav2vec/config/finetuning for all available configs.
+
+### NLP
+
+Commands to pretrain
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name base_text_only_task task.data=/path/to/file
+```
+
+Commands to fine-tune all GLUE tasks
+```shell script
+$ task=cola # choose from [cola|qnli|mrpc|rte|sst_2|mnli|qqp|sts_b]
+$ lr=1e-5 # sweep [1e-5|2e-5|4e-5|6e-5] for each task
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2/text_finetuning \
+--config-name $task task.data=/path/to/file model.model_path=/path/to/pretrained/model "optimization.lr=[${lr}]"
+```
+
+# data2vec
+
+data2vec is a framework for self-supervised representation learning for images, speech, and text as described in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language (Baevski et al., 2022)](https://ai.facebook.com/research/data2vec-a-general-framework-for-self-supervised-learning-in-speech-vision-and-language). The algorithm uses the same learning mechanism for different modalities.
+
+
+## Pre-trained models
+
+### Vision
+
+Code and pre-trained models for data2vec visions can be found [here](https://github.com/facebookresearch/data2vec_vision/tree/main/beit).
+
+### Speech
+
+| Model | Finetuning split | Dataset | Link
+|---|---|---|---
+data2vec Base | No fine-tuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/audio_base_ls.pt)
+data2vec Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/audio_base_ls_10m.pt)
+data2vec Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/audio_base_ls_100h.pt)
+data2vec Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/audio_base_ls_960h.pt)
+data2vec Large | No fine-tuning | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/vox_pretrained.pt)
+data2vec Large | 10 minutes | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/vox_10m.pt)
+data2vec Large | 100 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/vox_100h.pt)
+data2vec Large | 960 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/vox_960h.pt)
+---
+
+### NLP
+
+Model | Fine-tuning data | Dataset | Link
+|---|---|---|---|
+data2vec Base | No fine-tuning | Books + Wiki | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/nlp_base.pt)
+
+## Training a new speech model with the CLI tools
+
+Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length)
+
+### Prepare training data manifest:
+
+First, install the `soundfile` library:
+```shell script
+pip install soundfile
+```
+
+Next, run:
+
+```shell script
+$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid
+```
+
+$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read.
+
+$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation.
+To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a
+separately pre-processed manifest file.
+
+### Train a data2vec Base model:
+
+This configuration was used for the base model trained on the Librispeech dataset in the data2vec paper
+
+Note that the input is expected to be single channel, sampled at 16 kHz
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/audio/pretraining \
+--config-name base_librispeech task.data=/path/to/manifests common.user_dir=examples/data2vec
+```
+
+Note: you can simulate 16 GPUs by using k GPUs and adding command line parameters
+`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 16/k
+
+### Fine-tune a pre-trained model with CTC:
+
+Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format.
+A letter vocabulary can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt).
+An example [script](../wav2vec/libri_labels.py) that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows:
+
+```shell script
+split=train
+$ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $split
+```
+
+Fine-tuning on 100h of Librispeech with letter targets:
+```shell script
+$ fairseq-hydra-train \
+ distributed_training.distributed_port=$PORT \
+ task.data=/path/to/data \
+ model.w2v_path=/path/to/model.pt \
+ --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \
+ --config-name base_100h common.user_dir=examples/data2vec
+```
+
+There are other config files in the config/finetuning directory that can be used to fine-tune on other splits.
+You can specify the right config via the `--config-name` parameter.
+
+Decoding with a language model during training requires flashlight [python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter).
+If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line.
+
+### Evaluating a CTC model:
+
+Evaluating a CTC model with a language model requires [flashlight python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter) to be installed.
+
+Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019).
+Be sure to upper-case the language model vocab after downloading it.
+
+Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt).
+
+Next, run the evaluation command:
+
+```shell script
+python examples/speech_recognition/new/infer.py --config-dir examples/speech_recognition/new/conf \
+--config-name infer task=audio_finetuning task.data=/path/to/manifests common.user_dir=examples/data2vec \
+task.labels=ltr decoding.type=kenlm \
+decoding.lmweight=${lmweight} decoding.wordscore=${wordscore} decoding.silweight=${silscore} \
+decoding.lexicon=/path/to/lexicon \
+decoding.lmpath=/path/to/lm decoding.unique_wer_file=True \
+dataset.gen_subset=dev_clean,dev_other,test_clean,test_other \
+common_eval.path=/path/to/checkpoint.pt decoding.beam=1500 distributed_training.distributed_world_size=${num_gpus}
+```
+
+To get raw numbers, use decoding.type=viterbi and omit the lexicon. To use the transformer language model, use decoding.type=fairseqlm.
+
+## Training a new NLP model with the CLI tools
+
+Please follow the [RoBERTa](../roberta/README.md) instructions to preprocess your data. To train a data2vec model on run:
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/text/pretraining \
+--config-name base task.data=/path/to/data common.user_dir=examples/data2vec
+```
+
+As for speech models, you can simulate 16 gpus by using the update_freq parameter.
+
+### Finetuning data2vec-text on GLUE
+
+Please use a command similar to this:
+
+```shell
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task task.data=$data_path checkpoint.restore_file="${/path/to/pretrained/model.pt}"
+```
diff --git a/examples/data2vec/__init__.py b/examples/data2vec/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/data2vec/config/audio/classification/base_classification.yaml b/examples/data2vec/config/audio/classification/base_classification.yaml
new file mode 100644
index 0000000000..fdb9c8d3d7
--- /dev/null
+++ b/examples/data2vec/config/audio/classification/base_classification.yaml
@@ -0,0 +1,70 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ all_gather_list_size: 70000
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+
+checkpoint:
+ save_interval: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: mAP
+ maximize_best_checkpoint_metric: true
+
+task:
+ _name: audio_classification
+ data: ???
+ normalize: true
+ labels: lbl
+
+dataset:
+ num_workers: 6
+ max_tokens: 2560000
+ skip_invalid_size_inputs_valid_test: true
+ valid_subset: eval
+ validate_interval: 5
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: model
+ can_sum: false
+ log_keys:
+ - _predictions
+ - _targets
+
+optimization:
+ max_update: 30000
+ lr: [0.00006] # scratch 53-5
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 5000
+
+model:
+ _name: audio_classification
+ model_path: ???
+ apply_mask: true
+ mask_prob: 0.6
+ mask_length: 5 # scratch 1
+ mask_channel_prob: 0
+ mask_channel_length: 64
+ layerdrop: 0.1
+ dropout: 0.1
+ activation_dropout: 0.1
+ attention_dropout: 0.2
+ feature_grad_mult: 0 # scratch 1
+ label_mixup: true
+ source_mixup: 0.5
+ prediction_mode: lin_softmax # scratch average_sigmoid
+
diff --git a/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml b/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml
new file mode 100644
index 0000000000..881a1583f8
--- /dev/null
+++ b/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml
@@ -0,0 +1,35 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml b/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml
new file mode 100644
index 0000000000..de7894d9cf
--- /dev/null
+++ b/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml
@@ -0,0 +1,35 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 1
+ tasks_per_node: 1
+ mem_gb: 100
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml b/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml
new file mode 100644
index 0000000000..b016cac9b5
--- /dev/null
+++ b/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml
@@ -0,0 +1,35 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/audioset.yaml b/examples/data2vec/config/audio/pretraining/audioset.yaml
new file mode 100644
index 0000000000..dd30fbedd5
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/audioset.yaml
@@ -0,0 +1,91 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: /private/home/abaevski/data/audioset
+ max_sample_size: 320000
+ min_sample_size: 32000
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 3400000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 5
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 24
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+# - avg_self_attn
+# - weights
+
+optimization:
+ max_update: 200000
+ lr: [0.0005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 10000
+
+model:
+ _name: data2vec_audio
+ extractor_mode: layer_norm
+ encoder_layerdrop: 0.05
+ dropout_input: 0.0
+ dropout_features: 0.0
+ feature_grad_mult: 1.0
+ encoder_embed_dim: 768
+
+ mask_prob: 0.65
+ mask_length: 10
+
+ loss_beta: 0
+ loss_scale: null
+
+ instance_norm_target_layer: true
+ layer_norm_targets: true
+ average_top_k_layers: 12
+
+ self_attn_norm_type: deepnorm
+ final_norm_type: deepnorm
+
+ pos_conv_depth: 5
+ conv_pos: 95
+
+ ema_decay: 0.999
+ ema_end_decay: 0.9999
+ ema_anneal_end_step: 30000
+ ema_transformer_only: true
+ ema_layers_only: false
+
+ require_same_masks: true
+ mask_dropout: 0
diff --git a/examples/data2vec/config/audio/pretraining/base_librispeech.yaml b/examples/data2vec/config/audio/pretraining/base_librispeech.yaml
new file mode 100644
index 0000000000..c332c5a3f8
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/base_librispeech.yaml
@@ -0,0 +1,83 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: ???
+ max_sample_size: 320000
+ min_sample_size: 32000
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 3800000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 5
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.03,0.9,0.07]
+
+model:
+ _name: data2vec_audio
+ extractor_mode: layer_norm
+ encoder_layerdrop: 0.05
+ dropout_input: 0.0
+ dropout_features: 0.0
+ feature_grad_mult: 1.0
+ encoder_embed_dim: 768
+
+ mask_prob: 0.65
+ mask_length: 10
+
+ loss_beta: 0
+ loss_scale: null
+
+ instance_norm_target_layer: true
+ average_top_k_layers: 8
+
+ pos_conv_depth: 5
+ conv_pos: 95
+
+ ema_decay: 0.999
+ ema_end_decay: 0.9999
+ ema_anneal_end_step: 30000
+ ema_transformer_only: true
+ ema_layers_only: true
+
+ require_same_masks: true
+ mask_dropout: 0
diff --git a/examples/data2vec/config/audio/pretraining/run_config/local.yaml b/examples/data2vec/config/audio/pretraining/run_config/local.yaml
new file mode 100644
index 0000000000..45595f9eea
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml
new file mode 100644
index 0000000000..732f018899
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml
new file mode 100644
index 0000000000..e2bab5675a
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml
new file mode 100644
index 0000000000..ec53dc2a98
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml
new file mode 100644
index 0000000000..70cc8cbb5b
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - task.post_save_script
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml
new file mode 100644
index 0000000000..14b47d14e6
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml
new file mode 100644
index 0000000000..c54d735fb2
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml
new file mode 100644
index 0000000000..0231b2690d
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - task.post_save_script
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml
new file mode 100644
index 0000000000..9a4e43a987
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 6
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml
new file mode 100644
index 0000000000..78c9f57aeb
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/text/pretraining/base.yaml b/examples/data2vec/config/text/pretraining/base.yaml
new file mode 100644
index 0000000000..c6b07c4052
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/base.yaml
@@ -0,0 +1,77 @@
+# @package _group_
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+
+checkpoint:
+ no_epoch_checkpoints: true
+ save_interval_updates: 50000
+ keep_interval_updates: 1
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: legacy_ddp
+
+task:
+ _name: masked_lm
+ data: ???
+ sample_break_mode: complete_doc
+ tokens_per_sample: 512
+ include_target_tokens: true
+ random_token_prob: 0
+ leave_unmasked_prob: 0
+ mask_prob: 0.35
+ mask_multiple_length: 4
+
+criterion: model
+
+dataset:
+ max_tokens: 8192
+ ignore_unused_valid_subsets: true
+ skip_invalid_size_inputs_valid_test: true
+
+optimizer:
+ _name: adam
+ weight_decay: 0.01
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 10000
+
+optimization:
+ clip_norm: 5
+ lr: [0.0002]
+ max_update: 1000000
+ update_freq: [1]
+
+model:
+ _name: data2vec_text
+ head_layers: 2
+ average_top_k_layers: 10
+ layer_norm_target_layer: true
+ loss_scale: 1
+ ema_decay: 0.999
+ ema_end_decay: 0.9999
+ ema_anneal_end_step: 300000
+ loss_beta: 4
+ ema_transformer_layers_only: true
+
+ transformer:
+ dropout: 0.1
+ attention_dropout: 0.1
+ layernorm_embedding: true
+ activation_fn: gelu
+ no_scale_embedding: true
+ max_source_positions: 512
+ encoder:
+ embed_dim: 768
+ ffn_embed_dim: 3072
+ layers: 12
+ attention_heads: 12
+ normalize_before: false
+ learned_pos: true
+ layerdrop: 0
diff --git a/examples/data2vec/config/text/pretraining/run_config/local.yaml b/examples/data2vec/config/text/pretraining/run_config/local.yaml
new file mode 100644
index 0000000000..45595f9eea
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml
new file mode 100644
index 0000000000..4bac45a58d
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
+ exclude: a100-st-p4d24xlarge-471
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml
new file mode 100644
index 0000000000..006a0f2116
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml
new file mode 100644
index 0000000000..4292198b4e
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
+ exclude: a100-st-p4d24xlarge-471
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml
new file mode 100644
index 0000000000..0e1555d20f
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml
new file mode 100644
index 0000000000..c54d735fb2
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml
new file mode 100644
index 0000000000..5df84cd6da
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml
@@ -0,0 +1,41 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
+ exclude: a100-st-p4d24xlarge-471
+
+distributed_training:
+ distributed_world_size: 32
+ ddp_backend: legacy_ddp
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml
new file mode 100644
index 0000000000..5b32c23a66
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml
@@ -0,0 +1,41 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: pt
+ partition: wav2vec
+ max_num_timeout: 30
+ exclude: a100-st-p4d24xlarge-471
+
+distributed_training:
+ distributed_world_size: 64
+ ddp_backend: legacy_ddp
diff --git a/examples/data2vec/config/v2/base_audio_only_task.yaml b/examples/data2vec/config/v2/base_audio_only_task.yaml
new file mode 100644
index 0000000000..65a9ab3e73
--- /dev/null
+++ b/examples/data2vec/config/v2/base_audio_only_task.yaml
@@ -0,0 +1,113 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: false
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: /private/home/abaevski/data/librispeech/full
+ max_sample_size: 320000
+ min_sample_size: 32000
+ normalize: true
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 5
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 8
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 400000
+ lr: [0.00075]
+ debug_param_names: true
+
+optimizer:
+ _name: adam
+ adam_betas: [ 0.9,0.98 ]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 8000
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 0
+ loss_scale: null
+
+ depth: 12
+ embed_dim: 768
+ clone_batch: 8
+
+ ema_decay: 0.999
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 75000
+ ema_encoder_only: false
+
+ average_top_k_layers: 8
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: false
+
+ layerdrop: 0.05
+ norm_eps: 1e-5
+
+ supported_modality: AUDIO
+
+ modalities:
+ audio:
+ feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]'
+ conv_pos_depth: 5
+ conv_pos_width: 95
+ conv_pos_groups: 16
+ prenet_depth: 0
+ mask_prob: 0.5
+ mask_prob_adjust: 0.05
+ inverse_mask: false
+ mask_length: 5
+ mask_noise_std: 0.01
+ mask_dropout: 0
+ add_masks: false
+ ema_local_encoder: false
+ use_alibi_encoder: true
+ prenet_layerdrop: 0.05
+ prenet_dropout: 0.1
+ learned_alibi_scale: true
+ learned_alibi_scale_per_head: true
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 384
+ decoder_groups: 16
+ decoder_kernel: 7
+ decoder_layers: 4
diff --git a/examples/data2vec/config/v2/base_images_only_task.yaml b/examples/data2vec/config/v2/base_images_only_task.yaml
new file mode 100644
index 0000000000..ff0c247b13
--- /dev/null
+++ b/examples/data2vec/config/v2/base_images_only_task.yaml
@@ -0,0 +1,116 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+ local_cache_path: /scratch/cache_abaevski/imagenet
+ key: source
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 10
+ batch_size: 16
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 375300
+ lr: [ 0.001 ]
+ debug_param_names: true
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 1e-3
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ ema_decay: 0.9998
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 100000
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: true
+ end_of_block_targets: false
+
+ depth: 10
+ average_top_k_layers: 10
+ clone_batch: 16
+
+ norm_eps: 1e-6
+
+ min_target_var: 0
+ min_pred_var: 0
+
+ encoder_dropout: 0
+ post_mlp_drop: 0
+ attention_dropout: 0
+ activation_dropout: 0
+
+ supported_modality: IMAGE
+ cls_loss: 0.01
+
+ ema_encoder_only: false
+
+ modalities:
+ image:
+ inverse_mask: true
+ mask_prob: 0.8
+ mask_prob_adjust: 0.07
+ mask_length: 3
+ mask_noise_std: 0.01
+ prenet_depth: 2
+ ema_local_encoder: true
+ num_extra_tokens: 1
+ init_extra_token_zero: false
+ use_alibi_encoder: false
+ decoder:
+ decoder_dim: 768
+ decoder_groups: 16
+ decoder_kernel: 3
+ decoder_layers: 6
+ input_dropout: 0
\ No newline at end of file
diff --git a/examples/data2vec/config/v2/base_text_only_task.yaml b/examples/data2vec/config/v2/base_text_only_task.yaml
new file mode 100644
index 0000000000..62f22eb0fe
--- /dev/null
+++ b/examples/data2vec/config/v2/base_text_only_task.yaml
@@ -0,0 +1,112 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ no_epoch_checkpoints: true
+ save_interval_updates: 50000
+ keep_interval_updates: 1
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: legacy_ddp
+
+task:
+ _name: masked_lm
+ data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
+ sample_break_mode: none
+ tokens_per_sample: 512
+ include_target_tokens: true
+ random_token_prob: 0
+ leave_unmasked_prob: 0
+ include_index: True
+ skip_masking: True
+ d2v2_multi: True
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+dataset:
+ batch_size: 4
+ ignore_unused_valid_subsets: true
+ skip_invalid_size_inputs_valid_test: true
+ disable_validation: true
+
+optimization:
+ clip_norm: 1
+ lr: [0.0002]
+ max_update: 1000000
+ update_freq: [1]
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0002
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 0
+ loss_scale: 1
+
+ depth: 12
+ embed_dim: 768
+ clone_batch: 8
+
+ ema_decay: 0.9999
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 100000
+ ema_encoder_only: true
+
+ average_top_k_layers: 12
+ layer_norm_target_layer: false
+ instance_norm_target_layer: true
+ batch_norm_target_layer: false
+ instance_norm_targets: false
+ layer_norm_targets: false
+
+ layerdrop: 0
+ norm_eps: 1e-5
+
+ supported_modality: TEXT
+
+ modalities:
+ text:
+ mask_prob: 0.48
+ mask_length: 1
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 768
+ decoder_groups: 1
+ decoder_kernel: 9
+ decoder_layers: 5
+ decoder_residual: false
+ projection_layers: 2
+ projection_ratio: 2.0
diff --git a/examples/data2vec/config/v2/huge_images14_only_task.yaml b/examples/data2vec/config/v2/huge_images14_only_task.yaml
new file mode 100644
index 0000000000..a8a15253f2
--- /dev/null
+++ b/examples/data2vec/config/v2/huge_images14_only_task.yaml
@@ -0,0 +1,122 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+ local_cache_path: /scratch/cache_abaevski/imagenet
+ key: source
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 10
+ batch_size: 8
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 32
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 500000
+ lr: [ 0.0004 ]
+ debug_param_names: true
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 4e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ ema_decay: 0.9998
+ ema_end_decay: 1
+ ema_anneal_end_step: 300000
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: true
+ end_of_block_targets: false
+
+ depth: 32
+ embed_dim: 1280
+ num_heads: 16
+
+ average_top_k_layers: 24
+ clone_batch: 16
+
+ norm_eps: 1e-6
+
+ min_target_var: 0
+ min_pred_var: 0
+
+ encoder_dropout: 0
+ post_mlp_drop: 0
+ attention_dropout: 0
+ activation_dropout: 0
+
+ supported_modality: IMAGE
+ cls_loss: 0.01
+
+ ema_encoder_only: false
+
+ modalities:
+ image:
+ patch_size: 14
+ inverse_mask: true
+ mask_prob: 0.75
+ mask_prob_adjust: 0.1
+ mask_length: 3
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ ema_local_encoder: true
+ num_extra_tokens: 1
+ init_extra_token_zero: false
+ use_alibi_encoder: false
+ embed_dim: 1280
+ decoder:
+ decoder_dim: 1024
+ decoder_groups: 16
+ decoder_kernel: 5
+ decoder_layers: 3
+ final_layer_norm: false
+ input_dropout: 0
\ No newline at end of file
diff --git a/examples/data2vec/config/v2/huge_images_only_task.yaml b/examples/data2vec/config/v2/huge_images_only_task.yaml
new file mode 100644
index 0000000000..7a352ac3c7
--- /dev/null
+++ b/examples/data2vec/config/v2/huge_images_only_task.yaml
@@ -0,0 +1,120 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+ local_cache_path: /scratch/cache_abaevski/imagenet
+ key: source
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 10
+ batch_size: 8
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 375300
+ lr: [ 0.0004 ]
+ debug_param_names: true
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 4e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ ema_decay: 0.9998
+ ema_end_decay: 0.99995
+ ema_anneal_end_step: 150000
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: true
+ end_of_block_targets: false
+
+ depth: 32
+ embed_dim: 1280
+ num_heads: 16
+
+ average_top_k_layers: 24
+ clone_batch: 16
+
+ norm_eps: 1e-6
+
+ min_target_var: 0
+ min_pred_var: 0
+
+ encoder_dropout: 0
+ post_mlp_drop: 0
+ attention_dropout: 0
+ activation_dropout: 0
+
+ supported_modality: IMAGE
+ cls_loss: 0.01
+
+ ema_encoder_only: false
+
+ modalities:
+ image:
+ inverse_mask: true
+ mask_prob: 0.75
+ mask_prob_adjust: 0.1
+ mask_length: 3
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ ema_local_encoder: true
+ num_extra_tokens: 1
+ init_extra_token_zero: false
+ use_alibi_encoder: false
+ embed_dim: 1280
+ decoder:
+ decoder_dim: 1024
+ decoder_groups: 16
+ decoder_kernel: 5
+ decoder_layers: 3
+ input_dropout: 0
\ No newline at end of file
diff --git a/examples/data2vec/config/v2/large_audio_only_task.yaml b/examples/data2vec/config/v2/large_audio_only_task.yaml
new file mode 100644
index 0000000000..3f61589721
--- /dev/null
+++ b/examples/data2vec/config/v2/large_audio_only_task.yaml
@@ -0,0 +1,122 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: /fsx-wav2vec/abaevski/data/librivox/no_silence
+ max_sample_size: 320000
+ min_sample_size: 32000
+ normalize: true
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 8
+ max_tokens: 320000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 5
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 48
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 600000
+ debug_param_names: true
+ clip_norm: 1
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0004
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 10000
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 0
+ loss_scale: null
+
+ depth: 16
+ embed_dim: 1024
+ num_heads: 16
+
+ clone_batch: 12
+
+ ema_decay: 0.9997
+ ema_end_decay: 1
+ ema_anneal_end_step: 300000
+ ema_encoder_only: false
+
+ average_top_k_layers: 16
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: false
+
+ layerdrop: 0
+ norm_eps: 1e-5
+
+ supported_modality: AUDIO
+
+ modalities:
+ audio:
+ feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]'
+ conv_pos_depth: 5
+ conv_pos_width: 95
+ conv_pos_groups: 16
+ prenet_depth: 8
+ mask_prob: 0.55
+ mask_prob_adjust: 0.1
+ inverse_mask: false
+ mask_length: 5
+ mask_noise_std: 0.01
+ mask_dropout: 0
+ add_masks: false
+ ema_local_encoder: false
+ use_alibi_encoder: true
+ prenet_layerdrop: 0
+ prenet_dropout: 0.1
+ learned_alibi_scale: true
+ learned_alibi_scale_per_head: true
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 768
+ decoder_groups: 16
+ decoder_kernel: 7
+ decoder_layers: 4
diff --git a/examples/data2vec/config/v2/large_images_only_task.yaml b/examples/data2vec/config/v2/large_images_only_task.yaml
new file mode 100644
index 0000000000..6b957fc129
--- /dev/null
+++ b/examples/data2vec/config/v2/large_images_only_task.yaml
@@ -0,0 +1,120 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+ local_cache_path: /scratch/cache_abaevski/imagenet
+ key: source
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 10
+ batch_size: 8
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 375300
+ lr: [ 0.0004 ]
+ debug_param_names: true
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 4e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ ema_decay: 0.9998
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 150000
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: true
+ end_of_block_targets: false
+
+ depth: 24
+ embed_dim: 1024
+ num_heads: 16
+
+ average_top_k_layers: 18
+ clone_batch: 16
+
+ norm_eps: 1e-6
+
+ min_target_var: 0
+ min_pred_var: 0
+
+ encoder_dropout: 0
+ post_mlp_drop: 0
+ attention_dropout: 0
+ activation_dropout: 0
+
+ supported_modality: IMAGE
+ cls_loss: 0.01
+
+ ema_encoder_only: false
+
+ modalities:
+ image:
+ inverse_mask: true
+ mask_prob: 0.75
+ mask_prob_adjust: 0.1
+ mask_length: 3
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ ema_local_encoder: true
+ num_extra_tokens: 1
+ init_extra_token_zero: false
+ use_alibi_encoder: false
+ embed_dim: 1024
+ decoder:
+ decoder_dim: 1024
+ decoder_groups: 16
+ decoder_kernel: 5
+ decoder_layers: 3
+ input_dropout: 0
\ No newline at end of file
diff --git a/examples/data2vec/config/v2/large_text_only_task.yaml b/examples/data2vec/config/v2/large_text_only_task.yaml
new file mode 100644
index 0000000000..fd69048e77
--- /dev/null
+++ b/examples/data2vec/config/v2/large_text_only_task.yaml
@@ -0,0 +1,112 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval_updates: 50000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: masked_lm
+ data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
+ sample_break_mode: none
+ tokens_per_sample: 512
+ include_target_tokens: true
+ random_token_prob: 0
+ leave_unmasked_prob: 0
+ include_index: True
+ skip_masking: True
+ d2v2_multi: True
+
+dataset:
+ batch_size: 2
+ ignore_unused_valid_subsets: true
+ skip_invalid_size_inputs_valid_test: true
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 32
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 600000
+ clip_norm: 1
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0001
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 0
+ loss_scale: 1
+
+ depth: 24
+ num_heads: 16
+ embed_dim: 1024
+ clone_batch: 8
+
+ ema_decay: 0.9999
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 100000
+ ema_encoder_only: true
+
+ average_top_k_layers: 24
+ layer_norm_target_layer: true
+ instance_norm_target_layer: false
+ batch_norm_target_layer: false
+ instance_norm_targets: true
+ layer_norm_targets: false
+
+ layerdrop: 0
+ norm_eps: 1e-5
+
+ supported_modality: TEXT
+
+ modalities:
+ text:
+ mask_prob: 0.5
+ mask_length: 1
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 768
+ decoder_groups: 1
+ decoder_kernel: 9
+ decoder_layers: 5
+ decoder_residual: false
+ projection_layers: 2
+ projection_ratio: 2.0
diff --git a/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml b/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml
new file mode 100644
index 0000000000..739e6f6724
--- /dev/null
+++ b/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml
@@ -0,0 +1,123 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ no_epoch_checkpoints: true
+ save_interval_updates: 50000
+ keep_interval_updates: 1
+
+distributed_training:
+ distributed_world_size: 32
+ ddp_backend: legacy_ddp
+
+task:
+ _name: masked_lm
+ data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
+ sample_break_mode: none
+ tokens_per_sample: 512
+ include_target_tokens: true
+ random_token_prob: 0
+ leave_unmasked_prob: 0
+ include_index: True
+ skip_masking: True
+ d2v2_multi: True
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+dataset:
+ batch_size: 2
+ ignore_unused_valid_subsets: true
+ skip_invalid_size_inputs_valid_test: true
+ disable_validation: true
+
+optimization:
+ clip_norm: 1
+ lr: [3e-4]
+ max_update: 1000000
+ update_freq: [1]
+
+optimizer:
+ _name: composite
+ groups:
+ default:
+ lr_float: 1e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+ decoder:
+ lr_float: 1e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 4
+ loss_scale: 1
+
+ depth: 24
+ num_heads: 16
+ embed_dim: 1024
+ clone_batch: 8
+
+ ema_decay: 0.9999
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 100000
+ ema_encoder_only: true
+
+ average_top_k_layers: 24
+ layer_norm_target_layer: true
+ instance_norm_target_layer: false
+ batch_norm_target_layer: false
+ instance_norm_targets: true
+ layer_norm_targets: false
+
+ layerdrop: 0
+ norm_eps: 1e-5
+
+ supported_modality: TEXT
+ decoder_group: true
+
+ modalities:
+ text:
+ mask_prob: 0.5
+ mask_length: 1
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 768
+ decoder_groups: 1
+ decoder_kernel: 9
+ decoder_layers: 5
+ decoder_residual: false
+ projection_layers: 2
+ projection_ratio: 2.0
diff --git a/examples/data2vec/config/v2/run_config/local.yaml b/examples/data2vec/config/v2/run_config/local.yaml
new file mode 100644
index 0000000000..45595f9eea
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/v2/run_config/slurm_1.yaml b/examples/data2vec/config/v2/run_config/slurm_1.yaml
new file mode 100644
index 0000000000..732f018899
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_1.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml
new file mode 100644
index 0000000000..b2184f8cfa
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.local_cache_path
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_2.yaml b/examples/data2vec/config/v2/run_config/slurm_2.yaml
new file mode 100644
index 0000000000..ec53dc2a98
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_2.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml
new file mode 100644
index 0000000000..553765597f
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml
@@ -0,0 +1,39 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.local_cache_path
+ - task.data
+ - task.post_save_script
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - model.model_path
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 12
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_3.yaml b/examples/data2vec/config/v2/run_config/slurm_3.yaml
new file mode 100644
index 0000000000..14b47d14e6
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_4.yaml b/examples/data2vec/config/v2/run_config/slurm_4.yaml
new file mode 100644
index 0000000000..c54d735fb2
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml
new file mode 100644
index 0000000000..a77f62aece
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - task.post_save_script
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 12
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml
new file mode 100644
index 0000000000..20e06582be
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 12
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 6
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_8.yaml b/examples/data2vec/config/v2/run_config/slurm_8.yaml
new file mode 100644
index 0000000000..e3ec2c2847
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_8.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml
new file mode 100644
index 0000000000..a9dce876cc
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 12
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/text_finetuning/cola.yaml b/examples/data2vec/config/v2/text_finetuning/cola.yaml
new file mode 100644
index 0000000000..d4ac4ec8b8
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/cola.yaml
@@ -0,0 +1,60 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: mcc
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+ report_mcc: True
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 320
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 5336
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/mnli.yaml b/examples/data2vec/config/v2/text_finetuning/mnli.yaml
new file mode 100644
index 0000000000..1a9d6e52f0
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/mnli.yaml
@@ -0,0 +1,60 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 3
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ valid_subset: valid,valid1
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 7432
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 123873
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/mrpc.yaml b/examples/data2vec/config/v2/text_finetuning/mrpc.yaml
new file mode 100644
index 0000000000..8f93d9d9ea
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/mrpc.yaml
@@ -0,0 +1,60 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: acc_and_f1
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+ report_acc_and_f1: True
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 137
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 2296
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/qnli.yaml b/examples/data2vec/config/v2/text_finetuning/qnli.yaml
new file mode 100644
index 0000000000..739fb53b69
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/qnli.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 1986
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 33112
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/qqp.yaml b/examples/data2vec/config/v2/text_finetuning/qqp.yaml
new file mode 100644
index 0000000000..9accbaa521
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/qqp.yaml
@@ -0,0 +1,60 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: acc_and_f1
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+ report_acc_and_f1: True
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 28318
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 113272
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/rte.yaml b/examples/data2vec/config/v2/text_finetuning/rte.yaml
new file mode 100644
index 0000000000..ea07764d98
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/rte.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 122
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 2036
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml b/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml
new file mode 100644
index 0000000000..45595f9eea
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/v2/text_finetuning/sst_2.yaml b/examples/data2vec/config/v2/text_finetuning/sst_2.yaml
new file mode 100644
index 0000000000..a273e5b943
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/sst_2.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 1256
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 20935
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/sts_b.yaml b/examples/data2vec/config/v2/text_finetuning/sts_b.yaml
new file mode 100644
index 0000000000..fb009ab95b
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/sts_b.yaml
@@ -0,0 +1,61 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 1
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: pearson_and_spearman
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+ regression_target: true
+ report_pearson_and_spearman: True
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 214
+
+optimization:
+ clip_norm: 0.0
+ lr: [4e-05]
+ max_update: 3598
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/imagenet.yaml b/examples/data2vec/config/vision/finetuning/imagenet.yaml
new file mode 100644
index 0000000000..d6d4864cca
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/imagenet.yaml
@@ -0,0 +1,52 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: accuracy
+
+task:
+ _name: image_classification
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 64
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ valid_subset: val
+
+distributed_training:
+ distributed_world_size: 8
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - correct
+
+optimization:
+ max_update: 100000
+ lr: [0.0005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 10000
+
+model:
+ _name: data2vec_image_classification
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml b/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml
new file mode 100644
index 0000000000..17d4c0a8f5
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml
@@ -0,0 +1,65 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+
+task:
+ _name: mae_image_classification
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 32
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 2
+ valid_subset: val
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - correct
+
+optimization:
+ max_update: 250200
+ lr: [0.001]
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.001
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 16000
+ min_lr: 1e-6
+
+
+lr_scheduler: pass_through
+
+model:
+ _name: mae_image_classification
+ mixup: 0.7
+ mixup_prob: 0.9
+
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml b/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml
new file mode 100644
index 0000000000..2d2eb57bac
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml
@@ -0,0 +1,68 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+
+task:
+ _name: mae_image_classification
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 32
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 2
+ valid_subset: val
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - correct
+
+optimization:
+ max_update: 125200
+ lr: [0.0005]
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0005
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 16000
+ min_lr: 1e-20
+
+
+lr_scheduler: pass_through
+
+model:
+ _name: mae_image_classification
+ mixup: 0.7
+ mixup_prob: 0.9
+ layer_decay: 0.75
+ drop_path_rate: 0.2
+
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml b/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml
new file mode 100644
index 0000000000..3a9413cef6
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml
@@ -0,0 +1,68 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+
+task:
+ _name: mae_image_classification
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 32
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 2
+ valid_subset: val
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - correct
+
+optimization:
+ max_update: 125200
+ lr: [0.0005]
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0005
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 16000
+ min_lr: 1e-7
+
+
+lr_scheduler: pass_through
+
+model:
+ _name: mae_image_classification
+ mixup: 0.7
+ mixup_prob: 0.9
+ layer_decay: 0.75
+ drop_path_rate: 0.2
+
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/run_config/local.yaml b/examples/data2vec/config/vision/finetuning/run_config/local.yaml
new file mode 100644
index 0000000000..45595f9eea
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml
new file mode 100644
index 0000000000..732f018899
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml
new file mode 100644
index 0000000000..e2bab5675a
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml
new file mode 100644
index 0000000000..c8b0f02a9b
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - task.local_cache_path
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml
new file mode 100644
index 0000000000..93d0d9c20a
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - task.local_cache_path
+ - model.model_path
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml
new file mode 100644
index 0000000000..14b47d14e6
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml
new file mode 100644
index 0000000000..c54d735fb2
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml
new file mode 100644
index 0000000000..d5d11cb755
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml
new file mode 100644
index 0000000000..906f08a602
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 6
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml
new file mode 100644
index 0000000000..d60e13f8ba
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/base_imagenet.yaml b/examples/data2vec/config/vision/pretraining/base_imagenet.yaml
new file mode 100644
index 0000000000..9bfc0f32b6
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/base_imagenet.yaml
@@ -0,0 +1,52 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+
+dataset:
+ num_workers: 6
+ batch_size: 64
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 10000
+
+model:
+ _name: data2vec_vision
diff --git a/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml b/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml
new file mode 100644
index 0000000000..5fd399b117
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml
@@ -0,0 +1,64 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: image_pretraining
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 128
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 2
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+
+optimization:
+ max_update: 375300 #300*1251
+ lr: [0.0005]
+ clip_norm: 3.0
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.999)
+ adam_eps: 1e-08
+ weight_decay: 0.05
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 12510 # it should be 10 epochs
+
+model:
+ _name: data2vec_vision
+
+ attention_dropout: 0.05
+
+ ema_decay: 0.999
+ ema_end_decay: 0.9998
+ layer_norm_targets: True
+ average_top_k_layers: 6
+
+ loss_beta: 2.0
+
+ drop_path: 0.25
diff --git a/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml b/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml
new file mode 100644
index 0000000000..d7872b5e04
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml
@@ -0,0 +1,64 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+
+dataset:
+ num_workers: 6
+ batch_size: 64
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+
+optimization:
+ max_update: 375300
+ lr: [0.0006]
+
+optimizer:
+ _name: composite
+ groups:
+ with_decay:
+ lr_float: 6e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+ no_decay:
+ lr_float: 6e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: mae
diff --git a/examples/data2vec/config/vision/pretraining/run_config/local.yaml b/examples/data2vec/config/vision/pretraining/run_config/local.yaml
new file mode 100644
index 0000000000..45595f9eea
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml
new file mode 100644
index 0000000000..732f018899
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml
new file mode 100644
index 0000000000..e2bab5675a
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml
new file mode 100644
index 0000000000..c8b0f02a9b
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - task.local_cache_path
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml
new file mode 100644
index 0000000000..032e53a304
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - task.local_cache_path
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml
new file mode 100644
index 0000000000..14b47d14e6
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml
new file mode 100644
index 0000000000..c54d735fb2
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml
new file mode 100644
index 0000000000..d5d11cb755
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml
new file mode 100644
index 0000000000..906f08a602
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 6
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml
new file mode 100644
index 0000000000..d60e13f8ba
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/data/__init__.py b/examples/data2vec/data/__init__.py
new file mode 100644
index 0000000000..d76112bfc2
--- /dev/null
+++ b/examples/data2vec/data/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .image_dataset import ImageDataset
+from .path_dataset import PathDataset
+from .mae_image_dataset import MaeImageDataset
+from .mae_finetuning_image_dataset import MaeFinetuningImageDataset
+
+
+__all__ = [
+ "ImageDataset",
+ "MaeImageDataset",
+ "MaeFinetuningImageDataset",
+ "PathDataset",
+]
\ No newline at end of file
diff --git a/examples/data2vec/data/add_class_target_dataset.py b/examples/data2vec/data/add_class_target_dataset.py
new file mode 100644
index 0000000000..c346c83e58
--- /dev/null
+++ b/examples/data2vec/data/add_class_target_dataset.py
@@ -0,0 +1,63 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from fairseq.data import BaseWrapperDataset, data_utils
+
+
+class AddClassTargetDataset(BaseWrapperDataset):
+ def __init__(
+ self,
+ dataset,
+ labels,
+ multi_class,
+ num_classes=None,
+ label_indices=None,
+ add_to_input=True,
+ ):
+ super().__init__(dataset)
+
+ self.label_indices = label_indices
+ self.labels = labels
+ self.multi_class = multi_class
+ self.add_to_input = add_to_input
+ if num_classes is None and multi_class:
+ assert self.label_indices is not None
+ num_classes = len(self.label_indices)
+
+ self.num_classes = num_classes
+
+ def __getitem__(self, index):
+ item = self.dataset[index]
+ item_labels = self.labels[index]
+ if self.multi_class:
+ item["label"] = torch.zeros(self.num_classes)
+ for il in item_labels:
+ if self.label_indices is not None:
+ il = self.label_indices[il]
+ item["label"][il] = 1.0
+ else:
+ item["label"] = torch.tensor(
+ self.labels[index]
+ if self.label_indices is None
+ else self.label_indices[self.labels[index]]
+ )
+
+ return item
+
+ def collater(self, samples):
+ collated = self.dataset.collater(samples)
+ if len(collated) == 0:
+ return collated
+
+ indices = set(collated["id"].tolist())
+ target = [s["label"] for s in samples if s["id"] in indices]
+ collated["label"] = torch.stack(target, dim=0)
+
+ if self.add_to_input:
+ collated["net_input"]["label"] = collated["label"]
+
+ return collated
diff --git a/examples/data2vec/data/image_dataset.py b/examples/data2vec/data/image_dataset.py
new file mode 100644
index 0000000000..7f551057e8
--- /dev/null
+++ b/examples/data2vec/data/image_dataset.py
@@ -0,0 +1,127 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+
+import numpy as np
+import os
+from typing import Optional, Callable, Set
+
+import torch
+
+from torchvision.datasets.vision import VisionDataset
+from torchvision.transforms import ToTensor
+
+from fairseq.data import FairseqDataset
+
+
+logger = logging.getLogger(__name__)
+
+
+class ImageDataset(FairseqDataset, VisionDataset):
+ def __init__(
+ self,
+ root: str,
+ extensions: Set[str],
+ load_classes: bool,
+ transform: Optional[Callable] = None,
+ shuffle=True,
+ ):
+ FairseqDataset.__init__(self)
+ VisionDataset.__init__(self, root=root, transform=transform)
+
+ self.shuffle = shuffle
+ self.tensor_transform = ToTensor()
+
+ self.classes = None
+ self.labels = None
+ if load_classes:
+ classes = [d.name for d in os.scandir(root) if d.is_dir()]
+ classes.sort()
+ self.classes = {cls_name: i for i, cls_name in enumerate(classes)}
+ logger.info(f"loaded {len(self.classes)} classes")
+ self.labels = []
+
+ def walk_path(root_path):
+ for root, _, fnames in sorted(os.walk(root_path, followlinks=True)):
+ for fname in sorted(fnames):
+ fname_ext = os.path.splitext(fname)
+ if fname_ext[-1].lower() not in extensions:
+ continue
+
+ path = os.path.join(root, fname)
+ yield path
+
+ logger.info(f"finding images in {root}")
+ if self.classes is not None:
+ self.files = []
+ self.labels = []
+ for c, i in self.classes.items():
+ for f in walk_path(os.path.join(root, c)):
+ self.files.append(f)
+ self.labels.append(i)
+ else:
+ self.files = [f for f in walk_path(root)]
+
+ logger.info(f"loaded {len(self.files)} examples")
+
+ def __getitem__(self, index):
+ from PIL import Image
+
+ fpath = self.files[index]
+
+ with open(fpath, "rb") as f:
+ img = Image.open(f).convert("RGB")
+
+ if self.transform is None:
+ img = self.tensor_transform(img)
+ else:
+ img = self.transform(img)
+ assert torch.is_tensor(img)
+
+ res = {"id": index, "img": img}
+
+ if self.labels is not None:
+ res["label"] = self.labels[index]
+
+ return res
+
+ def __len__(self):
+ return len(self.files)
+
+ def collater(self, samples):
+ if len(samples) == 0:
+ return {}
+
+ collated_img = torch.stack([s["img"] for s in samples], dim=0)
+
+ res = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": {
+ "img": collated_img,
+ },
+ }
+
+ if "label" in samples[0]:
+ res["net_input"]["label"] = torch.LongTensor([s["label"] for s in samples])
+
+ return res
+
+ def num_tokens(self, index):
+ return 1
+
+ def size(self, index):
+ return 1
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ order = [np.random.permutation(len(self))]
+ else:
+ order = [np.arange(len(self))]
+
+ return order[0]
diff --git a/examples/data2vec/data/mae_finetuning_image_dataset.py b/examples/data2vec/data/mae_finetuning_image_dataset.py
new file mode 100644
index 0000000000..28cbcb38ac
--- /dev/null
+++ b/examples/data2vec/data/mae_finetuning_image_dataset.py
@@ -0,0 +1,135 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+
+import numpy as np
+import os
+
+import torch
+
+from torchvision import datasets, transforms
+
+from timm.data import create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+import PIL
+
+from fairseq.data import FairseqDataset
+from .mae_image_dataset import caching_loader
+
+
+logger = logging.getLogger(__name__)
+
+
+def build_transform(is_train, input_size, color_jitter, aa, reprob, remode, recount):
+ mean = IMAGENET_DEFAULT_MEAN
+ std = IMAGENET_DEFAULT_STD
+ # train transform
+ if is_train:
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=input_size,
+ is_training=True,
+ color_jitter=color_jitter,
+ auto_augment=aa,
+ interpolation="bicubic",
+ re_prob=reprob,
+ re_mode=remode,
+ re_count=recount,
+ mean=mean,
+ std=std,
+ )
+ return transform
+
+ # eval transform
+ t = []
+ if input_size <= 224:
+ crop_pct = 224 / 256
+ else:
+ crop_pct = 1.0
+ size = int(input_size / crop_pct)
+ t.append(
+ transforms.Resize(
+ size, interpolation=PIL.Image.BICUBIC
+ ), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(transforms.Normalize(mean, std))
+ return transforms.Compose(t)
+
+
+class MaeFinetuningImageDataset(FairseqDataset):
+ def __init__(
+ self,
+ root: str,
+ split: str,
+ is_train: bool,
+ input_size,
+ color_jitter=None,
+ aa="rand-m9-mstd0.5-inc1",
+ reprob=0.25,
+ remode="pixel",
+ recount=1,
+ local_cache_path=None,
+ shuffle=True,
+ ):
+ FairseqDataset.__init__(self)
+
+ self.shuffle = shuffle
+
+ transform = build_transform(
+ is_train, input_size, color_jitter, aa, reprob, remode, recount
+ )
+
+ path = os.path.join(root, split)
+ loader = caching_loader(local_cache_path, datasets.folder.default_loader)
+
+ self.dataset = datasets.ImageFolder(path, loader=loader, transform=transform)
+
+ logger.info(f"loaded {len(self.dataset)} examples")
+
+ def __getitem__(self, index):
+ img, label = self.dataset[index]
+ return {"id": index, "img": img, "label": label}
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def collater(self, samples):
+ if len(samples) == 0:
+ return {}
+
+ collated_img = torch.stack([s["img"] for s in samples], dim=0)
+
+ res = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": {
+ "imgs": collated_img,
+ },
+ }
+
+ if "label" in samples[0]:
+ res["net_input"]["labels"] = torch.LongTensor([s["label"] for s in samples])
+
+ return res
+
+ def num_tokens(self, index):
+ return 1
+
+ def size(self, index):
+ return 1
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ order = [np.random.permutation(len(self))]
+ else:
+ order = [np.arange(len(self))]
+
+ return order[0]
diff --git a/examples/data2vec/data/mae_image_dataset.py b/examples/data2vec/data/mae_image_dataset.py
new file mode 100644
index 0000000000..4aacb94895
--- /dev/null
+++ b/examples/data2vec/data/mae_image_dataset.py
@@ -0,0 +1,418 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from functools import partial
+import logging
+import math
+import random
+import time
+
+import numpy as np
+import os
+
+import torch
+
+from torchvision import datasets, transforms
+from .path_dataset import PathDataset
+
+from fairseq.data import FairseqDataset
+from fairseq.data.data_utils import compute_block_mask_1d, compute_block_mask_2d
+
+from shutil import copyfile
+
+logger = logging.getLogger(__name__)
+
+
+def load(path, loader, cache):
+ if hasattr(caching_loader, "cache_root"):
+ cache = caching_loader.cache_root
+
+ cached_path = cache + path
+
+ num_tries = 3
+ for curr_try in range(num_tries):
+ try:
+ if curr_try == 2:
+ return loader(path)
+ if not os.path.exists(cached_path) or curr_try > 0:
+ os.makedirs(os.path.dirname(cached_path), exist_ok=True)
+ copyfile(path, cached_path)
+ os.chmod(cached_path, 0o777)
+ return loader(cached_path)
+ except Exception as e:
+ logger.warning(str(e))
+ if "Errno 13" in str(e):
+ caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}"
+ logger.warning(f"setting cache root to {caching_loader.cache_root}")
+ cached_path = caching_loader.cache_root + path
+ if curr_try == (num_tries - 1):
+ raise
+ time.sleep(2)
+
+
+def caching_loader(cache_root: str, loader):
+ if cache_root is None:
+ return loader
+
+ if cache_root == "slurm_tmpdir":
+ cache_root = os.environ["SLURM_TMPDIR"]
+ assert len(cache_root) > 0
+
+ if not cache_root.endswith("/"):
+ cache_root += "/"
+
+ return partial(load, loader=loader, cache=cache_root)
+
+
+class RandomResizedCropAndInterpolationWithTwoPic:
+ """Crop the given PIL Image to random size and aspect ratio with random interpolation.
+
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+ is finally resized to given size.
+ This is popularly used to train the Inception networks.
+
+ Args:
+ size: expected output size of each edge
+ scale: range of size of the origin size cropped
+ ratio: range of aspect ratio of the origin aspect ratio cropped
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(
+ self,
+ size,
+ second_size=None,
+ scale=(0.08, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+ interpolation="bilinear",
+ second_interpolation="lanczos",
+ ):
+ if isinstance(size, tuple):
+ self.size = size
+ else:
+ self.size = (size, size)
+ if second_size is not None:
+ if isinstance(second_size, tuple):
+ self.second_size = second_size
+ else:
+ self.second_size = (second_size, second_size)
+ else:
+ self.second_size = None
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+ logger.warning("range should be of kind (min, max)")
+
+ if interpolation == "random":
+ from PIL import Image
+
+ self.interpolation = (Image.BILINEAR, Image.BICUBIC)
+ else:
+ self.interpolation = self._pil_interp(interpolation)
+
+ self.second_interpolation = (
+ self._pil_interp(second_interpolation)
+ if second_interpolation is not None
+ else None
+ )
+ self.scale = scale
+ self.ratio = ratio
+
+ def _pil_interp(self, method):
+ from PIL import Image
+
+ if method == "bicubic":
+ return Image.BICUBIC
+ elif method == "lanczos":
+ return Image.LANCZOS
+ elif method == "hamming":
+ return Image.HAMMING
+ else:
+ # default bilinear, do we want to allow nearest?
+ return Image.BILINEAR
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ area = img.size[0] * img.size[1]
+
+ for attempt in range(10):
+ target_area = random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if w <= img.size[0] and h <= img.size[1]:
+ i = random.randint(0, img.size[1] - h)
+ j = random.randint(0, img.size[0] - w)
+ return i, j, h, w
+
+ # Fallback to central crop
+ in_ratio = img.size[0] / img.size[1]
+ if in_ratio < min(ratio):
+ w = img.size[0]
+ h = int(round(w / min(ratio)))
+ elif in_ratio > max(ratio):
+ h = img.size[1]
+ w = int(round(h * max(ratio)))
+ else: # whole image
+ w = img.size[0]
+ h = img.size[1]
+ i = (img.size[1] - h) // 2
+ j = (img.size[0] - w) // 2
+ return i, j, h, w
+
+ def __call__(self, img):
+ import torchvision.transforms.functional as F
+
+ """
+ Args:
+ img (PIL Image): Image to be cropped and resized.
+
+ Returns:
+ PIL Image: Randomly cropped and resized image.
+ """
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ if isinstance(self.interpolation, (tuple, list)):
+ interpolation = random.choice(self.interpolation)
+ else:
+ interpolation = self.interpolation
+ if self.second_size is None:
+ return F.resized_crop(img, i, j, h, w, self.size, interpolation)
+ else:
+ return F.resized_crop(
+ img, i, j, h, w, self.size, interpolation
+ ), F.resized_crop(
+ img, i, j, h, w, self.second_size, self.second_interpolation
+ )
+
+
+class MaeImageDataset(FairseqDataset):
+ def __init__(
+ self,
+ root: str,
+ split: str,
+ input_size,
+ local_cache_path=None,
+ shuffle=True,
+ key="imgs",
+ beit_transforms=False,
+ target_transform=False,
+ no_transform=False,
+ compute_mask=False,
+ patch_size: int = 16,
+ mask_prob: float = 0.75,
+ mask_prob_adjust: float = 0,
+ mask_length: int = 1,
+ inverse_mask: bool = False,
+ expand_adjacent: bool = False,
+ mask_dropout: float = 0,
+ non_overlapping: bool = False,
+ require_same_masks: bool = True,
+ clone_batch: int = 1,
+ dataset_type: str = "imagefolder",
+ ):
+ FairseqDataset.__init__(self)
+
+ self.shuffle = shuffle
+ self.key = key
+
+ loader = caching_loader(local_cache_path, datasets.folder.default_loader)
+
+ self.transform_source = None
+ self.transform_target = None
+
+ if target_transform:
+ self.transform_source = transforms.ColorJitter(0.4, 0.4, 0.4)
+ self.transform_target = transforms.ColorJitter(0.4, 0.4, 0.4)
+
+ if no_transform:
+ if input_size <= 224:
+ crop_pct = 224 / 256
+ else:
+ crop_pct = 1.0
+ size = int(input_size / crop_pct)
+
+ self.transform_train = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=3),
+ transforms.CenterCrop(input_size),
+ ]
+ )
+
+ self.transform_train = transforms.Resize((input_size, input_size))
+ elif beit_transforms:
+ beit_transform_list = []
+ if not target_transform:
+ beit_transform_list.append(transforms.ColorJitter(0.4, 0.4, 0.4))
+ beit_transform_list.extend(
+ [
+ transforms.RandomHorizontalFlip(p=0.5),
+ RandomResizedCropAndInterpolationWithTwoPic(
+ size=input_size,
+ second_size=None,
+ interpolation="bicubic",
+ second_interpolation=None,
+ ),
+ ]
+ )
+ self.transform_train = transforms.Compose(beit_transform_list)
+ else:
+ self.transform_train = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ input_size, scale=(0.2, 1.0), interpolation=3
+ ), # 3 is bicubic
+ transforms.RandomHorizontalFlip(),
+ ]
+ )
+ self.final_transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ ),
+ ]
+ )
+
+ if dataset_type == "imagefolder":
+ self.dataset = datasets.ImageFolder(
+ os.path.join(root, split), loader=loader
+ )
+ elif dataset_type == "path":
+ self.dataset = PathDataset(
+ root,
+ loader,
+ None,
+ None,
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225],
+ )
+ else:
+ raise Exception(f"invalid dataset type {dataset_type}")
+
+ logger.info(
+ f"initial transform: {self.transform_train}, "
+ f"source transform: {self.transform_source}, "
+ f"target transform: {self.transform_target}, "
+ f"final transform: {self.final_transform}"
+ )
+ logger.info(f"loaded {len(self.dataset)} examples")
+
+ self.is_compute_mask = compute_mask
+ self.patches = (input_size // patch_size) ** 2
+ self.mask_prob = mask_prob
+ self.mask_prob_adjust = mask_prob_adjust
+ self.mask_length = mask_length
+ self.inverse_mask = inverse_mask
+ self.expand_adjacent = expand_adjacent
+ self.mask_dropout = mask_dropout
+ self.non_overlapping = non_overlapping
+ self.require_same_masks = require_same_masks
+ self.clone_batch = clone_batch
+
+ def __getitem__(self, index):
+ img, _ = self.dataset[index]
+
+ img = self.transform_train(img)
+
+ source = None
+ target = None
+ if self.transform_source is not None:
+ source = self.final_transform(self.transform_source(img))
+ if self.transform_target is not None:
+ target = self.final_transform(self.transform_target(img))
+
+ if source is None:
+ img = self.final_transform(img)
+
+ v = {"id": index, self.key: source if source is not None else img}
+ if target is not None:
+ v["target"] = target
+
+ if self.is_compute_mask:
+ if self.mask_length == 1:
+ mask = compute_block_mask_1d(
+ shape=(self.clone_batch, self.patches),
+ mask_prob=self.mask_prob,
+ mask_length=self.mask_length,
+ mask_prob_adjust=self.mask_prob_adjust,
+ inverse_mask=self.inverse_mask,
+ require_same_masks=True,
+ )
+ else:
+ mask = compute_block_mask_2d(
+ shape=(self.clone_batch, self.patches),
+ mask_prob=self.mask_prob,
+ mask_length=self.mask_length,
+ mask_prob_adjust=self.mask_prob_adjust,
+ inverse_mask=self.inverse_mask,
+ require_same_masks=True,
+ expand_adjcent=self.expand_adjacent,
+ mask_dropout=self.mask_dropout,
+ non_overlapping=self.non_overlapping,
+ )
+
+ v["precomputed_mask"] = mask
+
+ return v
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def collater(self, samples):
+ if len(samples) == 0:
+ return {}
+
+ collated_img = torch.stack([s[self.key] for s in samples], dim=0)
+
+ res = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": {
+ self.key: collated_img,
+ },
+ }
+
+ if "target" in samples[0]:
+ collated_target = torch.stack([s["target"] for s in samples], dim=0)
+ res["net_input"]["target"] = collated_target
+
+ if "precomputed_mask" in samples[0]:
+ collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0)
+ res["net_input"]["precomputed_mask"] = collated_mask
+
+ return res
+
+ def num_tokens(self, index):
+ return 1
+
+ def size(self, index):
+ return 1
+
+ @property
+ def sizes(self):
+ return np.full((len(self),), 1)
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ order = [np.random.permutation(len(self))]
+ else:
+ order = [np.arange(len(self))]
+
+ return order[0]
diff --git a/examples/data2vec/data/modality.py b/examples/data2vec/data/modality.py
new file mode 100644
index 0000000000..aa23ac94f7
--- /dev/null
+++ b/examples/data2vec/data/modality.py
@@ -0,0 +1,14 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+from enum import Enum, auto
+
+
+class Modality(Enum):
+ AUDIO = auto()
+ IMAGE = auto()
+ TEXT = auto()
diff --git a/examples/data2vec/data/path_dataset.py b/examples/data2vec/data/path_dataset.py
new file mode 100644
index 0000000000..02010058e6
--- /dev/null
+++ b/examples/data2vec/data/path_dataset.py
@@ -0,0 +1,64 @@
+import glob
+import os
+from typing import List, Optional, Tuple
+
+import logging
+import numpy as np
+import torchvision.transforms.functional as TF
+import PIL
+from PIL import Image
+from torchvision.datasets import VisionDataset
+
+logger = logging.getLogger(__name__)
+
+
+class PathDataset(VisionDataset):
+ def __init__(
+ self,
+ root: List[str],
+ loader: None = None,
+ transform: Optional[str] = None,
+ extra_transform: Optional[str] = None,
+ mean: Optional[List[float]] = None,
+ std: Optional[List[float]] = None,
+ ):
+ super().__init__(root=root)
+
+ PIL.Image.MAX_IMAGE_PIXELS = 256000001
+
+ self.files = []
+ for folder in self.root:
+ self.files.extend(
+ sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True))
+ )
+ self.files.extend(
+ sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True))
+ )
+
+ self.transform = transform
+ self.extra_transform = extra_transform
+ self.mean = mean
+ self.std = std
+
+ self.loader = loader
+
+ logger.info(f"loaded {len(self.files)} samples from {root}")
+
+ assert (mean is None) == (std is None)
+
+ def __len__(self) -> int:
+ return len(self.files)
+
+ def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
+ path = self.files[idx]
+
+ if self.loader is not None:
+ return self.loader(path), None
+
+ img = Image.open(path).convert("RGB")
+ if self.transform is not None:
+ img = self.transform(img)
+ img = TF.to_tensor(img)
+ if self.mean is not None and self.std is not None:
+ img = TF.normalize(img, self.mean, self.std)
+ return img, None
diff --git a/examples/data2vec/fb_convert_beit_cp.py b/examples/data2vec/fb_convert_beit_cp.py
new file mode 100644
index 0000000000..cf42ace762
--- /dev/null
+++ b/examples/data2vec/fb_convert_beit_cp.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import torch
+
+from omegaconf import OmegaConf
+
+from fairseq.criterions.model_criterion import ModelCriterionConfig
+from fairseq.dataclass.configs import FairseqConfig
+
+from tasks import ImageClassificationConfig, ImagePretrainingConfig
+from models.data2vec_image_classification import (
+ Data2VecImageClassificationConfig,
+ Data2VecImageClassificationModel,
+)
+from models.data2vec_vision import Data2VecVisionConfig, Data2VecVisionModel
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="convert beit checkpoint into data2vec - vision checkpoint"
+ )
+ # fmt: off
+ parser.add_argument('checkpoint', help='checkpoint to convert')
+ parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted checkpoint')
+ parser.add_argument('--type', type=str, choices=['vision', 'image_classification'], default='image_classification', help='type of model to upgrade')
+ parser.add_argument('--inception_norms', action='store_true', default=False)
+ # fmt: on
+
+ return parser
+
+
+def update_checkpoint(model_dict, prefix, is_nested):
+
+ replace_paths = {
+ "cls_token": "model.cls_emb" if is_nested else "cls_emb",
+ "patch_embed": "model.patch_embed" if is_nested else "patch_embed",
+ "mask_token": "mask_emb",
+ }
+
+ starts_with = {
+ "patch_embed.proj": "model.patch_embed.conv"
+ if is_nested
+ else "patch_embed.conv",
+ "lm_head": "final_proj",
+ "fc_norm": "fc_norm",
+ "head": "head",
+ }
+
+ partial = {
+ "mlp.fc1": "mlp.0",
+ "mlp.fc2": "mlp.2",
+ }
+
+ for k in list(model_dict.keys()):
+ for sw, r in starts_with.items():
+ if k.startswith(sw):
+ replace_paths[k] = k.replace(sw, r)
+ for p, r in partial.items():
+ if p in k:
+ replace_paths[k] = prefix + k.replace(p, r)
+
+ if prefix != "":
+ for k in list(model_dict.keys()):
+ if k not in replace_paths:
+ replace_paths[k] = prefix + k
+
+ for k in list(model_dict.keys()):
+ if k in replace_paths:
+ model_dict[replace_paths[k]] = model_dict[k]
+ if k != replace_paths[k]:
+ del model_dict[k]
+
+ return model_dict
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ cp = torch.load(args.checkpoint, map_location="cpu")
+
+ cfg = FairseqConfig(
+ criterion=ModelCriterionConfig(_name="model", log_keys=["correct"]),
+ )
+
+ if args.type == "image_classification":
+
+ cfg.task = ImageClassificationConfig(
+ _name="image_classification",
+ data=".",
+ )
+
+ if args.inception_norms:
+ cfg.task.normalization_mean = [0.5, 0.5, 0.5]
+ cfg.task.normalization_std = [0.5, 0.5, 0.5]
+
+ cfg.model = Data2VecImageClassificationConfig(
+ _name="data2vec_image_classification",
+ )
+ cfg.model.pretrained_model_args = FairseqConfig(
+ model=Data2VecVisionConfig(
+ _name="data2vec_vision", shared_rel_pos_bias=False
+ ),
+ task=ImagePretrainingConfig(
+ _name="image_pretraining",
+ ),
+ )
+
+ cfg = OmegaConf.create(cfg)
+
+ state = {
+ "cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
+ "model": cp["module"],
+ "best_loss": None,
+ "optimizer": None,
+ "extra_state": {},
+ }
+
+ model = Data2VecImageClassificationModel(cfg.model)
+ model.load_state_dict(
+ update_checkpoint(state["model"], prefix="model.encoder.", is_nested=True),
+ strict=True,
+ )
+ elif args.type == "vision":
+ cfg.task = ImagePretrainingConfig(
+ _name="image_pretraining",
+ data=".",
+ )
+
+ if args.inception_norms:
+ cfg.task.normalization_mean = [0.5, 0.5, 0.5]
+ cfg.task.normalization_std = [0.5, 0.5, 0.5]
+
+ cfg.model = Data2VecVisionConfig(
+ _name="data2vec_vision",
+ )
+ cfg = OmegaConf.create(cfg)
+
+ state = {
+ "cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
+ "model": cp["model"],
+ "best_loss": None,
+ "optimizer": None,
+ "extra_state": {},
+ }
+
+ model = Data2VecVisionModel(cfg.model)
+ model.load_state_dict(
+ update_checkpoint(state["model"], prefix="encoder.", is_nested=False),
+ strict=True,
+ )
+ else:
+ raise Exception("unsupported type " + args.type)
+
+ print(state["cfg"], state.keys())
+ torch.save(state, args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/data2vec/models/__init__.py b/examples/data2vec/models/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/data2vec/models/audio_classification.py b/examples/data2vec/models/audio_classification.py
new file mode 100644
index 0000000000..06d2158267
--- /dev/null
+++ b/examples/data2vec/models/audio_classification.py
@@ -0,0 +1,614 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import contextlib
+import logging
+import re
+from dataclasses import dataclass, field
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from omegaconf import II, MISSING, open_dict
+
+from fairseq import checkpoint_utils, tasks
+from fairseq.dataclass import FairseqDataclass
+from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+from fairseq.models import (
+ BaseFairseqModel,
+ register_model,
+)
+from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
+from fairseq.modules import TransposeLast
+from fairseq.tasks import FairseqTask
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class AudioClassificationConfig(FairseqDataclass):
+ model_path: str = field(
+ default=MISSING, metadata={"help": "path to wav2vec 2.0 model"}
+ )
+ no_pretrained_weights: bool = field(
+ default=False, metadata={"help": "if true, does not load pretrained weights"}
+ )
+ dropout_input: float = field(
+ default=0.0,
+ metadata={"help": "dropout to apply to the input (after feat extr)"},
+ )
+ final_dropout: float = field(
+ default=0.0,
+ metadata={"help": "dropout after transformer and before final projection"},
+ )
+ dropout: float = field(
+ default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"}
+ )
+ attention_dropout: float = field(
+ default=0.0,
+ metadata={
+ "help": "dropout probability for attention weights inside wav2vec 2.0 model"
+ },
+ )
+ activation_dropout: float = field(
+ default=0.0,
+ metadata={
+ "help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
+ },
+ )
+
+ # masking
+ apply_mask: bool = field(
+ default=False, metadata={"help": "apply masking during fine-tuning"}
+ )
+ mask_length: int = field(
+ default=10, metadata={"help": "repeat the mask indices multiple times"}
+ )
+ mask_prob: float = field(
+ default=0.5,
+ metadata={
+ "help": "probability of replacing a token with mask (normalized by length)"
+ },
+ )
+ mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
+ default="static", metadata={"help": "how to choose masks"}
+ )
+ mask_other: float = field(
+ default=0,
+ metadata={
+ "help": "secondary mask argument (used for more complex distributions), "
+ "see help in compute_mask_indices"
+ },
+ )
+ no_mask_overlap: bool = field(
+ default=False, metadata={"help": "whether to allow masks to overlap"}
+ )
+ mask_min_space: Optional[int] = field(
+ default=1,
+ metadata={"help": "min space between spans (if no overlap is enabled)"},
+ )
+ require_same_masks: bool = field(
+ default=True,
+ metadata={
+ "help": "whether to number of masked timesteps must be the same across all "
+ "examples in a batch"
+ },
+ )
+ mask_dropout: float = field(
+ default=0.0,
+ metadata={"help": "percent of masks to unmask for each sample"},
+ )
+
+ # channel masking
+ mask_channel_length: int = field(
+ default=10, metadata={"help": "length of the mask for features (channels)"}
+ )
+ mask_channel_prob: float = field(
+ default=0.0, metadata={"help": "probability of replacing a feature with 0"}
+ )
+ mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
+ default="static",
+ metadata={"help": "how to choose mask length for channel masking"},
+ )
+ mask_channel_other: float = field(
+ default=0,
+ metadata={
+ "help": "secondary mask argument (used for more complex distributions), "
+ "see help in compute_mask_indicesh"
+ },
+ )
+ no_mask_channel_overlap: bool = field(
+ default=False, metadata={"help": "whether to allow channel masks to overlap"}
+ )
+ freeze_finetune_updates: int = field(
+ default=0, metadata={"help": "dont finetune wav2vec for this many updates"}
+ )
+ feature_grad_mult: float = field(
+ default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"}
+ )
+ layerdrop: float = field(
+ default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
+ )
+ mask_channel_min_space: Optional[int] = field(
+ default=1,
+ metadata={"help": "min space between spans (if no overlap is enabled)"},
+ )
+ mask_channel_before: bool = False
+ normalize: bool = II("task.normalize")
+ data: str = II("task.data")
+ # this holds the loaded wav2vec args
+ d2v_args: Any = None
+ offload_activations: bool = field(
+ default=False, metadata={"help": "offload_activations"}
+ )
+ min_params_to_wrap: int = field(
+ default=int(1e8),
+ metadata={
+ "help": "minimum number of params for a layer to be wrapped with FSDP() when "
+ "training with --ddp-backend=fully_sharded. Smaller values will "
+ "improve memory efficiency, but may make torch.distributed "
+ "communication less efficient due to smaller input sizes. This option "
+ "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
+ "--offload-activations are passed."
+ },
+ )
+
+ checkpoint_activations: bool = field(
+ default=False,
+ metadata={"help": "recompute activations and save memory for extra compute"},
+ )
+ ddp_backend: str = II("distributed_training.ddp_backend")
+
+ prediction_mode: str = "lin_softmax"
+ eval_prediction_mode: Optional[str] = None
+ conv_kernel: int = -1
+ conv_stride: int = 1
+ two_convs: bool = False
+ extreme_factor: float = 1.0
+
+ conv_feature_layers: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "string describing convolutional feature extraction layers in form of a python list that contains "
+ "[(dim, kernel_size, stride), ...]"
+ },
+ )
+
+ mixup_prob: float = 1.0
+ source_mixup: float = -1
+ same_mixup: bool = True
+ label_mixup: bool = False
+
+ gain_mode: str = "none"
+
+
+@register_model("audio_classification", dataclass=AudioClassificationConfig)
+class AudioClassificationModel(BaseFairseqModel):
+ def __init__(self, cfg: AudioClassificationConfig, num_classes):
+ super().__init__()
+
+ self.apply_mask = cfg.apply_mask
+ self.cfg = cfg
+
+ arg_overrides = {
+ "dropout": cfg.dropout,
+ "activation_dropout": cfg.activation_dropout,
+ "dropout_input": cfg.dropout_input,
+ "attention_dropout": cfg.attention_dropout,
+ "mask_length": cfg.mask_length,
+ "mask_prob": cfg.mask_prob,
+ "require_same_masks": getattr(cfg, "require_same_masks", True),
+ "mask_dropout": getattr(cfg, "mask_dropout", 0),
+ "mask_selection": cfg.mask_selection,
+ "mask_other": cfg.mask_other,
+ "no_mask_overlap": cfg.no_mask_overlap,
+ "mask_channel_length": cfg.mask_channel_length,
+ "mask_channel_prob": cfg.mask_channel_prob,
+ "mask_channel_before": cfg.mask_channel_before,
+ "mask_channel_selection": cfg.mask_channel_selection,
+ "mask_channel_other": cfg.mask_channel_other,
+ "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
+ "encoder_layerdrop": cfg.layerdrop,
+ "feature_grad_mult": cfg.feature_grad_mult,
+ "checkpoint_activations": cfg.checkpoint_activations,
+ "offload_activations": cfg.offload_activations,
+ "min_params_to_wrap": cfg.min_params_to_wrap,
+ "mixup": -1,
+ }
+
+ if cfg.conv_feature_layers is not None:
+ arg_overrides["conv_feature_layers"] = cfg.conv_feature_layers
+
+ if cfg.d2v_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(
+ cfg.model_path, arg_overrides
+ )
+ d2v_args = state.get("cfg", None)
+ if d2v_args is None:
+ d2v_args = convert_namespace_to_omegaconf(state["args"])
+ d2v_args.criterion = None
+ d2v_args.lr_scheduler = None
+ cfg.d2v_args = d2v_args
+
+ logger.info(d2v_args)
+
+ else:
+ state = None
+ d2v_args = cfg.d2v_args
+
+ model_normalized = d2v_args.task.get(
+ "normalize", d2v_args.model.get("normalize", False)
+ )
+ assert cfg.normalize == model_normalized, (
+ "Fine-tuning works best when data normalization is the same. "
+ "Please check that --normalize is set or unset for both pre-training and here"
+ )
+
+ if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
+ with open_dict(d2v_args):
+ d2v_args.model.checkpoint_activations = cfg.checkpoint_activations
+
+ d2v_args.task.data = cfg.data
+ task = tasks.setup_task(d2v_args.task)
+ model = task.build_model(d2v_args.model, from_checkpoint=True)
+
+ model.remove_pretraining_modules()
+
+ if state is not None and not cfg.no_pretrained_weights:
+ self.load_model_weights(state, model, cfg)
+
+ d = d2v_args.model.encoder_embed_dim
+
+ self.d2v_model = model
+
+ self.final_dropout = nn.Dropout(cfg.final_dropout)
+ self.freeze_finetune_updates = cfg.freeze_finetune_updates
+ self.num_updates = 0
+
+ for p in self.parameters():
+ p.param_group = "pretrained"
+
+ if cfg.prediction_mode == "proj_avg_proj":
+ self.proj = nn.Linear(d, d * 2)
+ self.proj2 = nn.Linear(d * 2, num_classes)
+
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+ for p in self.proj2.parameters():
+ p.param_group = "projection"
+ elif self.cfg.prediction_mode == "summary_proj":
+ self.proj = nn.Linear(d // 3, num_classes)
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+ elif self.cfg.conv_kernel > 1 and not self.cfg.two_convs:
+ self.proj = nn.Sequential(
+ TransposeLast(),
+ nn.Conv1d(d, num_classes, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
+ TransposeLast(),
+ )
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+ elif self.cfg.conv_kernel > 0 and self.cfg.two_convs:
+ self.proj = nn.Sequential(
+ TransposeLast(),
+ nn.Conv1d(d, d, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
+ TransposeLast(),
+ nn.GELU(),
+ nn.Linear(d, num_classes),
+ )
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+ else:
+ self.proj = nn.Linear(d, num_classes)
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ @classmethod
+ def build_model(cls, cfg: AudioClassificationConfig, task: FairseqTask):
+ """Build a new model instance."""
+
+ assert hasattr(task, "labels"), f"Task {task} must have an attribute 'labels'"
+
+ return cls(cfg, len(task.labels))
+
+ def load_model_weights(self, state, model, cfg):
+ if cfg.ddp_backend == "fully_sharded":
+ from fairseq.distributed import FullyShardedDataParallel
+
+ for name, module in model.named_modules():
+ if "encoder.layers" in name and len(name.split(".")) == 3:
+ # Only for layers, we do a special handling and load the weights one by one
+ # We dont load all weights together as that wont be memory efficient and may
+ # cause oom
+ new_dict = {
+ k.replace(name + ".", ""): v
+ for (k, v) in state["model"].items()
+ if name + "." in k
+ }
+ assert isinstance(module, FullyShardedDataParallel)
+ with module.summon_full_params():
+ module.load_state_dict(new_dict, strict=True)
+ module._reset_lazy_init()
+
+ # Once layers are loaded, filter them out and load everything else.
+ r = re.compile("encoder.layers.\d.")
+ filtered_list = list(filter(r.match, state["model"].keys()))
+
+ new_big_dict = {
+ k: v for (k, v) in state["model"].items() if k not in filtered_list
+ }
+
+ model.load_state_dict(new_big_dict, strict=False)
+ else:
+ if "_ema" in state["model"]:
+ del state["model"]["_ema"]
+ model.load_state_dict(state["model"], strict=False)
+
+ def set_num_updates(self, num_updates):
+ """Set the number of parameters updates."""
+ super().set_num_updates(num_updates)
+ self.num_updates = num_updates
+
+ def compute_gain(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
+ if fs == 16000:
+ n_fft = 2048
+ elif fs == 44100:
+ n_fft = 4096
+ else:
+ raise Exception("Invalid fs {}".format(fs))
+ stride = n_fft // 2
+
+ def a_weight(fs, n_fft, min_db=-80.0):
+ freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
+ freq_sq = np.power(freq, 2)
+ freq_sq[0] = 1.0
+ weight = 2.0 + 20.0 * (
+ 2 * np.log10(12194)
+ + 2 * np.log10(freq_sq)
+ - np.log10(freq_sq + 12194 ** 2)
+ - np.log10(freq_sq + 20.6 ** 2)
+ - 0.5 * np.log10(freq_sq + 107.7 ** 2)
+ - 0.5 * np.log10(freq_sq + 737.9 ** 2)
+ )
+ weight = np.maximum(weight, min_db)
+
+ return weight
+
+ gain = []
+ for i in range(0, len(sound) - n_fft + 1, stride):
+ if mode == "RMSE":
+ g = np.mean(sound[i : i + n_fft] ** 2)
+ elif mode == "A_weighting":
+ spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i : i + n_fft])
+ power_spec = np.abs(spec) ** 2
+ a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
+ g = np.sum(a_weighted_spec)
+ else:
+ raise Exception("Invalid mode {}".format(mode))
+ gain.append(g)
+
+ gain = np.array(gain)
+ gain = np.maximum(gain, np.power(10, min_db / 10))
+ gain_db = 10 * np.log10(gain)
+
+ return gain_db
+
+ # adapted from https://github.com/mil-tokyo/bc_learning_sound/blob/master/utils.py
+ def compute_gain_torch(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
+ if fs == 16000:
+ n_fft = 2048
+ elif fs == 44100:
+ n_fft = 4096
+ else:
+ raise Exception("Invalid fs {}".format(fs))
+
+ if mode == "A_weighting":
+ if not hasattr(self, f"a_weight"):
+ self.a_weight = {}
+
+ if fs not in self.a_weight:
+
+ def a_weight(fs, n_fft, min_db=-80.0):
+ freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
+ freq_sq = freq ** 2
+ freq_sq[0] = 1.0
+ weight = 2.0 + 20.0 * (
+ 2 * np.log10(12194)
+ + 2 * np.log10(freq_sq)
+ - np.log10(freq_sq + 12194 ** 2)
+ - np.log10(freq_sq + 20.6 ** 2)
+ - 0.5 * np.log10(freq_sq + 107.7 ** 2)
+ - 0.5 * np.log10(freq_sq + 737.9 ** 2)
+ )
+ weight = np.maximum(weight, min_db)
+
+ return weight
+
+ self.a_weight[fs] = torch.from_numpy(
+ np.power(10, a_weight(fs, n_fft, min_db) / 10)
+ ).to(device=sound.device)
+
+ sound = sound.unfold(-1, n_fft, n_fft // 2)
+
+ if mode == "RMSE":
+ sound = sound ** 2
+ g = sound.mean(-1)
+ elif mode == "A_weighting":
+ w = torch.hann_window(n_fft, device=sound.device) * sound
+ spec = torch.fft.rfft(w)
+ power_spec = spec.abs() ** 2
+ a_weighted_spec = power_spec * self.a_weight[fs]
+ g = a_weighted_spec.sum(-1)
+ else:
+ raise Exception("Invalid mode {}".format(mode))
+
+ gain = torch.maximum(g, torch.tensor(10 ** (min_db / 10), device=g.device))
+ gain_db = 10 * torch.log10(gain)
+
+ return gain_db
+
+ def forward(self, source, padding_mask, label=None, **kwargs):
+
+ if self.cfg.source_mixup >= 0 and self.training and self.cfg.mixup_prob > 0:
+ with torch.no_grad():
+ mixed_source = source
+ mix_mask = None
+ if self.cfg.mixup_prob < 1:
+ mix_mask = (
+ torch.empty((source.size(0),), device=source.device)
+ .bernoulli_(self.cfg.mixup_prob)
+ .bool()
+ )
+ mixed_source = source[mix_mask]
+
+ r = (
+ torch.FloatTensor(
+ 1 if self.cfg.same_mixup else mixed_source.size(0)
+ )
+ .uniform_(max(1e-6, self.cfg.source_mixup), 1)
+ .to(dtype=source.dtype, device=source.device)
+ )
+
+ mixup_perm = torch.randperm(source.size(0))
+ s2 = source[mixup_perm]
+
+ if self.cfg.gain_mode == "none":
+ p = r.unsqueeze(-1)
+ if mix_mask is not None:
+ s2 = s2[mix_mask]
+ else:
+ if self.cfg.gain_mode == "naive_rms":
+ G1 = source.pow(2).mean(dim=-1).sqrt()
+ else:
+ G1, _ = self.compute_gain_torch(
+ source, mode=self.cfg.gain_mode
+ ).max(-1)
+ G1 = G1.to(dtype=source.dtype)
+
+ G2 = G1[mixup_perm]
+
+ if mix_mask is not None:
+ G1 = G1[mix_mask]
+ G2 = G2[mix_mask]
+ s2 = s2[mix_mask]
+
+ p = 1 / (1 + 10 ** ((G1 - G2) / 20) * (1 - r) / r)
+ p = p.unsqueeze(-1)
+
+ mixed = (p * mixed_source) + (1 - p) * s2
+
+ if mix_mask is None:
+ source = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
+ else:
+ source[mix_mask] = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
+
+ if label is not None and self.cfg.label_mixup:
+ r = r.unsqueeze(-1)
+ if mix_mask is None:
+ label = label * r + (1 - r) * label[mixup_perm]
+ else:
+ label[mix_mask] = (
+ label[mix_mask] * r + (1 - r) * label[mixup_perm][mix_mask]
+ )
+
+ d2v_args = {
+ "source": source,
+ "padding_mask": padding_mask,
+ "mask": self.apply_mask and self.training,
+ }
+
+ ft = self.freeze_finetune_updates <= self.num_updates
+
+ with torch.no_grad() if not ft else contextlib.ExitStack():
+ res = self.d2v_model.extract_features(**d2v_args)
+
+ x = res["x"]
+ padding_mask = res["padding_mask"]
+ if padding_mask is not None:
+ x[padding_mask] = 0
+
+ x = self.final_dropout(x)
+
+ if self.training or (
+ self.cfg.eval_prediction_mode is None or self.cfg.eval_prediction_mode == ""
+ ):
+ prediction_mode = self.cfg.prediction_mode
+ else:
+ prediction_mode = self.cfg.eval_prediction_mode
+
+ if prediction_mode == "average_before":
+ x = x.mean(dim=1)
+
+ if prediction_mode != "summary_mha" and prediction_mode != "summary_proj" and prediction_mode != "cls":
+ x = self.proj(x)
+
+ logits = True
+ if prediction_mode == "lin_softmax":
+ x = F.logsigmoid(x.float())
+ x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x, dim=1)
+ x = x.clamp(max=0)
+ x = x - torch.log(-(torch.expm1(x)))
+ elif prediction_mode == "extremized_odds":
+ x = x.float().sum(dim=1)
+ x = x * self.cfg.extreme_factor
+ elif prediction_mode == "average_before":
+ x = x.float()
+ elif prediction_mode == "average":
+ x = x.float().mean(dim=1)
+ elif prediction_mode == "average_sigmoid":
+ x = torch.sigmoid(x.float())
+ x = x.mean(dim=1)
+ logits = False
+ elif prediction_mode == "max":
+ x, _ = x.float().max(dim=1)
+ elif prediction_mode == "max_sigmoid":
+ x = torch.sigmoid(x.float())
+ x, _ = x.float().max(dim=1)
+ logits = False
+ elif prediction_mode == "proj_avg_proj":
+ x = x.mean(dim=1)
+ x = self.proj2(x)
+ elif prediction_mode == "summary_mha" or prediction_mode == "summary_proj":
+ x = self.d2v_model.summary(
+ x, padding_mask, proj=prediction_mode == "summary_proj"
+ )
+ x = x.type_as(source)
+ x = self.proj(x)
+ elif prediction_mode == "cls":
+ x = x[:,0]
+ x = self.proj(x)
+ else:
+ raise Exception(f"unknown prediction mode {prediction_mode}")
+
+ if label is None:
+ return torch.sigmoid(x) if logits else x
+
+ x = torch.nan_to_num(x)
+
+ if logits:
+ loss = F.binary_cross_entropy_with_logits(
+ x, label.float(), reduction="none"
+ )
+ else:
+ loss = F.binary_cross_entropy(x, label.float(), reduction="none")
+
+ result = {
+ "losses": {
+ "main": loss,
+ },
+ "sample_size": label.sum(),
+ }
+
+ if not self.training:
+ result["_predictions"] = torch.sigmoid(x) if logits else x
+ result["_targets"] = label
+
+ return result
diff --git a/examples/data2vec/models/data2vec2.py b/examples/data2vec/models/data2vec2.py
new file mode 100644
index 0000000000..0c61b37081
--- /dev/null
+++ b/examples/data2vec/models/data2vec2.py
@@ -0,0 +1,813 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+from dataclasses import dataclass, field
+from typing import Optional, Callable
+from functools import partial
+import numpy as np
+
+from omegaconf import II
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from fairseq.modules import EMAModule, EMAModuleConfig
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+
+from examples.data2vec.data.modality import Modality
+
+from examples.data2vec.models.modalities.base import (
+ MaskSeed,
+ D2vModalityConfig,
+ ModalitySpecificEncoder,
+ get_annealed_rate,
+)
+from examples.data2vec.models.modalities.modules import (
+ D2vDecoderConfig,
+ AltBlock,
+ Decoder1d,
+)
+
+from examples.data2vec.models.modalities.audio import (
+ D2vAudioConfig,
+ AudioEncoder,
+)
+from examples.data2vec.models.modalities.images import (
+ D2vImageConfig,
+ ImageEncoder,
+)
+from examples.data2vec.models.modalities.text import (
+ D2vTextConfig,
+ TextEncoder,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class D2vModalitiesConfig(FairseqDataclass):
+ audio: D2vAudioConfig = D2vAudioConfig()
+ image: D2vImageConfig = D2vImageConfig()
+ text: D2vTextConfig = D2vTextConfig()
+
+
+@dataclass
+class Data2VecMultiConfig(FairseqDataclass):
+
+ loss_beta: float = field(
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
+ )
+ loss_scale: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
+ },
+ )
+
+ depth: int = 8
+ start_drop_path_rate: float = 0
+ end_drop_path_rate: float = 0
+ num_heads: int = 12
+ norm_eps: float = 1e-6
+ norm_affine: bool = True
+ encoder_dropout: float = 0.1
+ post_mlp_drop: float = 0.1
+ attention_dropout: float = 0.1
+ activation_dropout: float = 0.0
+ dropout_input: float = 0.0
+ layerdrop: float = 0.0
+ embed_dim: int = 768
+ mlp_ratio: float = 4
+ layer_norm_first: bool = False
+
+ average_top_k_layers: int = field(
+ default=8, metadata={"help": "how many layers to average"}
+ )
+
+ end_of_block_targets: bool = False
+
+ clone_batch: int = 1
+
+ layer_norm_target_layer: bool = False
+ batch_norm_target_layer: bool = False
+ instance_norm_target_layer: bool = False
+ instance_norm_targets: bool = False
+ layer_norm_targets: bool = False
+
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
+ ema_same_dtype: bool = True
+ log_norms: bool = True
+ ema_end_decay: float = field(
+ default=0.9999, metadata={"help": "final ema decay rate"}
+ )
+
+ # when to finish annealing ema decay rate
+ ema_anneal_end_step: int = II("optimization.max_update")
+
+ ema_encoder_only: bool = field(
+ default=True,
+ metadata={
+ "help": "whether to momentum update only the shared transformer encoder"
+ },
+ )
+
+ max_update: int = II("optimization.max_update")
+
+ modalities: D2vModalitiesConfig = D2vModalitiesConfig()
+
+ shared_decoder: Optional[D2vDecoderConfig] = None
+
+ min_target_var: float = field(
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
+ )
+ min_pred_var: float = field(
+ default=0.01,
+ metadata={"help": "stop training if prediction var falls below this"},
+ )
+
+ supported_modality: Optional[Modality] = None
+ mae_init: bool = False
+
+ seed: int = II("common.seed")
+
+ skip_ema: bool = False
+
+ cls_loss: float = 0
+ recon_loss: float = 0
+ d2v_loss: float = 1
+
+ decoder_group: bool = False
+
+
+@register_model("data2vec_multi", dataclass=Data2VecMultiConfig)
+class Data2VecMultiModel(BaseFairseqModel):
+ def make_modality_encoder(
+ self,
+ cfg: D2vModalityConfig,
+ embed_dim: int,
+ make_block: Callable[[float], nn.ModuleList],
+ norm_layer: Callable[[int], nn.LayerNorm],
+ layer_norm_first: bool,
+ alibi_biases,
+ task,
+ ) -> ModalitySpecificEncoder:
+ if cfg.type == Modality.AUDIO:
+ enc_cls = AudioEncoder
+ elif cfg.type == Modality.IMAGE:
+ enc_cls = ImageEncoder
+ elif cfg.type == Modality.TEXT:
+ enc_cls = TextEncoder
+ if hasattr(task, "text_task"):
+ task = task.text_task
+ else:
+ raise Exception(f"unsupported modality {cfg.type}")
+
+ return enc_cls(
+ cfg,
+ embed_dim,
+ make_block,
+ norm_layer,
+ layer_norm_first,
+ alibi_biases,
+ task,
+ )
+
+ def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None):
+ super().__init__()
+ self.cfg = cfg
+ self.modalities = modalities
+ self.task = task
+
+ make_layer_norm = partial(
+ nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
+ )
+
+ def make_block(drop_path, dim=None, heads=None):
+ return AltBlock(
+ cfg.embed_dim if dim is None else dim,
+ cfg.num_heads if heads is None else heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ drop=cfg.encoder_dropout,
+ attn_drop=cfg.attention_dropout,
+ mlp_drop=cfg.activation_dropout,
+ post_mlp_drop=cfg.post_mlp_drop,
+ drop_path=drop_path,
+ norm_layer=make_layer_norm,
+ layer_norm_first=cfg.layer_norm_first,
+ ffn_targets=not cfg.end_of_block_targets,
+ )
+
+ self.alibi_biases = {}
+ self.modality_encoders = nn.ModuleDict()
+ for mod in self.modalities:
+ mod_cfg = getattr(cfg.modalities, mod.name.lower())
+ enc = self.make_modality_encoder(
+ mod_cfg,
+ cfg.embed_dim,
+ make_block,
+ make_layer_norm,
+ cfg.layer_norm_first,
+ self.alibi_biases,
+ task,
+ )
+ self.modality_encoders[mod.name] = enc
+
+ self.ema = None
+
+ self.average_top_k_layers = cfg.average_top_k_layers
+ self.loss_beta = cfg.loss_beta
+ self.loss_scale = cfg.loss_scale
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+
+ dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
+
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
+
+ self.norm = None
+ if cfg.layer_norm_first:
+ self.norm = make_layer_norm(cfg.embed_dim)
+
+ if self.cfg.mae_init:
+ self.apply(self._init_weights)
+ else:
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+ self.apply(init_bert_params)
+
+ for mod_enc in self.modality_encoders.values():
+ mod_enc.reset_parameters()
+
+ if not skip_ema:
+ self.ema = self.make_ema_teacher(cfg.ema_decay)
+ self.shared_decoder = (
+ Decoder1d(cfg.shared_decoder, cfg.embed_dim)
+ if self.cfg.shared_decoder is not None
+ else None
+ )
+ if self.shared_decoder is not None:
+ self.shared_decoder.apply(self._init_weights)
+
+ self.recon_proj = None
+ if cfg.recon_loss > 0:
+ self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
+
+ for pn, p in self.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn:
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
+ if cfg.decoder_group and "decoder" in pn:
+ p.param_group = "decoder"
+
+ self.num_updates = 0
+
+ def _init_weights(self, m):
+
+ try:
+ from apex.normalization import FusedLayerNorm
+
+ fn = FusedLayerNorm
+ except:
+ fn = nn.LayerNorm
+
+ if isinstance(m, nn.Linear):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, fn):
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.no_grad()
+ def make_ema_teacher(self, ema_decay):
+ ema_config = EMAModuleConfig(
+ ema_decay=ema_decay,
+ ema_fp32=True,
+ log_norms=self.cfg.log_norms,
+ add_missing_params=False,
+ )
+
+ model_copy = self.make_target_model()
+
+ return EMAModule(
+ model_copy,
+ ema_config,
+ copy_model=False,
+ )
+
+ def make_target_model(self):
+ logger.info("making target model")
+
+ model_copy = Data2VecMultiModel(
+ self.cfg, self.modalities, skip_ema=True, task=self.task
+ )
+
+ if self.cfg.ema_encoder_only:
+ model_copy = model_copy.blocks
+ for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()):
+ p_t.data.copy_(p_s.data)
+ else:
+ for p_s, p_t in zip(self.parameters(), model_copy.parameters()):
+ p_t.data.copy_(p_s.data)
+
+ for mod_enc in model_copy.modality_encoders.values():
+ mod_enc.decoder = None
+ if not mod_enc.modality_cfg.ema_local_encoder:
+ mod_enc.local_encoder = None
+ mod_enc.project_features = None
+
+ model_copy.requires_grad_(False)
+ return model_copy
+
+ def set_num_updates(self, num_updates):
+ super().set_num_updates(num_updates)
+
+ if self.ema is not None and (
+ (self.num_updates == 0 and num_updates > 1)
+ or self.num_updates >= num_updates
+ ):
+ pass
+ elif self.training and self.ema is not None:
+ ema_weight_decay = None
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
+ if num_updates >= self.cfg.ema_anneal_end_step:
+ decay = self.cfg.ema_end_decay
+ else:
+ decay = get_annealed_rate(
+ self.cfg.ema_decay,
+ self.cfg.ema_end_decay,
+ num_updates,
+ self.cfg.ema_anneal_end_step,
+ )
+ self.ema.set_decay(decay, weight_decay=ema_weight_decay)
+ if self.ema.get_decay() < 1:
+ self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)
+
+ self.num_updates = num_updates
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ state = super().state_dict(destination, prefix, keep_vars)
+
+ if self.ema is not None:
+ state[prefix + "_ema"] = self.ema.fp32_params
+
+ return state
+
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+ k = prefix + "_ema"
+ if self.ema is not None:
+ assert k in state_dict
+ self.ema.restore(state_dict[k], True)
+ del state_dict[k]
+ elif k in state_dict:
+ del state_dict[k]
+
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+ @classmethod
+ def build_model(cls, cfg: Data2VecMultiConfig, task=None):
+ """Build a new model instance."""
+ if task is None or not hasattr(task, "supported_modalities"):
+ modalities = (
+ [cfg.supported_modality]
+ if cfg.supported_modality is not None
+ else [
+ Modality.AUDIO,
+ Modality.IMAGE,
+ Modality.TEXT,
+ ]
+ )
+ else:
+ modalities = task.supported_modalities
+ return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema)
+
+ def forward(
+ self,
+ source,
+ target=None,
+ id=None,
+ mode=None,
+ padding_mask=None,
+ mask=True,
+ features_only=False,
+ force_remove_masked=False,
+ remove_extra_tokens=True,
+ precomputed_mask=None,
+ ):
+ if mode is None:
+ assert self.cfg.supported_modality is not None
+ mode = self.cfg.supported_modality
+
+ if isinstance(mode, Modality):
+ mode = mode.name
+
+ feature_extractor = self.modality_encoders[mode]
+
+ mask_seeds = None
+ if id is not None:
+ mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)
+
+ extractor_out = feature_extractor(
+ source,
+ padding_mask,
+ mask,
+ remove_masked=not features_only or force_remove_masked,
+ clone_batch=self.cfg.clone_batch if not features_only else 1,
+ mask_seeds=mask_seeds,
+ precomputed_mask=precomputed_mask,
+ )
+
+ x = extractor_out["x"]
+ encoder_mask = extractor_out["encoder_mask"]
+ masked_padding_mask = extractor_out["padding_mask"]
+ masked_alibi_bias = extractor_out.get("alibi_bias", None)
+ alibi_scale = extractor_out.get("alibi_scale", None)
+
+ if self.dropout_input is not None:
+ x = self.dropout_input(x)
+
+ layer_results = []
+ for i, blk in enumerate(self.blocks):
+ if (
+ not self.training
+ or self.cfg.layerdrop == 0
+ or (np.random.random() > self.cfg.layerdrop)
+ ):
+ ab = masked_alibi_bias
+ if ab is not None and alibi_scale is not None:
+ scale = (
+ alibi_scale[i]
+ if alibi_scale.size(0) > 1
+ else alibi_scale.squeeze(0)
+ )
+ ab = ab * scale.type_as(ab)
+
+ x, lr = blk(
+ x,
+ padding_mask=masked_padding_mask,
+ alibi_bias=ab,
+ )
+ if features_only:
+ layer_results.append(lr)
+
+ if self.norm is not None:
+ x = self.norm(x)
+
+ if features_only:
+ if remove_extra_tokens:
+ x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
+ if masked_padding_mask is not None:
+ masked_padding_mask = masked_padding_mask[
+ :, feature_extractor.modality_cfg.num_extra_tokens :
+ ]
+
+ return {
+ "x": x,
+ "padding_mask": masked_padding_mask,
+ "layer_results": layer_results,
+ "mask": encoder_mask,
+ }
+
+ xs = []
+
+ if self.shared_decoder is not None:
+ dx = self.forward_decoder(
+ x,
+ feature_extractor,
+ self.shared_decoder,
+ encoder_mask,
+ )
+ xs.append(dx)
+ if feature_extractor.decoder is not None:
+ dx = self.forward_decoder(
+ x,
+ feature_extractor,
+ feature_extractor.decoder,
+ encoder_mask,
+ )
+ xs.append(dx)
+ orig_x = x
+
+ assert len(xs) > 0
+
+ p = next(self.ema.model.parameters())
+ device = x.device
+ dtype = x.dtype
+ ema_device = p.device
+ ema_dtype = p.dtype
+
+ if not self.cfg.ema_same_dtype:
+ dtype = ema_dtype
+
+ if ema_device != device or ema_dtype != dtype:
+ logger.info(f"adjusting ema dtype to {dtype} and device to {device}")
+ self.ema.model = self.ema.model.to(dtype=dtype, device=device)
+ ema_dtype = dtype
+
+ def to_device(d):
+ for k, p in d.items():
+ if isinstance(d[k], dict):
+ to_device(d[k])
+ else:
+ d[k] = p.to(device=device)
+
+ to_device(self.ema.fp32_params)
+ tm = self.ema.model
+
+ with torch.no_grad():
+ tm.eval()
+
+ if self.cfg.ema_encoder_only:
+ assert target is None
+ ema_input = extractor_out["local_features"]
+ ema_input = feature_extractor.contextualized_features(
+ ema_input.to(dtype=ema_dtype),
+ padding_mask,
+ mask=False,
+ remove_masked=False,
+ )
+ ema_blocks = tm
+ else:
+ ema_blocks = tm.blocks
+ if feature_extractor.modality_cfg.ema_local_encoder:
+ inp = (
+ target.to(dtype=ema_dtype)
+ if target is not None
+ else source.to(dtype=ema_dtype)
+ )
+ ema_input = tm.modality_encoders[mode](
+ inp,
+ padding_mask,
+ mask=False,
+ remove_masked=False,
+ )
+ else:
+ assert target is None
+ ema_input = extractor_out["local_features"]
+ ema_feature_enc = tm.modality_encoders[mode]
+ ema_input = ema_feature_enc.contextualized_features(
+ ema_input.to(dtype=ema_dtype),
+ padding_mask,
+ mask=False,
+ remove_masked=False,
+ )
+
+ ema_padding_mask = ema_input["padding_mask"]
+ ema_alibi_bias = ema_input.get("alibi_bias", None)
+ ema_alibi_scale = ema_input.get("alibi_scale", None)
+ ema_input = ema_input["x"]
+
+ y = []
+ ema_x = []
+ extra_tokens = feature_extractor.modality_cfg.num_extra_tokens
+ for i, blk in enumerate(ema_blocks):
+ ab = ema_alibi_bias
+ if ab is not None and alibi_scale is not None:
+ scale = (
+ ema_alibi_scale[i]
+ if ema_alibi_scale.size(0) > 1
+ else ema_alibi_scale.squeeze(0)
+ )
+ ab = ab * scale.type_as(ab)
+
+ ema_input, lr = blk(
+ ema_input,
+ padding_mask=ema_padding_mask,
+ alibi_bias=ab,
+ )
+ y.append(lr[:, extra_tokens:])
+ ema_x.append(ema_input[:, extra_tokens:])
+
+ y = self.make_targets(y, self.average_top_k_layers)
+ orig_targets = y
+
+ if self.cfg.clone_batch > 1:
+ y = y.repeat_interleave(self.cfg.clone_batch, 0)
+
+ masked = encoder_mask.mask.unsqueeze(-1)
+ masked_b = encoder_mask.mask.bool()
+ y = y[masked_b]
+
+ if xs[0].size(1) == masked_b.size(1):
+ xs = [x[masked_b] for x in xs]
+ else:
+ xs = [x.reshape(-1, x.size(-1)) for x in xs]
+
+ sample_size = masked.sum().long()
+
+ result = {
+ "losses": {},
+ "sample_size": sample_size,
+ }
+
+ sample_size = result["sample_size"]
+
+ if self.cfg.cls_loss > 0:
+ assert extra_tokens > 0
+ cls_target = orig_targets.mean(dim=1)
+ if self.cfg.clone_batch > 1:
+ cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)
+ cls_pred = x[:, extra_tokens - 1]
+ result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * (
+ self.cfg.cls_loss * sample_size
+ )
+
+ if self.cfg.recon_loss > 0:
+
+ with torch.no_grad():
+ target = feature_extractor.patchify(source)
+ mean = target.mean(dim=-1, keepdim=True)
+ var = target.var(dim=-1, keepdim=True)
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
+
+ if self.cfg.clone_batch > 1:
+ target = target.repeat_interleave(self.cfg.clone_batch, 0)
+
+ if masked_b is not None:
+ target = target[masked_b]
+
+ recon = xs[0]
+ if self.recon_proj is not None:
+ recon = self.recon_proj(recon)
+
+ result["losses"]["recon"] = (
+ self.d2v_loss(recon, target.float()) * self.cfg.recon_loss
+ )
+
+ if self.cfg.d2v_loss > 0:
+ for i, x in enumerate(xs):
+ reg_loss = self.d2v_loss(x, y)
+ n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression"
+ result["losses"][n] = reg_loss * self.cfg.d2v_loss
+
+ suffix = "" if len(self.modalities) == 1 else f"_{mode}"
+ with torch.no_grad():
+ if encoder_mask is not None:
+ result["masked_pct"] = 1 - (
+ encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1)
+ )
+ for i, x in enumerate(xs):
+ n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}"
+ result[n] = self.compute_var(x.float())
+ if self.ema is not None:
+ for k, v in self.ema.logs.items():
+ result[k] = v
+
+ y = y.float()
+ result[f"target_var{suffix}"] = self.compute_var(y)
+
+ if self.num_updates > 5000:
+ if result[f"target_var{suffix}"] < self.cfg.min_target_var:
+ logger.error(
+ f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
+ )
+ raise Exception(
+ f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
+ )
+
+ for k in result.keys():
+ if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var:
+ logger.error(
+ f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
+ )
+ raise Exception(
+ f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
+ )
+
+ result["ema_decay"] = self.ema.get_decay() * 1000
+
+ return result
+
+ def forward_decoder(
+ self,
+ x,
+ feature_extractor,
+ decoder,
+ mask_info,
+ ):
+ x = feature_extractor.decoder_input(x, mask_info)
+ x = decoder(*x)
+
+ return x
+
+ def d2v_loss(self, x, y):
+ x = x.view(-1, x.size(-1)).float()
+ y = y.view(-1, x.size(-1))
+
+ if self.loss_beta == 0:
+ loss = F.mse_loss(x, y, reduction="none")
+ else:
+ loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta)
+
+ if self.loss_scale is not None:
+ scale = self.loss_scale
+ else:
+ scale = 1 / math.sqrt(x.size(-1))
+
+ reg_loss = loss * scale
+
+ return reg_loss
+
+ def make_targets(self, y, num_layers):
+
+ with torch.no_grad():
+ target_layer_results = y[-num_layers:]
+
+ permuted = False
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
+ target_layer_results = [
+ tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT
+ ]
+ permuted = True
+ if self.cfg.batch_norm_target_layer:
+ target_layer_results = [
+ F.batch_norm(
+ tl.float(), running_mean=None, running_var=None, training=True
+ )
+ for tl in target_layer_results
+ ]
+ if self.cfg.instance_norm_target_layer:
+ target_layer_results = [
+ F.instance_norm(tl.float()) for tl in target_layer_results
+ ]
+ if permuted:
+ target_layer_results = [
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
+ ]
+ if self.cfg.layer_norm_target_layer:
+ target_layer_results = [
+ F.layer_norm(tl.float(), tl.shape[-1:])
+ for tl in target_layer_results
+ ]
+
+ y = target_layer_results[0].float()
+ for tl in target_layer_results[1:]:
+ y.add_(tl.float())
+ y = y.div_(len(target_layer_results))
+
+ if self.cfg.layer_norm_targets:
+ y = F.layer_norm(y, y.shape[-1:])
+
+ if self.cfg.instance_norm_targets:
+ y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
+
+ return y
+
+ @staticmethod
+ def compute_var(y):
+ y = y.view(-1, y.size(-1))
+ if dist.is_initialized():
+ zc = torch.tensor(y.size(0)).cuda()
+ zs = y.sum(dim=0)
+ zss = (y**2).sum(dim=0)
+
+ dist.all_reduce(zc)
+ dist.all_reduce(zs)
+ dist.all_reduce(zss)
+
+ var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
+ return torch.sqrt(var + 1e-6).mean()
+ else:
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
+
+ def extract_features(
+ self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
+ ):
+ res = self.forward(
+ source,
+ mode=mode,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=True,
+ remove_extra_tokens=remove_extra_tokens,
+ )
+ return res
+
+ def remove_pretraining_modules(self, modality=None, keep_decoder=False):
+ self.ema = None
+ self.cfg.clone_batch = 1
+ self.recon_proj = None
+
+ if not keep_decoder:
+ self.shared_decoder = None
+
+ modality = modality.lower() if modality is not None else None
+ for k in list(self.modality_encoders.keys()):
+ if modality is not None and k.lower() != modality:
+ del self.modality_encoders[k]
+ else:
+ self.modality_encoders[k].remove_pretraining_modules(
+ keep_decoder=keep_decoder
+ )
+ if not keep_decoder:
+ self.modality_encoders[k].decoder = None
diff --git a/examples/data2vec/models/data2vec_audio.py b/examples/data2vec/models/data2vec_audio.py
new file mode 100644
index 0000000000..261c2f104c
--- /dev/null
+++ b/examples/data2vec/models/data2vec_audio.py
@@ -0,0 +1,537 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+from dataclasses import dataclass, field
+from typing import Optional
+
+from omegaconf import II
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from fairseq.modules import EMAModule, EMAModuleConfig
+from fairseq.data.data_utils import compute_mask_indices
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.models.wav2vec import (
+ ConvFeatureExtractionModel,
+ Wav2Vec2Config,
+ TransformerEncoder,
+)
+from fairseq.modules import (
+ GradMultiply,
+ LayerNorm,
+)
+from fairseq.utils import index_put
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Data2VecAudioConfig(Wav2Vec2Config):
+
+ loss_beta: float = field(
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
+ )
+ loss_scale: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
+ },
+ )
+ average_top_k_layers: int = field(
+ default=8, metadata={"help": "how many layers to average"}
+ )
+
+ layer_norm_target_layer: bool = False
+ instance_norm_target_layer: bool = False
+ instance_norm_targets: bool = False
+ layer_norm_targets: bool = False
+ batch_norm_target_layer: bool = False
+ group_norm_target_layer: bool = False
+
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
+ ema_end_decay: float = field(
+ default=0.9999, metadata={"help": "final ema decay rate"}
+ )
+
+ # when to finish annealing ema decay rate
+ ema_anneal_end_step: int = II("optimization.max_update")
+
+ ema_transformer_only: bool = field(
+ default=True,
+ metadata={"help": "whether to momentum update only the transformer"},
+ )
+ ema_layers_only: bool = field(
+ default=True,
+ metadata={"help": "whether to momentum update only the transformer layers"},
+ )
+
+ max_update: int = II("optimization.max_update")
+
+ min_target_var: float = field(
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
+ )
+ min_pred_var: float = field(
+ default=0.01,
+ metadata={"help": "stop training if prediction var falls below this"},
+ )
+
+
+def get_annealed_rate(start, end, curr_step, total_steps):
+ r = end - start
+ pct_remaining = 1 - curr_step / total_steps
+ return end - r * pct_remaining
+
+
+@register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
+class Data2VecAudioModel(BaseFairseqModel):
+ def __init__(self, cfg: Data2VecAudioConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ feature_enc_layers = eval(cfg.conv_feature_layers)
+ self.extractor_embed = feature_enc_layers[-1][0]
+
+ self.ema = None
+ self.embed = cfg.encoder_embed_dim
+
+ self.average_top_k_layers = cfg.average_top_k_layers
+ self.loss_beta = cfg.loss_beta
+ self.loss_scale = cfg.loss_scale
+
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers,
+ dropout=0.0,
+ mode=cfg.extractor_mode,
+ conv_bias=cfg.conv_bias,
+ )
+
+ self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
+
+ self.mask_prob = cfg.mask_prob
+ self.mask_selection = cfg.mask_selection
+ self.mask_other = cfg.mask_other
+ self.mask_length = cfg.mask_length
+ self.no_mask_overlap = cfg.no_mask_overlap
+ self.mask_min_space = cfg.mask_min_space
+
+ self.mask_channel_prob = cfg.mask_channel_prob
+ self.mask_channel_before = cfg.mask_channel_before
+ self.mask_channel_selection = cfg.mask_channel_selection
+ self.mask_channel_other = cfg.mask_channel_other
+ self.mask_channel_length = cfg.mask_channel_length
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+ self.mask_channel_min_space = cfg.mask_channel_min_space
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+ self.feature_grad_mult = cfg.feature_grad_mult
+
+ self.mask_emb = nn.Parameter(
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
+ )
+
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.extractor_embed)
+
+ self.final_proj = nn.Linear(self.embed, self.embed)
+
+ self.num_updates = 0
+
+ def make_ema_teacher(self):
+ ema_config = EMAModuleConfig(
+ ema_decay=self.cfg.ema_decay,
+ ema_fp32=True,
+ )
+ skip_keys = set()
+ if self.cfg.ema_layers_only:
+ self.cfg.ema_transformer_only = True
+ for k, _ in self.encoder.pos_conv.named_parameters():
+ skip_keys.add(f"pos_conv.{k}")
+
+ self.ema = EMAModule(
+ self.encoder if self.cfg.ema_transformer_only else self,
+ ema_config,
+ skip_keys=skip_keys,
+ )
+
+ def set_num_updates(self, num_updates):
+ super().set_num_updates(num_updates)
+
+ if self.ema is None and self.final_proj is not None:
+ logger.info(f"making ema teacher")
+ self.make_ema_teacher()
+ elif self.training and self.ema is not None:
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
+ if num_updates >= self.cfg.ema_anneal_end_step:
+ decay = self.cfg.ema_end_decay
+ else:
+ decay = get_annealed_rate(
+ self.cfg.ema_decay,
+ self.cfg.ema_end_decay,
+ num_updates,
+ self.cfg.ema_anneal_end_step,
+ )
+ self.ema.set_decay(decay)
+ if self.ema.get_decay() < 1:
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
+
+ self.num_updates = num_updates
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ state = super().state_dict(destination, prefix, keep_vars)
+
+ if self.ema is not None:
+ state[prefix + "_ema"] = self.ema.fp32_params
+
+ return state
+
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+ if self.ema is not None:
+ k = prefix + "_ema"
+ assert k in state_dict
+ self.ema.restore(state_dict[k], True)
+ del state_dict[k]
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+ @classmethod
+ def build_model(cls, cfg: Data2VecAudioConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def apply_mask(
+ self,
+ x,
+ padding_mask,
+ mask_indices=None,
+ mask_channel_indices=None,
+ ):
+ B, T, C = x.shape
+
+ if self.mask_channel_prob > 0 and self.mask_channel_before:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x[mask_channel_indices] = 0
+
+ if self.mask_prob > 0:
+ if mask_indices is None:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=1,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ require_same_masks=self.cfg.require_same_masks,
+ mask_dropout=self.cfg.mask_dropout,
+ )
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
+ x = index_put(x, mask_indices, self.mask_emb)
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0 and not self.mask_channel_before:
+ if mask_channel_indices is None:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x = index_put(x, mask_channel_indices, 0)
+
+ return x, mask_indices
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ return torch.floor((input_length - kernel_size) / stride + 1)
+
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
+
+ for i in range(len(conv_cfg_list)):
+ input_lengths = _conv_out_length(
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
+ )
+
+ return input_lengths.to(torch.long)
+
+ def forward(
+ self,
+ source,
+ padding_mask=None,
+ mask=True,
+ features_only=False,
+ layer=None,
+ mask_indices=None,
+ mask_channel_indices=None,
+ padding_count=None,
+ ):
+ features = source
+
+ if self.feature_grad_mult > 0:
+ features = self.feature_extractor(features)
+ if self.feature_grad_mult != 1.0:
+ features = GradMultiply.apply(features, self.feature_grad_mult)
+ else:
+ with torch.no_grad():
+ features = self.feature_extractor(features)
+
+ features = features.transpose(1, 2)
+
+ features = self.layer_norm(features)
+
+ orig_padding_mask = padding_mask
+
+ if padding_mask is not None and padding_mask.any():
+ input_lengths = (1 - padding_mask.long()).sum(-1)
+ # apply conv formula to get real output_lengths
+ output_lengths = self._get_feat_extract_output_lengths(input_lengths)
+
+ padding_mask = torch.zeros(
+ features.shape[:2], dtype=features.dtype, device=features.device
+ )
+
+ # these two operations makes sure that all values
+ # before the output lengths indices are attended to
+ padding_mask[
+ (
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
+ output_lengths - 1,
+ )
+ ] = 1
+ padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
+ else:
+ padding_mask = None
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ pre_encoder_features = None
+ if self.cfg.ema_transformer_only:
+ pre_encoder_features = features.clone()
+
+ features = self.dropout_input(features)
+
+ if mask:
+ x, mask_indices = self.apply_mask(
+ features,
+ padding_mask,
+ mask_indices=mask_indices,
+ mask_channel_indices=mask_channel_indices,
+ )
+ else:
+ x = features
+ mask_indices = None
+
+ x, layer_results = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ layer=layer,
+ )
+
+ if features_only:
+ return {
+ "x": x,
+ "padding_mask": padding_mask,
+ "layer_results": layer_results,
+ }
+
+ result = {
+ "losses": {},
+ }
+
+ with torch.no_grad():
+ self.ema.model.eval()
+
+ if self.cfg.ema_transformer_only:
+ y, layer_results = self.ema.model.extract_features(
+ pre_encoder_features,
+ padding_mask=padding_mask,
+ min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
+ )
+ y = {
+ "x": y,
+ "padding_mask": padding_mask,
+ "layer_results": layer_results,
+ }
+ else:
+ y = self.ema.model.extract_features(
+ source=source,
+ padding_mask=orig_padding_mask,
+ mask=False,
+ )
+
+ target_layer_results = [l[2] for l in y["layer_results"]]
+
+ permuted = False
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
+ target_layer_results = [
+ tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
+ ]
+ permuted = True
+
+ if self.cfg.batch_norm_target_layer:
+ target_layer_results = [
+ F.batch_norm(
+ tl.float(), running_mean=None, running_var=None, training=True
+ )
+ for tl in target_layer_results
+ ]
+
+ if self.cfg.instance_norm_target_layer:
+ target_layer_results = [
+ F.instance_norm(tl.float()) for tl in target_layer_results
+ ]
+
+ if permuted:
+ target_layer_results = [
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
+ ]
+
+ if self.cfg.group_norm_target_layer:
+ target_layer_results = [
+ F.layer_norm(tl.float(), tl.shape[-2:])
+ for tl in target_layer_results
+ ]
+
+ if self.cfg.layer_norm_target_layer:
+ target_layer_results = [
+ F.layer_norm(tl.float(), tl.shape[-1:])
+ for tl in target_layer_results
+ ]
+
+ y = sum(target_layer_results) / len(target_layer_results)
+
+ if self.cfg.layer_norm_targets:
+ y = F.layer_norm(y.float(), y.shape[-1:])
+
+ if self.cfg.instance_norm_targets:
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
+
+ if not permuted:
+ y = y.transpose(0, 1)
+
+ y = y[mask_indices]
+
+ x = x[mask_indices]
+ x = self.final_proj(x)
+
+ sz = x.size(-1)
+
+ if self.loss_beta == 0:
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
+ else:
+ loss = F.smooth_l1_loss(
+ x.float(), y.float(), reduction="none", beta=self.loss_beta
+ ).sum(dim=-1)
+
+ if self.loss_scale is not None:
+ scale = self.loss_scale
+ else:
+ scale = 1 / math.sqrt(sz)
+
+ result["losses"]["regression"] = loss.sum() * scale
+
+ if "sample_size" not in result:
+ result["sample_size"] = loss.numel()
+
+ with torch.no_grad():
+ result["target_var"] = self.compute_var(y)
+ result["pred_var"] = self.compute_var(x.float())
+
+ if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
+ logger.error(
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
+ )
+ raise Exception(
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
+ )
+ if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
+ logger.error(
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
+ )
+ raise Exception(
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
+ )
+
+ if self.ema is not None:
+ result["ema_decay"] = self.ema.get_decay() * 1000
+
+ return result
+
+ @staticmethod
+ def compute_var(y):
+ y = y.view(-1, y.size(-1))
+ if dist.is_initialized():
+ zc = torch.tensor(y.size(0)).cuda()
+ zs = y.sum(dim=0)
+ zss = (y ** 2).sum(dim=0)
+
+ dist.all_reduce(zc)
+ dist.all_reduce(zs)
+ dist.all_reduce(zss)
+
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
+ return torch.sqrt(var + 1e-6).mean()
+ else:
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
+
+ def extract_features(
+ self, source, padding_mask, mask=False, layer=None
+ ):
+ res = self.forward(
+ source,
+ padding_mask,
+ mask=mask,
+ features_only=True,
+ layer=layer,
+ )
+ return res
+
+ def remove_pretraining_modules(self, last_layer=None):
+ self.final_proj = None
+ self.ema = None
+ if last_layer is not None:
+ self.encoder.layers = nn.ModuleList(
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
+ )
diff --git a/examples/data2vec/models/data2vec_image_classification.py b/examples/data2vec/models/data2vec_image_classification.py
new file mode 100644
index 0000000000..851c9ce455
--- /dev/null
+++ b/examples/data2vec/models/data2vec_image_classification.py
@@ -0,0 +1,143 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+
+from dataclasses import dataclass
+from typing import Any
+
+from omegaconf import II, MISSING
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq import checkpoint_utils, tasks
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Data2VecImageClassificationConfig(FairseqDataclass):
+ model_path: str = MISSING
+ no_pretrained_weights: bool = False
+ num_classes: int = 1000
+ mixup: float = 0.8
+ cutmix: float = 1.0
+ label_smoothing: float = 0.1
+
+ pretrained_model_args: Any = None
+ data: str = II("task.data")
+
+
+@register_model(
+ "data2vec_image_classification", dataclass=Data2VecImageClassificationConfig
+)
+class Data2VecImageClassificationModel(BaseFairseqModel):
+ def __init__(self, cfg: Data2VecImageClassificationConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ if cfg.pretrained_model_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
+ pretrained_args = state.get("cfg", None)
+ pretrained_args.criterion = None
+ pretrained_args.lr_scheduler = None
+ cfg.pretrained_model_args = pretrained_args
+
+ logger.info(pretrained_args)
+ else:
+ state = None
+ pretrained_args = cfg.pretrained_model_args
+
+ pretrained_args.task.data = cfg.data
+ task = tasks.setup_task(pretrained_args.task)
+ model = task.build_model(pretrained_args.model, from_checkpoint=True)
+
+ model.remove_pretraining_modules()
+
+ self.model = model
+
+ if state is not None and not cfg.no_pretrained_weights:
+ self.load_model_weights(state, model, cfg)
+
+ self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim)
+ self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
+
+ self.head.weight.data.mul_(1e-3)
+ self.head.bias.data.mul_(1e-3)
+
+ self.mixup_fn = None
+
+ if cfg.mixup > 0 or cfg.cutmix > 0:
+ from timm.data import Mixup
+
+ self.mixup_fn = Mixup(
+ mixup_alpha=cfg.mixup,
+ cutmix_alpha=cfg.cutmix,
+ cutmix_minmax=None,
+ prob=1.0,
+ switch_prob=0.5,
+ mode="batch",
+ label_smoothing=cfg.label_smoothing,
+ num_classes=cfg.num_classes,
+ )
+
+ def load_model_weights(self, state, model, cfg):
+ if "_ema" in state["model"]:
+ del state["model"]["_ema"]
+ model.load_state_dict(state["model"], strict=True)
+
+ @classmethod
+ def build_model(cls, cfg: Data2VecImageClassificationConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def forward(
+ self,
+ img,
+ label=None,
+ ):
+ if self.training and self.mixup_fn is not None and label is not None:
+ img, label = self.mixup_fn(img, label)
+
+ x = self.model(img, mask=False)
+ x = x[:, 1:]
+ x = self.fc_norm(x.mean(1))
+ x = self.head(x)
+
+ if label is None:
+ return x
+
+ if self.training and self.mixup_fn is not None:
+ loss = -label * F.log_softmax(x.float(), dim=-1)
+ else:
+ loss = F.cross_entropy(
+ x.float(),
+ label,
+ label_smoothing=self.cfg.label_smoothing if self.training else 0,
+ reduction="none",
+ )
+
+ result = {
+ "losses": {"regression": loss},
+ "sample_size": img.size(0),
+ }
+
+ if not self.training:
+ with torch.no_grad():
+ pred = x.argmax(-1)
+ correct = (pred == label).sum()
+ result["correct"] = correct
+
+ return result
diff --git a/examples/data2vec/models/data2vec_text.py b/examples/data2vec/models/data2vec_text.py
new file mode 100644
index 0000000000..cb3c8b383a
--- /dev/null
+++ b/examples/data2vec/models/data2vec_text.py
@@ -0,0 +1,517 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass, field
+from typing import Optional
+import logging
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from omegaconf import II
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.modules import EMAModule, EMAModuleConfig
+from fairseq.models import (
+ FairseqEncoder,
+ FairseqEncoderModel,
+ register_model,
+)
+from fairseq.models.roberta.model import RobertaLMHead, RobertaClassificationHead
+from fairseq.models.transformer import TransformerEncoder, TransformerConfig
+from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Data2VecTextConfig(FairseqDataclass):
+ max_positions: int = II("task.tokens_per_sample")
+
+ head_layers: int = 1
+
+ transformer: TransformerConfig = TransformerConfig()
+
+ load_checkpoint_heads: bool = field(
+ default=False,
+ metadata={"help": "(re-)register and load heads when loading checkpoints"},
+ )
+
+ loss_beta: float = field(
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
+ )
+ loss_scale: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
+ },
+ )
+ average_top_k_layers: int = field(
+ default=8, metadata={"help": "how many layers to average"}
+ )
+
+ layer_norm_target_layer: bool = False
+ instance_norm_target_layer: bool = False
+ batch_norm_target_layer: bool = False
+ instance_norm_targets: bool = False
+ layer_norm_targets: bool = False
+
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
+ ema_end_decay: float = field(
+ default=0.9999, metadata={"help": "final ema decay rate"}
+ )
+
+ # when to finish annealing ema decay rate
+ ema_anneal_end_step: int = II("optimization.max_update")
+
+ ema_transformer_layers_only: bool = field(
+ default=True,
+ metadata={"help": "whether to momentum update only the transformer layers"},
+ )
+
+
+def get_annealed_rate(start, end, curr_step, total_steps):
+ r = end - start
+ pct_remaining = 1 - curr_step / total_steps
+ return end - r * pct_remaining
+
+
+@register_model("data2vec_text", dataclass=Data2VecTextConfig)
+class Data2VecTextModel(FairseqEncoderModel):
+ def __init__(self, cfg: Data2VecTextConfig, encoder):
+ super().__init__(encoder)
+ self.cfg = cfg
+
+ # We follow BERT's random weight initialization
+ self.apply(init_bert_params)
+
+ self.classification_heads = nn.ModuleDict()
+
+ @classmethod
+ def build_model(cls, cfg, task):
+ """Build a new model instance."""
+
+ encoder = Data2VecTextEncoder(cfg, task.source_dictionary, task.cfg.data)
+
+ return cls(cfg, encoder)
+
+ def forward(
+ self,
+ src_tokens,
+ target_tokens=None,
+ features_only=False,
+ return_all_hiddens=False,
+ classification_head_name=None,
+ **kwargs,
+ ):
+ if classification_head_name is not None:
+ features_only = True
+
+ res = self.encoder(
+ src_tokens, target_tokens, features_only, return_all_hiddens, **kwargs
+ )
+
+ if isinstance(res, tuple):
+ x, extra = res
+ else:
+ return res
+
+ if classification_head_name is not None:
+ x = self.classification_heads[classification_head_name](x)
+ return x, extra
+
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
+ """Get normalized probabilities (or log probs) from a net's output."""
+ logits = net_output[0].float()
+ if log_probs:
+ return F.log_softmax(logits, dim=-1)
+ else:
+ return F.softmax(logits, dim=-1)
+
+ def register_classification_head(
+ self, name, num_classes=None, inner_dim=None, **kwargs
+ ):
+ """Register a classification head."""
+ if name in self.classification_heads:
+ prev_num_classes = self.classification_heads[name].out_proj.out_features
+ prev_inner_dim = self.classification_heads[name].dense.out_features
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
+ logger.warning(
+ 're-registering head "{}" with num_classes {} (prev: {}) '
+ "and inner_dim {} (prev: {})".format(
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
+ )
+ )
+ self.classification_heads[name] = RobertaClassificationHead(
+ input_dim=self.cfg.transformer.encoder.embed_dim,
+ inner_dim=inner_dim or self.cfg.transformer.encoder.embed_dim,
+ num_classes=num_classes,
+ activation_fn="tanh",
+ pooler_dropout=0,
+ )
+
+ @property
+ def supported_targets(self):
+ return {"self"}
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ prefix = name + "." if name != "" else ""
+
+ # rename decoder -> encoder before upgrading children modules
+ for k in list(state_dict.keys()):
+ if k.startswith(prefix + "decoder"):
+ new_k = prefix + "encoder" + k[len(prefix + "decoder") :]
+ state_dict[new_k] = state_dict[k]
+ del state_dict[k]
+
+ # rename emb_layer_norm -> layernorm_embedding
+ for k in list(state_dict.keys()):
+ if ".emb_layer_norm." in k:
+ new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.")
+ state_dict[new_k] = state_dict[k]
+ del state_dict[k]
+
+ if self.encoder.regression_head is not None:
+ if ".lm_head." in k:
+ new_k = k.replace(".lm_head.", ".regression_head.")
+ state_dict[new_k] = state_dict[k]
+ del state_dict[k]
+ else:
+ if ".regression_head." in k:
+ del state_dict[k]
+
+ # upgrade children modules
+ super().upgrade_state_dict_named(state_dict, name)
+
+ # Handle new classification heads present in the state dict.
+ current_head_names = (
+ []
+ if not hasattr(self, "classification_heads")
+ or self.classification_heads is None
+ else self.classification_heads.keys()
+ )
+ keys_to_delete = []
+ for k in state_dict.keys():
+ if not k.startswith(prefix + "classification_heads."):
+ continue
+
+ head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
+ num_classes = state_dict[
+ prefix + "classification_heads." + head_name + ".out_proj.weight"
+ ].size(0)
+ inner_dim = state_dict[
+ prefix + "classification_heads." + head_name + ".dense.weight"
+ ].size(0)
+
+ if self.cfg.load_checkpoint_heads:
+ if head_name not in current_head_names:
+ self.register_classification_head(head_name, num_classes, inner_dim)
+ else:
+ if head_name not in current_head_names:
+ logger.warning(
+ "deleting classification head ({}) from checkpoint "
+ "not present in current model: {}".format(head_name, k)
+ )
+ keys_to_delete.append(k)
+ elif (
+ num_classes
+ != self.classification_heads[head_name].out_proj.out_features
+ or inner_dim
+ != self.classification_heads[head_name].dense.out_features
+ ):
+ logger.warning(
+ "deleting classification head ({}) from checkpoint "
+ "with different dimensions than current model: {}".format(
+ head_name, k
+ )
+ )
+ keys_to_delete.append(k)
+ for k in keys_to_delete:
+ del state_dict[k]
+
+ # Copy any newly-added classification heads into the state dict
+ # with their current weights.
+ if (
+ hasattr(self, "classification_heads")
+ and self.classification_heads is not None
+ and len(self.classification_heads) > 0
+ ):
+ cur_state = self.classification_heads.state_dict()
+ for k, v in cur_state.items():
+ if prefix + "classification_heads." + k not in state_dict:
+ logger.info("Overwriting " + prefix + "classification_heads." + k)
+ state_dict[prefix + "classification_heads." + k] = v
+
+ for k in list(state_dict.keys()):
+ if k.startswith(prefix + "encoder.lm_head.") or k.startswith(
+ prefix + "encoder.emb_head."
+ ):
+ del state_dict[k]
+
+ self.encoder.lm_head = None
+
+ if self.encoder.target_model is None:
+ for k in list(state_dict.keys()):
+ if k.startswith(prefix + "encoder.target_model."):
+ del state_dict[k]
+
+ if (self.encoder.ema is None) and (prefix + "encoder._ema" in state_dict):
+ del state_dict[prefix + "encoder._ema"]
+
+ def remove_pretraining_modules(self, last_layer=None):
+ self.encoder.lm_head = None
+ self.encoder.regression_head = None
+ self.encoder.ema = None
+ self.classification_heads = None
+
+ if last_layer is not None:
+ self.encoder.sentence_encoder.layers = nn.ModuleList(
+ l
+ for i, l in enumerate(self.encoder.sentence_encoder.layers)
+ if i <= last_layer
+ )
+ self.encoder.sentence_encoder.layer_norm = None
+
+
+class Data2VecTextEncoder(FairseqEncoder):
+ def __init__(self, cfg: Data2VecTextConfig, dictionary, task_data):
+ super().__init__(dictionary)
+
+ self.cfg = cfg
+
+ embed_tokens = self.build_embedding(
+ len(dictionary), cfg.transformer.encoder.embed_dim, dictionary.pad()
+ )
+
+ self.sentence_encoder = self.build_encoder(cfg, dictionary, embed_tokens)
+ self.mask_idx = dictionary.index("")
+ assert self.mask_idx != dictionary.unk(), dictionary.symbols
+
+ self.ema = None
+ self.average_top_k_layers = cfg.average_top_k_layers
+ self.loss_scale = cfg.loss_scale
+
+ assert self.cfg.head_layers >= 1
+
+ embed_dim = cfg.transformer.encoder.embed_dim
+ curr_dim = embed_dim
+ projs = []
+ for i in range(self.cfg.head_layers - 1):
+ next_dim = embed_dim * 2 if i == 0 else curr_dim
+ projs.append(nn.Linear(curr_dim, next_dim))
+ projs.append(nn.GELU())
+ curr_dim = next_dim
+
+ projs.append(nn.Linear(curr_dim, embed_dim))
+ self.regression_head = nn.Sequential(*projs)
+
+ self.num_updates = 0
+
+ def build_embedding(self, vocab_size, embedding_dim, padding_idx):
+ return nn.Embedding(vocab_size, embedding_dim, padding_idx)
+
+ def build_encoder(self, cfg, dictionary, embed_tokens):
+ encoder = TransformerEncoder(cfg.transformer, dictionary, embed_tokens, return_fc=True)
+ encoder.apply(init_bert_params)
+ return encoder
+
+ def build_lm_head(self, embed_dim, output_dim, activation_fn, weight):
+ return RobertaLMHead(embed_dim, output_dim, activation_fn, weight)
+
+ def make_ema_teacher(self):
+ ema_config = EMAModuleConfig(
+ ema_decay=self.cfg.ema_decay,
+ ema_fp32=True,
+ )
+ skip_keys = set()
+ if self.cfg.ema_transformer_layers_only:
+ for k, _ in self.sentence_encoder.embed_positions.named_parameters():
+ skip_keys.add(f"embed_tokens.{k}")
+ for k, _ in self.sentence_encoder.embed_positions.named_parameters():
+ skip_keys.add(f"embed_positions.{k}")
+ if self.sentence_encoder.layernorm_embedding is not None:
+ for (
+ k,
+ _,
+ ) in self.sentence_encoder.layernorm_embedding.named_parameters():
+ skip_keys.add(f"layernorm_embedding.{k}")
+ if self.sentence_encoder.layer_norm is not None:
+ for k, _ in self.sentence_encoder.layer_norm.named_parameters():
+ skip_keys.add(f"layernorm_embedding.{k}")
+
+ self.ema = EMAModule(
+ self.sentence_encoder,
+ ema_config,
+ skip_keys=skip_keys,
+ )
+
+ def set_num_updates(self, num_updates):
+ super().set_num_updates(num_updates)
+
+ if self.ema is None and self.regression_head is not None:
+ logger.info(f"making ema teacher")
+ self.make_ema_teacher()
+ elif self.training and self.ema is not None:
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
+ if num_updates >= self.cfg.ema_anneal_end_step:
+ decay = self.cfg.ema_end_decay
+ else:
+ decay = get_annealed_rate(
+ self.cfg.ema_decay,
+ self.cfg.ema_end_decay,
+ num_updates,
+ self.cfg.ema_anneal_end_step,
+ )
+ self.ema.set_decay(decay)
+ if self.ema.get_decay() < 1:
+ self.ema.step(self.sentence_encoder)
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ state = super().state_dict(destination, prefix, keep_vars)
+ if self.ema is not None:
+ state[prefix + "_ema"] = self.ema.fp32_params
+ return state
+
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+ if self.ema is not None:
+ k = prefix + "_ema"
+ assert k in state_dict
+ self.ema.restore(state_dict[k], True)
+ del state_dict[k]
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+ def forward(
+ self,
+ src_tokens,
+ target_tokens=None,
+ features_only=False,
+ return_all_hiddens=False,
+ masked_tokens=None,
+ **unused,
+ ):
+ """
+ Args:
+ src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
+ features_only (bool, optional): skip LM head and just return
+ features. If True, the output will be of shape
+ `(batch, src_len, embed_dim)`.
+ return_all_hiddens (bool, optional): also return all of the
+ intermediate hidden states (default: False).
+
+ Returns:
+ tuple:
+ - the LM output of shape `(batch, src_len, vocab)`
+ - a dictionary of additional data, where 'inner_states'
+ is a list of hidden states. Note that the hidden
+ states have shape `(src_len, batch, vocab)`.
+ """
+
+ x, extra = self.extract_features(
+ src_tokens, return_all_hiddens=return_all_hiddens
+ )
+
+ if features_only:
+ return x, extra
+
+ assert target_tokens is not None
+
+ with torch.no_grad():
+ # use EMA parameter as the teacher
+ self.ema.model.eval()
+
+ encoder_out = self.ema.model(
+ target_tokens,
+ return_all_hiddens=True,
+ )
+ y = encoder_out["fc_results"]
+
+ y = y[-self.average_top_k_layers :]
+
+ permuted = False
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
+ y = [tl.permute(1, 2, 0) for tl in y] # TBC -> BCT
+ permuted = True
+
+ if self.cfg.batch_norm_target_layer:
+ y = [
+ F.batch_norm(
+ tl.float(), running_mean=None, running_var=None, training=True
+ )
+ for tl in y
+ ]
+
+ if self.cfg.instance_norm_target_layer:
+ y = [F.instance_norm(tl.float()) for tl in y]
+
+ if permuted:
+ y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
+
+ if self.cfg.layer_norm_target_layer:
+ y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
+
+ y = sum(y) / len(y)
+
+ if not permuted:
+ y = y.transpose(0, 1)
+
+ if self.cfg.layer_norm_targets:
+ y = F.layer_norm(y.float(), y.shape[-1:])
+
+ if self.cfg.instance_norm_targets:
+ y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
+
+ masked_indices = src_tokens.eq(self.mask_idx)
+
+ x = x[masked_indices]
+ y = y[masked_indices]
+
+ x = self.regression_head(x)
+
+ sz = x.size(-1)
+ if self.cfg.loss_beta == 0:
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
+ else:
+ loss = F.smooth_l1_loss(
+ x.float(), y.float(), reduction="none", beta=self.cfg.loss_beta
+ ).sum(dim=-1)
+
+ result = {
+ "losses": {
+ "main": loss.sum() / math.sqrt(sz)
+ if self.loss_scale <= 0
+ else loss.sum() * self.loss_scale,
+ },
+ "sample_size": loss.numel(),
+ }
+
+ # logging other values
+ other_logs = {
+ "ema_decay": self.ema.get_decay() * 1000
+ }
+ result["logs"] = other_logs
+ return result
+
+ def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs):
+ encoder_out = self.sentence_encoder(
+ src_tokens,
+ return_all_hiddens=return_all_hiddens,
+ token_embeddings=kwargs.get("token_embeddings", None),
+ )
+ # T x B x C -> B x T x C
+ features = encoder_out["encoder_out"][0].transpose(0, 1)
+ inner_states = encoder_out["encoder_states"] if return_all_hiddens else None
+ return features, {
+ "inner_states": inner_states,
+ "encoder_embedding": encoder_out["encoder_embedding"][0],
+ }
+
+ def output_layer(self, features, masked_tokens=None, **unused):
+ return self.lm_head(features, masked_tokens)
+
+ def max_positions(self):
+ """Maximum output length supported by the encoder."""
+ return self.cfg.max_positions
diff --git a/examples/data2vec/models/data2vec_text_classification.py b/examples/data2vec/models/data2vec_text_classification.py
new file mode 100644
index 0000000000..e787b916dc
--- /dev/null
+++ b/examples/data2vec/models/data2vec_text_classification.py
@@ -0,0 +1,141 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+
+from dataclasses import dataclass
+from typing import Any
+
+from omegaconf import II, MISSING
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq import checkpoint_utils, tasks
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.models.roberta.model import RobertaClassificationHead
+
+from examples.data2vec.data.modality import Modality
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Data2VecTextClassificationConfig(FairseqDataclass):
+ pooler_dropout: float = 0.0
+ pooler_activation_fn: str = "tanh"
+ quant_noise_pq: int = 0
+ quant_noise_pq_block_size: int = 8
+ spectral_norm_classification_head: bool = False
+
+ model_path: str = MISSING
+ no_pretrained_weights: bool = False
+
+ pretrained_model_args: Any = None
+
+
+@register_model(
+ "data2vec_text_classification", dataclass=Data2VecTextClassificationConfig
+)
+class Data2VecTextClassificationModel(BaseFairseqModel):
+ def __init__(self, cfg: Data2VecTextClassificationConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ if cfg.pretrained_model_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
+ pretrained_args = state.get("cfg", None)
+ pretrained_args.criterion = None
+ pretrained_args.lr_scheduler = None
+ cfg.pretrained_model_args = pretrained_args
+
+ logger.info(pretrained_args)
+ else:
+ state = None
+ pretrained_args = cfg.pretrained_model_args
+
+ task = tasks.setup_task(pretrained_args.task)
+ model = task.build_model(pretrained_args.model, from_checkpoint=True)
+
+ model.remove_pretraining_modules()
+
+ self.model = model
+
+ if state is not None and not cfg.no_pretrained_weights:
+ self.load_model_weights(state, model, cfg)
+
+ self.classification_heads = nn.ModuleDict()
+
+
+ def load_model_weights(self, state, model, cfg):
+ for k in list(state["model"].keys()):
+ if (
+ k.startswith("shared_decoder") or
+ k.startswith("_ema") or
+ "decoder" in k
+ ):
+ logger.info(f"Deleting {k} from checkpoint")
+ del state["model"][k]
+ model.load_state_dict(state["model"], strict=True)
+
+ @classmethod
+ def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def register_classification_head(
+ self, name, num_classes=None, inner_dim=None, **kwargs
+ ):
+ """Register a classification head."""
+ if name in self.classification_heads:
+ prev_num_classes = self.classification_heads[name].out_proj.out_features
+ prev_inner_dim = self.classification_heads[name].dense.out_features
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
+ logger.warning(
+ 're-registering head "{}" with num_classes {} (prev: {}) '
+ "and inner_dim {} (prev: {})".format(
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
+ )
+ )
+ embed_dim = self.cfg.pretrained_model_args.model.embed_dim
+ self.classification_heads[name] = RobertaClassificationHead(
+ input_dim=embed_dim,
+ inner_dim=inner_dim or embed_dim,
+ num_classes=num_classes,
+ activation_fn=self.cfg.pooler_activation_fn,
+ pooler_dropout=self.cfg.pooler_dropout,
+ q_noise=self.cfg.quant_noise_pq,
+ qn_block_size=self.cfg.quant_noise_pq_block_size,
+ do_spectral_norm=self.cfg.spectral_norm_classification_head,
+ )
+
+ def forward(
+ self,
+ source,
+ id,
+ padding_mask,
+ features_only=True,
+ remove_extra_tokens=True,
+ classification_head_name=None,
+ ):
+ encoder_out = self.model(
+ source,
+ id=id,
+ mode=Modality.TEXT,
+ padding_mask=padding_mask,
+ mask=False,
+ features_only=features_only,
+ remove_extra_tokens=remove_extra_tokens
+ )
+ logits = self.classification_heads[classification_head_name](encoder_out["x"])
+ return logits, encoder_out
diff --git a/examples/data2vec/models/data2vec_vision.py b/examples/data2vec/models/data2vec_vision.py
new file mode 100644
index 0000000000..2f89894429
--- /dev/null
+++ b/examples/data2vec/models/data2vec_vision.py
@@ -0,0 +1,727 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+import math
+import numpy as np
+import random
+
+from dataclasses import dataclass, field
+from typing import Optional
+
+from omegaconf import II
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from fairseq.modules import EMAModule, EMAModuleConfig
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Data2VecVisionConfig(FairseqDataclass):
+ layer_scale_init_value: float = field(
+ default=1e-4, metadata={"help": "rescale layer outputs, 0 to disable"}
+ )
+ num_mask_patches: int = field(
+ default=75,
+ metadata={"help": "number of the visual tokens/patches need be masked"},
+ )
+ min_mask_patches_per_block: int = 16
+ max_mask_patches_per_block: int = 196
+ image_size: int = 224
+ patch_size: int = 16
+ in_channels: int = 3
+
+ shared_rel_pos_bias: bool = True
+
+ drop_path: float = 0.1
+ attention_dropout: float = 0.0
+
+ depth: int = 12
+ embed_dim: int = 768
+ num_heads: int = 12
+ mlp_ratio: int = 4
+
+ loss_beta: float = field(
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
+ )
+ loss_scale: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
+ },
+ )
+ average_top_k_layers: int = field(
+ default=8, metadata={"help": "how many layers to average"}
+ )
+
+ end_of_block_targets: bool = True
+ layer_norm_target_layer: bool = False
+ instance_norm_target_layer: bool = False
+ batch_norm_target_layer: bool = False
+ instance_norm_targets: bool = False
+ layer_norm_targets: bool = False
+
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
+ ema_end_decay: float = field(
+ default=0.9999, metadata={"help": "final ema decay rate"}
+ )
+
+ # when to finish annealing ema decay rate
+ ema_anneal_end_step: int = II("optimization.max_update")
+
+ ema_transformer_only: bool = field(
+ default=True,
+ metadata={"help": "whether to momentum update only the transformer layers"},
+ )
+
+
+def get_annealed_rate(start, end, curr_step, total_steps):
+ r = end - start
+ pct_remaining = 1 - curr_step / total_steps
+ return end - r * pct_remaining
+
+
+@register_model("data2vec_vision", dataclass=Data2VecVisionConfig)
+class Data2VecVisionModel(BaseFairseqModel):
+ def __init__(self, cfg: Data2VecVisionConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ self.ema = None
+
+ self.average_top_k_layers = cfg.average_top_k_layers
+ self.loss_beta = cfg.loss_beta
+ self.loss_scale = (
+ cfg.loss_scale
+ if cfg.loss_scale is not None
+ else 1 / math.sqrt(cfg.embed_dim)
+ )
+
+ self.patch_embed = PatchEmbed(
+ img_size=cfg.image_size,
+ patch_size=cfg.patch_size,
+ in_chans=cfg.in_channels,
+ embed_dim=cfg.embed_dim,
+ )
+
+ patch_size = self.patch_embed.patch_size
+ self.window_size = (
+ cfg.image_size // patch_size[0],
+ cfg.image_size // patch_size[1],
+ )
+
+ self.cls_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
+ self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
+
+ nn.init.trunc_normal_(self.cls_emb, 0.02)
+ nn.init.trunc_normal_(self.mask_emb, 0.02)
+
+ self.encoder = TransformerEncoder(cfg, self.patch_embed.patch_shape)
+
+ self.final_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
+ self.num_updates = 0
+
+ def make_ema_teacher(self):
+ ema_config = EMAModuleConfig(
+ ema_decay=self.cfg.ema_decay,
+ ema_fp32=True,
+ )
+ self.ema = EMAModule(
+ self.encoder if self.cfg.ema_transformer_only else self,
+ ema_config,
+ )
+
+ def set_num_updates(self, num_updates):
+ super().set_num_updates(num_updates)
+
+ if self.ema is None and self.final_proj is not None:
+ logger.info(f"making ema teacher")
+ self.make_ema_teacher()
+ elif self.training and self.ema is not None:
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
+ if num_updates >= self.cfg.ema_anneal_end_step:
+ decay = self.cfg.ema_end_decay
+ else:
+ decay = get_annealed_rate(
+ self.cfg.ema_decay,
+ self.cfg.ema_end_decay,
+ num_updates,
+ self.cfg.ema_anneal_end_step,
+ )
+ self.ema.set_decay(decay)
+ if self.ema.get_decay() < 1:
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
+
+ self.num_updates = num_updates
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ state = super().state_dict(destination, prefix, keep_vars)
+
+ if self.ema is not None:
+ state[prefix + "_ema"] = self.ema.fp32_params
+
+ return state
+
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+ if self.ema is not None:
+ k = prefix + "_ema"
+ assert k in state_dict
+ self.ema.restore(state_dict[k], True)
+ del state_dict[k]
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+ @classmethod
+ def build_model(cls, cfg: Data2VecVisionConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def make_mask(self, bsz, num_masks, min_masks, max_masks):
+ height, width = self.window_size
+
+ masks = np.zeros(shape=(bsz, height, width), dtype=np.int)
+
+ for i in range(bsz):
+ mask = masks[i]
+ mask_count = 0
+
+ min_aspect = 0.3
+ max_aspect = 1 / min_aspect
+ log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
+
+ def _mask(mask, max_mask_patches):
+ delta = 0
+ for attempt in range(10):
+ target_area = random.uniform(min_masks, max_mask_patches)
+ aspect_ratio = math.exp(random.uniform(*log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < width and h < height:
+ top = random.randint(0, height - h)
+ left = random.randint(0, width - w)
+
+ num_masked = mask[top : top + h, left : left + w].sum()
+ # Overlap
+ if 0 < h * w - num_masked <= max_mask_patches:
+ for i in range(top, top + h):
+ for j in range(left, left + w):
+ if mask[i, j] == 0:
+ mask[i, j] = 1
+ delta += 1
+
+ if delta > 0:
+ break
+ return delta
+
+ while mask_count < num_masks:
+ max_mask_patches = min(num_masks - mask_count, max_masks)
+
+ delta = _mask(mask, max_mask_patches)
+ if delta == 0:
+ break
+ else:
+ mask_count += delta
+
+ return torch.from_numpy(masks)
+
+ def forward(
+ self,
+ img,
+ mask: bool = True,
+ layer_results: bool = False,
+ ):
+ x = self.patch_embed(img)
+ batch_size, seq_len, _ = x.size()
+
+ if mask:
+ mask_indices = self.make_mask(
+ img.size(0),
+ self.cfg.num_mask_patches,
+ self.cfg.min_mask_patches_per_block,
+ self.cfg.max_mask_patches_per_block,
+ )
+ bool_mask = mask_indices.view(mask_indices.size(0), -1).bool()
+ else:
+ mask_indices = bool_mask = None
+
+ cls_tokens = self.cls_emb.expand(batch_size, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if self.ema is not None:
+ with torch.no_grad():
+ self.ema.model.eval()
+
+ if self.cfg.ema_transformer_only:
+ y = self.ema.model(
+ x,
+ layer_results="end" if self.cfg.end_of_block_targets else "fc",
+ )
+ else:
+ y = self.ema.model(
+ img,
+ mask=False,
+ layer_results=True,
+ )
+
+ y = y[-self.cfg.average_top_k_layers :]
+
+ permuted = False
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
+ y = [tl.transpose(1, 2) for tl in y] # BTC -> BCT
+ permuted = True
+
+ if self.cfg.batch_norm_target_layer:
+ y = [
+ F.batch_norm(
+ tl.float(), running_mean=None, running_var=None, training=True
+ )
+ for tl in y
+ ]
+
+ if self.cfg.instance_norm_target_layer:
+ y = [F.instance_norm(tl.float()) for tl in y]
+
+ if permuted:
+ y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
+
+ if self.cfg.layer_norm_target_layer:
+ y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
+
+ y = sum(y) / len(y)
+
+ if self.cfg.layer_norm_targets:
+ y = F.layer_norm(y.float(), y.shape[-1:])
+
+ if self.cfg.instance_norm_targets:
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
+
+ y = y[bool_mask].float()
+
+ if mask_indices is not None:
+ mask_token = self.mask_emb.expand(batch_size, seq_len, -1)
+ w = mask_indices.view(mask_indices.size(0), -1, 1).type_as(mask_token)
+ x[:, 1:] = x[:, 1:] * (1 - w) + mask_token * w
+
+ if layer_results:
+ enc_layer_results = "end" if self.cfg.end_of_block_targets else "fc"
+ else:
+ enc_layer_results = None
+
+ x = self.encoder(x, layer_results=enc_layer_results)
+ if layer_results or mask_indices is None:
+ return x
+
+ x = x[bool_mask].float()
+
+ if self.loss_beta == 0:
+ loss = F.mse_loss(x, y, reduction="none").sum(dim=-1)
+ else:
+ loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta).sum(
+ dim=-1
+ )
+
+ if self.loss_scale > 0:
+ loss = loss * self.loss_scale
+
+ result = {
+ "losses": {"regression": loss.sum()},
+ "sample_size": loss.numel(),
+ "target_var": self.compute_var(y),
+ "pred_var": self.compute_var(x),
+ "ema_decay": self.ema.get_decay() * 1000,
+ }
+ return result
+
+ @staticmethod
+ def compute_var(y):
+ y = y.view(-1, y.size(-1))
+ if dist.is_initialized():
+ zc = torch.tensor(y.size(0)).cuda()
+ zs = y.sum(dim=0)
+ zss = (y ** 2).sum(dim=0)
+
+ dist.all_reduce(zc)
+ dist.all_reduce(zs)
+ dist.all_reduce(zss)
+
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
+ return torch.sqrt(var + 1e-6).mean()
+ else:
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
+
+ def remove_pretraining_modules(self, last_layer=None):
+ self.final_proj = None
+ self.ema = None
+ self.encoder.norm = nn.Identity()
+ self.mask_emb = None
+ if last_layer is not None:
+ self.encoder.layers = nn.ModuleList(
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
+ )
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ if isinstance(img_size, int):
+ img_size = img_size, img_size
+ if isinstance(patch_size, int):
+ patch_size = patch_size, patch_size
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.conv = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+
+ def forward(self, x):
+ # BCHW -> BTC
+ x = self.conv(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=True,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2,
+ dtype=relative_coords.dtype,
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat(
+ (
+ self.q_bias,
+ torch.zeros_like(self.v_bias, requires_grad=False),
+ self.v_bias,
+ )
+ )
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if self.relative_position_bias_table is not None:
+ assert 1==2
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+ print("attn.size() :", attn.size())
+ print("rel_pos_bias.size() :", rel_pos_bias.size())
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ def forward(self):
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ print("self.window_size :", self.window_size)
+ print("self.num_relative_distance :", self.num_relative_distance)
+ print("self.relative_position_index :", self.relative_position_index.size(), self.relative_position_index)
+ print("relative_position_bias.size(), relative_position_bias :",relative_position_bias.size(), relative_position_bias)
+ print("self.relative_position_bias_table.size(), self.relative_position_bias_table :",self.relative_position_bias_table.size(), self.relative_position_bias_table)
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ if self.drop_prob == 0.0 or not self.training:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_()
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ init_values=None,
+ window_size=None,
+ ):
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = nn.LayerNorm(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, mlp_hidden_dim),
+ nn.GELU(),
+ nn.Linear(mlp_hidden_dim, dim),
+ nn.Dropout(drop),
+ )
+
+ if init_values > 0:
+ self.gamma_1 = nn.Parameter(
+ init_values * torch.ones((dim)), requires_grad=True
+ )
+ self.gamma_2 = nn.Parameter(
+ init_values * torch.ones((dim)), requires_grad=True
+ )
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rel_pos_bias=None):
+ print("inside block :", x.size())
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ fc_feature = self.drop_path(self.mlp(self.norm2(x)))
+ x = x + fc_feature
+ else:
+ x = x + self.drop_path(
+ self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
+ )
+ fc_feature = self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ x = x + fc_feature
+ return x, fc_feature
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, cfg: Data2VecVisionConfig, patch_shape):
+ super().__init__()
+
+ self.rel_pos_bias = None
+ if cfg.shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(
+ window_size=patch_shape, num_heads=cfg.num_heads
+ )
+
+ dpr = [
+ x.item() for x in torch.linspace(0, cfg.drop_path, cfg.depth)
+ ] # stochastic depth decay rule
+
+ print("TransformerEncoder > patch_shape :", patch_shape)
+ self.blocks = nn.ModuleList(
+ Block(
+ dim=cfg.embed_dim,
+ num_heads=cfg.num_heads,
+ attn_drop=cfg.attention_dropout,
+ drop_path=dpr[i],
+ init_values=cfg.layer_scale_init_value,
+ window_size=patch_shape if not cfg.shared_rel_pos_bias else None,
+ )
+ for i in range(cfg.depth)
+ )
+
+ self.norm = nn.LayerNorm(cfg.embed_dim)
+
+ self.apply(self.init_weights)
+ self.fix_init_weight()
+
+ def init_weights(self, m):
+ std = 0.02
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=std)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.trunc_normal_(m.weight, std=std)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp[2].weight.data, layer_id + 1)
+
+ def extract_features(self, x, layer_results):
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+
+ z = []
+ for i, blk in enumerate(self.blocks):
+ x, fc_feature = blk(x, rel_pos_bias=rel_pos_bias)
+ if layer_results == "end":
+ z.append(x)
+ elif layer_results == "fc":
+ z.append(fc_feature)
+
+ return z if layer_results else self.norm(x)
+
+ def forward(self, x, layer_results=None):
+ x = self.extract_features(x, layer_results=layer_results)
+ if layer_results:
+ return [z[:, 1:] for z in x]
+
+ x = x[:, 1:]
+ return x
diff --git a/examples/data2vec/models/mae.py b/examples/data2vec/models/mae.py
new file mode 100644
index 0000000000..a3b5f72a4a
--- /dev/null
+++ b/examples/data2vec/models/mae.py
@@ -0,0 +1,829 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+from dataclasses import dataclass
+from functools import partial
+
+from timm.models.vision_transformer import PatchEmbed, Block
+
+import torch
+import torch.nn as nn
+
+import numpy as np
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer
+
+try:
+ from apex.normalization import FusedLayerNorm
+except:
+ FusedLayerNorm = nn.LayerNorm
+
+import torch.nn.functional as F
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class MaeConfig(FairseqDataclass):
+ input_size: int = 224
+ in_chans: int = 3
+ patch_size: int = 16
+ embed_dim: int = 768
+ depth: int = 12
+ num_heads: int = 12
+ decoder_embed_dim: int = 512
+ decoder_depth: int = 8
+ decoder_num_heads: int = 16
+ mlp_ratio: int = 4
+ norm_eps: float = 1e-6
+
+ drop_path_rate: float = 0.0
+
+ mask_ratio: float = 0.75
+ norm_pix_loss: bool = True
+
+ w2v_block: bool = False
+ alt_block: bool = False
+ alt_block2: bool = False
+ alt_attention: bool = False
+ block_dropout: float = 0
+ attention_dropout: float = 0
+ activation_dropout: float = 0
+ layer_norm_first: bool = False
+
+ fused_ln: bool = True
+ end_of_block_targets: bool = True
+
+ no_decoder_embed: bool = False
+ no_decoder_pos_embed: bool = False
+ mask_noise_std: float = 0
+
+ single_qkv: bool = False
+ use_rel_pos_bias: bool = False
+ no_cls: bool = False
+
+
+def modify_relative_position_bias(orig_bias, bsz, mask):
+ if mask is None:
+ return orig_bias.unsqueeze(0).repeat(
+ bsz, 1, 1, 1
+ ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
+ heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token
+ mask_for_rel_pos_bias = torch.cat(
+ (torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1
+ ).bool() # bsz x seqlen (add CLS token)
+ unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias
+ unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat(
+ 1, heads, 1
+ ) # bsz x seq_len => bsz x heads x seq_len
+ b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat(
+ bsz, 1, 1, 1
+ ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
+ unmasked_for_rel_pos_bias.unsqueeze(-1)
+ )
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len)
+ new_len = b_t_t_rel_pos_bias.size(-2)
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
+ unmasked_for_rel_pos_bias.unsqueeze(-2)
+ )
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len)
+ return b_t_t_rel_pos_bias
+
+
+class AltBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ layer_norm_first=True,
+ ffn_targets=False,
+ use_rel_pos_bias=False,
+ window_size=None,
+ alt_attention=False,
+ ):
+ super().__init__()
+
+ self.layer_norm_first = layer_norm_first
+ self.ffn_targets = ffn_targets
+
+ from timm.models.vision_transformer import Attention, DropPath, Mlp
+
+ self.norm1 = norm_layer(dim)
+ self.use_rel_pos_bias = use_rel_pos_bias
+ if use_rel_pos_bias:
+ self.attn = AltAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ )
+ else:
+ if alt_attention:
+ from .multi.modules import AltAttention as AltAttention2
+ self.attn = AltAttention2(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ else:
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x, rel_pos_bias=None, pos_mask=None):
+ if self.layer_norm_first:
+ if self.use_rel_pos_bias:
+ x = x + self.drop_path(
+ self.attn(
+ self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask
+ )
+ )
+ else:
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ t = self.mlp(self.norm2(x))
+ x = x + self.drop_path(t)
+ if not self.ffn_targets:
+ t = x
+ return x, t
+ else:
+ if self.use_rel_pos_bias:
+ x = x + self.drop_path(
+ self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask)
+ )
+ else:
+ x = x + self.drop_path(self.attn(x))
+ r = x = self.norm1(x)
+ x = self.mlp(x)
+ t = x
+ x = self.norm2(r + self.drop_path(x))
+ if not self.ffn_targets:
+ t = x
+ return x, t
+
+
+class AltAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2,
+ dtype=relative_coords.dtype,
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None, pos_mask=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat(
+ (
+ self.q_bias,
+ torch.zeros_like(self.v_bias, requires_grad=False),
+ self.v_bias,
+ )
+ )
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + modify_relative_position_bias(
+ relative_position_bias, x.size(0), pos_mask
+ )
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ def forward(self):
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000 ** omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def interpolate_pos_embed(model, checkpoint_model):
+ if "pos_embed" in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print(
+ "Position interpolate from %dx%d to %dx%d"
+ % (orig_size, orig_size, new_size, new_size)
+ )
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size, orig_size, embedding_size
+ ).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens,
+ size=(new_size, new_size),
+ mode="bicubic",
+ align_corners=False,
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model["pos_embed"] = new_pos_embed
+
+
+@register_model("mae", dataclass=MaeConfig)
+class MaeModel(BaseFairseqModel):
+ def __init__(self, cfg: MaeConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ self.mask_ratio = cfg.mask_ratio
+
+ # --------------------------------------------------------------------------
+ # MAE encoder specifics
+ self.patch_embed = PatchEmbed(
+ cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False
+ ) # fixed sin-cos embedding
+
+ norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps)
+
+ dpr = [
+ x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)
+ ] # stochastic depth decay rule
+
+ def make_block(drop_path):
+ if cfg.w2v_block:
+ return TransformerSentenceEncoderLayer(
+ embedding_dim=cfg.embed_dim,
+ ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio,
+ num_attention_heads=cfg.num_heads,
+ dropout=cfg.block_dropout,
+ attention_dropout=cfg.attention_dropout,
+ activation_dropout=cfg.activation_dropout,
+ activation_fn="gelu",
+ layer_norm_first=cfg.layer_norm_first,
+ drop_path=drop_path,
+ norm_eps=1e-6,
+ single_qkv=cfg.single_qkv,
+ fused_ln=cfg.fused_ln,
+ )
+ elif cfg.alt_block:
+ window_size = (
+ cfg.input_size // self.patch_embed.patch_size[0],
+ cfg.input_size // self.patch_embed.patch_size[1],
+ )
+ return AltBlock(
+ cfg.embed_dim,
+ cfg.num_heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=norm_layer,
+ drop_path=drop_path,
+ layer_norm_first=cfg.layer_norm_first,
+ ffn_targets=not cfg.end_of_block_targets,
+ use_rel_pos_bias=cfg.use_rel_pos_bias,
+ window_size=window_size
+ if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias)
+ else None,
+ alt_attention=cfg.alt_attention,
+ )
+ elif cfg.alt_block2:
+ from .multi.modules import AltBlock as AltBlock2
+ return AltBlock2(
+ cfg.embed_dim,
+ cfg.num_heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=norm_layer,
+ drop_path=drop_path,
+ layer_norm_first=cfg.layer_norm_first,
+ ffn_targets=not cfg.end_of_block_targets,
+ )
+ else:
+ return Block(
+ cfg.embed_dim,
+ cfg.num_heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=norm_layer,
+ drop_path=drop_path,
+ )
+
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
+ self.norm = norm_layer(cfg.embed_dim)
+ # --------------------------------------------------------------------------
+
+ # --------------------------------------------------------------------------
+ # MAE decoder specifics
+ self.decoder_embed = (
+ nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True)
+ if not cfg.no_decoder_embed
+ else None
+ )
+
+ self.mask_token = (
+ nn.Parameter(
+ torch.zeros(
+ 1,
+ 1,
+ cfg.decoder_embed_dim
+ if not cfg.no_decoder_embed
+ else cfg.embed_dim,
+ )
+ )
+ if cfg.mask_noise_std <= 0
+ else None
+ )
+
+ self.decoder_pos_embed = (
+ nn.Parameter(
+ torch.zeros(
+ 1,
+ num_patches + 1,
+ cfg.decoder_embed_dim
+ if not cfg.no_decoder_embed
+ else cfg.embed_dim,
+ ),
+ requires_grad=False,
+ )
+ if not cfg.no_decoder_pos_embed
+ else None
+ )
+
+ self.decoder_blocks = nn.ModuleList(
+ [
+ Block(
+ cfg.decoder_embed_dim,
+ cfg.decoder_num_heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=norm_layer,
+ )
+ for _ in range(cfg.decoder_depth)
+ ]
+ )
+
+ self.decoder_norm = norm_layer(cfg.decoder_embed_dim)
+ self.decoder_pred = nn.Linear(
+ cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True
+ ) # decoder to patch
+ # --------------------------------------------------------------------------
+
+ self.norm_pix_loss = cfg.norm_pix_loss
+
+ self.initialize_weights()
+
+ for pn, p in self.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias"):
+ p.param_group = "no_decay"
+ else:
+ p.param_group = "with_decay"
+
+ def initialize_weights(self):
+ # initialization
+ # initialize (and freeze) pos_embed by sin-cos embedding
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ int(self.patch_embed.num_patches ** 0.5),
+ cls_token=not self.cfg.no_cls,
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ if self.decoder_pos_embed is not None:
+ decoder_pos_embed = get_2d_sincos_pos_embed(
+ self.decoder_pos_embed.shape[-1],
+ int(self.patch_embed.num_patches ** 0.5),
+ cls_token=not self.cfg.no_cls,
+ )
+ self.decoder_pos_embed.data.copy_(
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
+ )
+
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
+ w = self.patch_embed.proj.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
+ if self.cls_token is not None:
+ torch.nn.init.normal_(self.cls_token, std=0.02)
+
+ if self.mask_token is not None:
+ torch.nn.init.normal_(self.mask_token, std=0.02)
+
+ # initialize nn.Linear and nn.LayerNorm
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def patchify(self, imgs):
+ """
+ imgs: (N, 3, H, W)
+ x: (N, L, patch_size**2 *3)
+ """
+ p = self.patch_embed.patch_size[0]
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+ h = w = imgs.shape[2] // p
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+ x = torch.einsum("nchpwq->nhwpqc", x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
+ return x
+
+ def unpatchify(self, x):
+ """
+ x: (N, L, patch_size**2 *3)
+ imgs: (N, 3, H, W)
+ """
+ p = self.patch_embed.patch_size[0]
+ h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
+ return imgs
+
+ def random_masking(self, x, mask_ratio):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.shape # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(
+ noise, dim=1
+ ) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore # x_masked is actually unmasked x
+
+ @classmethod
+ def build_model(cls, cfg: MaeConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def forward_encoder(self, x, mask_ratio):
+ # embed patches
+ x = self.patch_embed(x)
+
+ # add pos embed w/o cls token
+ # if self.cls_token is not None:
+ # x = x + self.pos_embed
+ # else:
+ x = x + self.pos_embed[:, 1:, :]
+
+ # masking: length -> length * mask_ratio
+ if mask_ratio > 0:
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
+ else:
+ mask = ids_restore = None
+
+ # append cls token
+ if self.cls_token is not None:
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # apply Transformer blocks
+ for blk in self.blocks:
+ x = blk(x)
+
+ if self.norm is not None:
+ x = self.norm(x)
+
+ return x, mask, ids_restore
+
+ def forward_decoder(self, x, ids_restore):
+ # embed tokens
+ x = self.decoder_embed(x)
+
+ # append mask tokens to sequence
+ mask_tokens = self.mask_token.repeat(
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
+ )
+ if self.cls_token is not None:
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
+ else:
+ x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
+
+ x_ = torch.gather(
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
+ ) # unshuffle
+
+ if self.cls_token is not None:
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
+
+ # add pos embed
+ x = x + self.decoder_pos_embed
+
+ # apply Transformer blocks
+ for blk in self.decoder_blocks:
+ x = blk(x)
+ x = self.decoder_norm(x)
+
+ # predictor projection
+ x = self.decoder_pred(x)
+
+ if self.cls_token is not None:
+ # remove cls token
+ x = x[:, 1:, :]
+
+ return x
+
+ def forward_loss(self, imgs, pred, mask):
+ """
+ imgs: [N, 3, H, W]
+ pred: [N, L, p*p*3]
+ mask: [N, L], 0 is keep, 1 is remove,
+ """
+ target = self.patchify(imgs)
+ if self.norm_pix_loss:
+ mean = target.mean(dim=-1, keepdim=True)
+ var = target.var(dim=-1, keepdim=True)
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
+
+ loss = (pred - target) ** 2
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
+
+ loss = (loss * mask).sum()
+ return loss, mask.sum()
+
+ def forward(self, imgs, predictions_only=False):
+ latent, mask, ids_restore = self.forward_encoder(
+ imgs, self.mask_ratio if not predictions_only else 0
+ )
+
+ if predictions_only:
+ return latent
+
+ pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
+ loss, sample_size = self.forward_loss(imgs, pred, mask)
+
+ result = {
+ "losses": {"regression": loss},
+ "sample_size": sample_size,
+ }
+ return result
+
+ def remove_pretraining_modules(self):
+ self.decoder_embed = None
+ self.decoder_blocks = None
+ self.decoder_norm = None
+ self.decoder_pos_embed = None
+ self.decoder_pred = None
+ self.mask_token = None
+ if self.cfg.layer_norm_first:
+ self.norm = None
diff --git a/examples/data2vec/models/mae_image_classification.py b/examples/data2vec/models/mae_image_classification.py
new file mode 100644
index 0000000000..e304618dc5
--- /dev/null
+++ b/examples/data2vec/models/mae_image_classification.py
@@ -0,0 +1,386 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+
+from dataclasses import dataclass
+from enum import Enum, auto
+from typing import Any, Optional
+
+import numpy as np
+from omegaconf import II, MISSING
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq import checkpoint_utils, tasks
+from omegaconf import open_dict
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+from .mae import interpolate_pos_embed
+
+
+logger = logging.getLogger(__name__)
+
+
+class PredictionMode(Enum):
+ MEAN_POOLING = auto()
+ CLS_TOKEN = auto()
+ LIN_SOFTMAX = auto()
+
+
+@dataclass
+class MaeImageClassificationConfig(FairseqDataclass):
+ model_path: str = MISSING
+ no_pretrained_weights: bool = False
+ linear_classifier: bool = False
+ num_classes: int = 1000
+ mixup: float = 0.8
+ cutmix: float = 1.0
+ label_smoothing: float = 0.1
+
+ drop_path_rate: float = 0.1
+ layer_decay: float = 0.65
+
+ mixup_prob: float = 1.0
+ mixup_switch_prob: float = 0.5
+ mixup_mode: str = "batch"
+
+ pretrained_model_args: Any = None
+ data: str = II("task.data")
+
+ norm_eps: Optional[float] = None
+
+ remove_alibi: bool = False
+
+ # regularization overwrites
+ encoder_dropout: float = 0
+ post_mlp_drop: float = 0
+ attention_dropout: float = 0
+ activation_dropout: float = 0.0
+ dropout_input: float = 0.0
+ layerdrop: float = 0.0
+
+ prenet_layerdrop: float = 0
+ prenet_dropout: float = 0
+
+ use_fc_norm: bool = True
+ prediction_mode: PredictionMode = PredictionMode.MEAN_POOLING
+
+ no_decay_blocks: bool = True
+
+
+def get_layer_id_for_vit(name, num_layers):
+ """
+ Assign a parameter with its layer id
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+ """
+ if name in ["cls_token", "pos_embed"]:
+ return 0
+ elif name.startswith("patch_embed"):
+ return 0
+ elif name.startswith("rel_pos_bias"):
+ return num_layers - 1
+ elif name.startswith("blocks"):
+ return int(name.split(".")[1]) + 1
+ else:
+ return num_layers
+
+
+@register_model("mae_image_classification", dataclass=MaeImageClassificationConfig)
+class MaeImageClassificationModel(BaseFairseqModel):
+ def __init__(self, cfg: MaeImageClassificationConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ if cfg.pretrained_model_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
+ pretrained_args = state.get("cfg", None)
+
+ pretrained_args.criterion = None
+ pretrained_args.lr_scheduler = None
+
+ logger.info(pretrained_args.model)
+
+ with open_dict(pretrained_args.model):
+ pretrained_args.model.drop_path_rate = cfg.drop_path_rate
+ if cfg.norm_eps is not None:
+ pretrained_args.model.norm_eps = cfg.norm_eps
+
+ cfg.pretrained_model_args = pretrained_args
+
+ logger.info(pretrained_args)
+ else:
+ state = None
+ pretrained_args = cfg.pretrained_model_args
+
+ if "data" in pretrained_args.task:
+ pretrained_args.task.data = cfg.data
+ elif "image" in pretrained_args.task:
+ pretrained_args.task.image.data = cfg.data
+
+ if "modalities" in pretrained_args.model:
+ prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"]
+ model_blocks = pretrained_args.model["depth"]
+ with open_dict(pretrained_args):
+ dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist()
+ pretrained_args.model["modalities"]["image"][
+ "start_drop_path_rate"
+ ] = dpr[0]
+ pretrained_args.model["modalities"]["image"][
+ "end_drop_path_rate"
+ ] = max(0, dpr[prenet_blocks - 1])
+ pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks]
+ pretrained_args.model["end_drop_path_rate"] = dpr[-1]
+
+ if "mae_masking" in pretrained_args.model["modalities"]["image"]:
+ del pretrained_args.model["modalities"]["image"]["mae_masking"]
+
+ if cfg.remove_alibi:
+ pretrained_args.model["modalities"]["image"][
+ "use_alibi_encoder"
+ ] = False
+ if (
+ state is not None
+ and "modality_encoders.IMAGE.alibi_bias" in state["model"]
+ ):
+ del state["model"]["modality_encoders.IMAGE.alibi_bias"]
+
+ pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout
+ pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop
+ pretrained_args.model["attention_dropout"] = cfg.attention_dropout
+ pretrained_args.model["activation_dropout"] = cfg.activation_dropout
+ pretrained_args.model["dropout_input"] = cfg.dropout_input
+ pretrained_args.model["layerdrop"] = cfg.layerdrop
+
+ pretrained_args.model["modalities"]["image"][
+ "prenet_layerdrop"
+ ] = cfg.prenet_layerdrop
+ pretrained_args.model["modalities"]["image"][
+ "prenet_dropout"
+ ] = cfg.prenet_dropout
+ else:
+ # not d2v multi
+ with open_dict(pretrained_args):
+ pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate
+ pretrained_args.model["block_dropout"] = cfg.encoder_dropout
+ pretrained_args.model["attention_dropout"] = cfg.attention_dropout
+ pretrained_args.model["activation_dropout"] = cfg.activation_dropout
+
+ task = tasks.setup_task(pretrained_args.task)
+ model = task.build_model(pretrained_args.model, from_checkpoint=True)
+
+ self.d2v_multi = "data2vec_multi" in pretrained_args.model._name
+ self.linear_classifier = cfg.linear_classifier
+
+ self.model = model
+
+ if state is not None and not cfg.no_pretrained_weights:
+ interpolate_pos_embed(model, state)
+
+ if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]:
+ state["model"][
+ "modality_encoders.IMAGE.positional_encoder.positions"
+ ] = state["model"][
+ "modality_encoders.IMAGE.positional_encoder.pos_embed"
+ ]
+ del state["model"][
+ "modality_encoders.IMAGE.positional_encoder.pos_embed"
+ ]
+ if "modality_encoders.IMAGE.encoder_mask" in state["model"]:
+ del state["model"]["modality_encoders.IMAGE.encoder_mask"]
+
+ model.load_state_dict(state["model"], strict=True)
+
+ if self.d2v_multi:
+ model.remove_pretraining_modules(modality="image")
+ else:
+ model.remove_pretraining_modules()
+
+ if self.linear_classifier:
+ model.requires_grad_(False)
+
+ self.fc_norm = None
+ if self.cfg.use_fc_norm:
+ self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6)
+ nn.init.constant_(self.fc_norm.bias, 0)
+ nn.init.constant_(self.fc_norm.weight, 1.0)
+
+ self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
+
+ nn.init.trunc_normal_(self.head.weight, std=0.02)
+ nn.init.constant_(self.head.bias, 0)
+
+ self.mixup_fn = None
+
+ if cfg.mixup > 0 or cfg.cutmix > 0:
+ from timm.data import Mixup
+
+ self.mixup_fn = Mixup(
+ mixup_alpha=cfg.mixup,
+ cutmix_alpha=cfg.cutmix,
+ cutmix_minmax=None,
+ prob=cfg.mixup_prob,
+ switch_prob=cfg.mixup_switch_prob,
+ mode=cfg.mixup_mode,
+ label_smoothing=cfg.label_smoothing,
+ num_classes=cfg.num_classes,
+ )
+
+ if self.model.norm is not None:
+ for pn, p in self.model.norm.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias"):
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
+
+ if self.fc_norm is not None:
+ for pn, p in self.fc_norm.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias"):
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
+
+ for pn, p in self.head.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias"):
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
+
+ if self.d2v_multi:
+ mod_encs = list(model.modality_encoders.values())
+ assert len(mod_encs) == 1, len(mod_encs)
+ blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks)
+ else:
+ blocks = model.blocks
+
+ num_layers = len(blocks) + 1
+ layer_scales = list(
+ cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1)
+ )
+
+ if self.d2v_multi:
+ for n, p in self.model.named_parameters():
+ optimizer_override_dict = {}
+
+ if len(p.shape) == 1 or n.endswith(".bias"):
+ optimizer_override_dict["weight_decay_scale"] = 0
+
+ p.optim_overrides = {"optimizer": optimizer_override_dict}
+
+ if cfg.layer_decay > 0:
+ for i, b in enumerate(blocks):
+ lid = i + 1
+ if layer_scales[lid] == 1.0:
+ continue
+
+ for n, p in b.named_parameters():
+ optim_override = getattr(p, "optim_overrides", {})
+ if "optimizer" not in optim_override:
+ optim_override["optimizer"] = {}
+
+ if cfg.no_decay_blocks:
+ optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
+ p.optim_overrides = optim_override
+ else:
+ optim_override["optimizer"] = {
+ "lr_scale": layer_scales[lid]
+ }
+ p.optim_overrides = optim_override
+
+ else:
+ for n, p in self.model.named_parameters():
+ optimizer_override_dict = {}
+ layer_id = get_layer_id_for_vit(n, num_layers)
+
+ if len(p.shape) == 1 or n.endswith(".bias"):
+ optimizer_override_dict["weight_decay_scale"] = 0
+
+ if cfg.layer_decay > 0:
+ optimizer_override_dict["lr_scale"] = layer_scales[layer_id]
+ p.optim_overrides = {"optimizer": optimizer_override_dict}
+
+ @classmethod
+ def build_model(cls, cfg: MaeImageClassificationConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def forward(
+ self,
+ imgs,
+ labels=None,
+ ):
+ if self.training and self.mixup_fn is not None and labels is not None:
+ imgs, labels = self.mixup_fn(imgs, labels)
+
+ if self.linear_classifier:
+ with torch.no_grad():
+ x = self.model_forward(imgs)
+ else:
+ x = self.model_forward(imgs)
+
+ if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING:
+ x = x.mean(dim=1)
+ elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
+ x = x[:, 0]
+ elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX:
+ dtype = x.dtype
+ x = F.logsigmoid(x.float())
+ x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1)
+ x = x.clamp(max=0)
+ x = x - torch.log(-(torch.expm1(x)))
+ x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0)
+ x = x.to(dtype=dtype)
+ else:
+ raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}")
+
+ if self.fc_norm is not None:
+ x = self.fc_norm(x)
+
+ x = self.head(x)
+
+ if labels is None:
+ return x
+
+ if self.training and self.mixup_fn is not None:
+ loss = -labels * F.log_softmax(x.float(), dim=-1)
+ else:
+ loss = F.cross_entropy(
+ x.float(),
+ labels,
+ label_smoothing=self.cfg.label_smoothing if self.training else 0,
+ reduction="none",
+ )
+
+ result = {
+ "losses": {"regression": loss},
+ "sample_size": imgs.size(0),
+ }
+
+ if not self.training:
+ with torch.no_grad():
+ pred = x.argmax(-1)
+ correct = (pred == labels).sum()
+ result["correct"] = correct
+
+ return result
+
+ def model_forward(self, imgs):
+ if self.d2v_multi:
+ x = self.model.extract_features(
+ imgs,
+ mode="IMAGE",
+ mask=False,
+ remove_extra_tokens=(
+ self.cfg.prediction_mode != PredictionMode.CLS_TOKEN
+ ),
+ )["x"]
+ else:
+ x = self.model(imgs, predictions_only=True)
+ if (
+ "no_cls" not in self.model.cfg or not self.model.cfg.no_cls
+ ) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
+ x = x[:, 1:]
+ return x
diff --git a/examples/data2vec/models/modalities/__init__.py b/examples/data2vec/models/modalities/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/data2vec/models/modalities/audio.py b/examples/data2vec/models/modalities/audio.py
new file mode 100644
index 0000000000..80d2857b24
--- /dev/null
+++ b/examples/data2vec/models/modalities/audio.py
@@ -0,0 +1,192 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import partial
+import torch
+import torch.nn as nn
+import numpy as np
+from dataclasses import dataclass, field
+from typing import Callable, Dict, Optional
+from fairseq.models.wav2vec import ConvFeatureExtractionModel
+from fairseq.modules import (
+ LayerNorm,
+ SamePad,
+ TransposeLast,
+)
+from fairseq.tasks import FairseqTask
+from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
+from .modules import BlockEncoder, Decoder1d
+from examples.data2vec.data.modality import Modality
+
+
+@dataclass
+class D2vAudioConfig(D2vModalityConfig):
+ type: Modality = Modality.AUDIO
+ extractor_mode: str = "layer_norm"
+ feature_encoder_spec: str = field(
+ default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
+ metadata={
+ "help": "string describing convolutional feature extraction layers in form of a python list that contains "
+ "[(dim, kernel_size, stride), ...]"
+ },
+ )
+ conv_pos_width: int = field(
+ default=95,
+ metadata={"help": "number of filters for convolutional positional embeddings"},
+ )
+ conv_pos_groups: int = field(
+ default=16,
+ metadata={"help": "number of groups for convolutional positional embedding"},
+ )
+ conv_pos_depth: int = field(
+ default=5,
+ metadata={"help": "depth of positional encoder network"},
+ )
+ conv_pos_pre_ln: bool = False
+
+
+class AudioEncoder(ModalitySpecificEncoder):
+
+ modality_cfg: D2vAudioConfig
+
+ def __init__(
+ self,
+ modality_cfg: D2vAudioConfig,
+ embed_dim: int,
+ make_block: Callable[[float], nn.ModuleList],
+ norm_layer: Callable[[int], nn.LayerNorm],
+ layer_norm_first: bool,
+ alibi_biases: Dict,
+ task: Optional[FairseqTask],
+ ):
+
+ self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
+ feature_embed_dim = self.feature_enc_layers[-1][0]
+
+ local_encoder = ConvFeatureExtractionModel(
+ conv_layers=self.feature_enc_layers,
+ dropout=0.0,
+ mode=modality_cfg.extractor_mode,
+ conv_bias=False,
+ )
+
+ project_features = nn.Sequential(
+ TransposeLast(),
+ nn.LayerNorm(feature_embed_dim),
+ nn.Linear(feature_embed_dim, embed_dim),
+ )
+
+ num_pos_layers = modality_cfg.conv_pos_depth
+ k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
+
+ positional_encoder = nn.Sequential(
+ TransposeLast(),
+ *[
+ nn.Sequential(
+ nn.Conv1d(
+ embed_dim,
+ embed_dim,
+ kernel_size=k,
+ padding=k // 2,
+ groups=modality_cfg.conv_pos_groups,
+ ),
+ SamePad(k),
+ TransposeLast(),
+ LayerNorm(embed_dim, elementwise_affine=False),
+ TransposeLast(),
+ nn.GELU(),
+ )
+ for _ in range(num_pos_layers)
+ ],
+ TransposeLast(),
+ )
+
+ if modality_cfg.conv_pos_pre_ln:
+ positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
+
+ dpr = np.linspace(
+ modality_cfg.start_drop_path_rate,
+ modality_cfg.end_drop_path_rate,
+ modality_cfg.prenet_depth,
+ )
+ context_encoder = BlockEncoder(
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
+ norm_layer(embed_dim) if not layer_norm_first else None,
+ layer_norm_first,
+ modality_cfg.prenet_layerdrop,
+ modality_cfg.prenet_dropout,
+ )
+
+ decoder = (
+ Decoder1d(modality_cfg.decoder, embed_dim)
+ if modality_cfg.decoder is not None
+ else None
+ )
+
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
+
+ super().__init__(
+ modality_cfg=modality_cfg,
+ embed_dim=embed_dim,
+ local_encoder=local_encoder,
+ project_features=project_features,
+ fixed_positional_encoder=None,
+ relative_positional_encoder=positional_encoder,
+ context_encoder=context_encoder,
+ decoder=decoder,
+ get_alibi_bias=alibi_bias_fn,
+ )
+
+ def convert_padding_mask(self, x, padding_mask):
+ def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ return torch.floor((input_length - kernel_size) / stride + 1)
+
+ for i in range(len(self.feature_enc_layers)):
+ input_lengths = _conv_out_length(
+ input_lengths,
+ self.feature_enc_layers[i][1],
+ self.feature_enc_layers[i][2],
+ )
+
+ return input_lengths.to(torch.long)
+
+ if padding_mask is not None:
+ input_lengths = (1 - padding_mask.long()).sum(-1)
+ # apply conv formula to get real output_lengths
+ output_lengths = get_feat_extract_output_lengths(input_lengths)
+
+ if padding_mask.any():
+ padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
+
+ # these two operations makes sure that all values
+ # before the output lengths indices are attended to
+ padding_mask[
+ (
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
+ output_lengths - 1,
+ )
+ ] = 1
+ padding_mask = (
+ 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
+ ).bool()
+ else:
+ padding_mask = torch.zeros(
+ x.shape[:2], dtype=torch.bool, device=x.device
+ )
+
+ return padding_mask
+
+ def reset_parameters(self):
+ super().reset_parameters()
+ for mod in self.project_features.children():
+ if isinstance(mod, nn.Linear):
+ mod.reset_parameters()
+ if self.decoder is not None:
+ self.decoder.reset_parameters()
diff --git a/examples/data2vec/models/modalities/base.py b/examples/data2vec/models/modalities/base.py
new file mode 100644
index 0000000000..642cc84661
--- /dev/null
+++ b/examples/data2vec/models/modalities/base.py
@@ -0,0 +1,684 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from collections import namedtuple
+from dataclasses import dataclass
+from functools import partial
+from omegaconf import MISSING, II
+from typing import Optional, Callable
+from fairseq.data.data_utils import compute_mask_indices
+from fairseq.modules import GradMultiply
+from fairseq.utils import index_put
+from examples.data2vec.data.modality import Modality
+from .modules import D2vDecoderConfig
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class D2vModalityConfig:
+ type: Modality = MISSING
+ prenet_depth: int = 4
+ prenet_layerdrop: float = 0
+ prenet_dropout: float = 0
+ start_drop_path_rate: float = 0
+ end_drop_path_rate: float = 0
+
+ num_extra_tokens: int = 0
+ init_extra_token_zero: bool = True
+
+ mask_noise_std: float = 0.01
+ mask_prob_min: Optional[float] = None
+ mask_prob: float = 0.7
+ inverse_mask: bool = False
+ mask_prob_adjust: float = 0
+ keep_masked_pct: float = 0
+
+ mask_length: int = 5
+ add_masks: bool = False
+ remove_masks: bool = False
+ mask_dropout: float = 0.0
+ encoder_zero_mask: bool = True
+
+ mask_channel_prob: float = 0.0
+ mask_channel_length: int = 64
+
+ ema_local_encoder: bool = False # used in data2vec_multi
+ local_grad_mult: float = 1.0
+
+ use_alibi_encoder: bool = False
+ alibi_scale: float = 1.0
+ learned_alibi: bool = False
+ alibi_max_pos: Optional[int] = None
+ learned_alibi_scale: bool = False
+ learned_alibi_scale_per_head: bool = False
+ learned_alibi_scale_per_layer: bool = False
+
+ num_alibi_heads: int = II("model.num_heads")
+ model_depth: int = II("model.depth")
+
+ decoder: Optional[D2vDecoderConfig] = D2vDecoderConfig()
+
+
+MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
+MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
+
+
+class ModalitySpecificEncoder(nn.Module):
+ def __init__(
+ self,
+ modality_cfg: D2vModalityConfig,
+ embed_dim: int,
+ local_encoder: nn.Module,
+ project_features: nn.Module,
+ fixed_positional_encoder: Optional[nn.Module],
+ relative_positional_encoder: Optional[nn.Module],
+ context_encoder: nn.Module,
+ decoder: nn.Module,
+ get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
+ ):
+ super().__init__()
+
+ self.modality_cfg = modality_cfg
+ self.local_encoder = local_encoder
+ self.project_features = project_features
+ self.fixed_positional_encoder = fixed_positional_encoder
+ self.relative_positional_encoder = relative_positional_encoder
+ self.context_encoder = context_encoder
+
+ self.decoder = decoder
+ self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
+
+ self.local_grad_mult = self.modality_cfg.local_grad_mult
+
+ self.extra_tokens = None
+ if modality_cfg.num_extra_tokens > 0:
+ self.extra_tokens = nn.Parameter(
+ torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
+ )
+ if not modality_cfg.init_extra_token_zero:
+ nn.init.normal_(self.extra_tokens)
+ elif self.extra_tokens.size(1) > 1:
+ nn.init.normal_(self.extra_tokens[:, 1:])
+
+ self.alibi_scale = None
+ if self.get_alibi_bias is not None:
+ self.alibi_scale = nn.Parameter(
+ torch.full(
+ (
+ (modality_cfg.prenet_depth + modality_cfg.model_depth)
+ if modality_cfg.learned_alibi_scale_per_layer
+ else 1,
+ 1,
+ self.modality_cfg.num_alibi_heads
+ if modality_cfg.learned_alibi_scale_per_head
+ else 1,
+ 1,
+ 1,
+ ),
+ modality_cfg.alibi_scale,
+ dtype=torch.float,
+ ),
+ requires_grad=modality_cfg.learned_alibi_scale,
+ )
+
+ if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
+ assert modality_cfg.alibi_max_pos is not None
+ alibi_bias = self.get_alibi_bias(
+ batch_size=1,
+ time_steps=modality_cfg.alibi_max_pos,
+ heads=modality_cfg.num_alibi_heads,
+ scale=1.0,
+ dtype=torch.float,
+ device="cpu",
+ )
+ self.alibi_bias = nn.Parameter(alibi_bias)
+ self.get_alibi_bias = partial(
+ _learned_alibi_bias, alibi_bias=self.alibi_bias
+ )
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ k = f"{name}.alibi_scale"
+ if k in state_dict and state_dict[k].dim() == 4:
+ state_dict[k] = state_dict[k].unsqueeze(0)
+
+ return state_dict
+
+ def convert_padding_mask(self, x, padding_mask):
+ return padding_mask
+
+ def decoder_input(self, x, mask_info: MaskInfo):
+ inp_drop = self.modality_cfg.decoder.input_dropout
+ if inp_drop > 0:
+ x = F.dropout(x, inp_drop, training=self.training, inplace=True)
+
+ num_extra = self.modality_cfg.num_extra_tokens
+
+ if mask_info is not None:
+ num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra
+
+ mask_tokens = x.new_empty(
+ x.size(0),
+ num_masked,
+ x.size(-1),
+ ).normal_(0, self.modality_cfg.mask_noise_std)
+
+ x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1)
+ x = torch.gather(x_, dim=1, index=mask_info.ids_restore)
+
+ if self.modality_cfg.decoder.add_positions_masked:
+ assert self.fixed_positional_encoder is not None
+ pos = self.fixed_positional_encoder(x, None)
+ x = x + (pos * mask_info.mask.unsqueeze(-1))
+ else:
+ x = x[:, num_extra:]
+
+ if self.modality_cfg.decoder.add_positions_all:
+ assert self.fixed_positional_encoder is not None
+ x = x + self.fixed_positional_encoder(x, None)
+
+ return x, mask_info
+
+ def local_features(self, features):
+ if self.local_grad_mult > 0:
+ if self.local_grad_mult == 1.0:
+ x = self.local_encoder(features)
+ else:
+ x = GradMultiply.apply(
+ self.local_encoder(features), self.local_grad_mult
+ )
+ else:
+ with torch.no_grad():
+ x = self.local_encoder(features)
+
+ x = self.project_features(x)
+ return x
+
+ def contextualized_features(
+ self,
+ x,
+ padding_mask,
+ mask,
+ remove_masked,
+ clone_batch: int = 1,
+ mask_seeds: Optional[torch.Tensor] = None,
+ precomputed_mask=None,
+ ):
+
+ if padding_mask is not None:
+ padding_mask = self.convert_padding_mask(x, padding_mask)
+
+ local_features = x
+ if mask and clone_batch == 1:
+ local_features = local_features.clone()
+
+ orig_B, orig_T, _ = x.shape
+ pre_mask_B = orig_B
+ mask_info = None
+
+ x_pos = None
+ if self.fixed_positional_encoder is not None:
+ x = x + self.fixed_positional_encoder(x, padding_mask)
+
+ if mask:
+ if clone_batch > 1:
+ x = x.repeat_interleave(clone_batch, 0)
+ if mask_seeds is not None:
+ clone_hash = [
+ int(hash((mask_seeds.seed, ind)) % 1e10)
+ for ind in range(clone_batch - 1)
+ ]
+ clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
+
+ id = mask_seeds.ids
+ id = id.repeat_interleave(clone_batch, 0)
+ id = id.view(-1, clone_batch) + clone_hash.to(id)
+ id = id.view(-1)
+ mask_seeds = MaskSeed(
+ seed=mask_seeds.seed, update=mask_seeds.update, ids=id
+ )
+ if padding_mask is not None:
+ padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
+
+ x, mask_info = self.compute_mask(
+ x,
+ padding_mask,
+ mask_seed=mask_seeds,
+ apply=self.relative_positional_encoder is not None or not remove_masked,
+ precomputed_mask=precomputed_mask,
+ )
+
+ if self.relative_positional_encoder is not None:
+ x_pos = self.relative_positional_encoder(x)
+
+ masked_padding_mask = padding_mask
+ if mask and remove_masked:
+ x = mask_info.x_unmasked
+ if x_pos is not None:
+ x = x + gather_unmasked(x_pos, mask_info)
+
+ if padding_mask is not None and padding_mask.any():
+ masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
+ if not masked_padding_mask.any():
+ masked_padding_mask = None
+ else:
+ masked_padding_mask = None
+
+ elif x_pos is not None:
+ x = x + x_pos
+
+ alibi_bias = None
+ alibi_scale = self.alibi_scale
+
+ if self.get_alibi_bias is not None:
+ alibi_bias = self.get_alibi_bias(
+ batch_size=pre_mask_B,
+ time_steps=orig_T,
+ heads=self.modality_cfg.num_alibi_heads,
+ dtype=torch.float32,
+ device=x.device,
+ )
+
+ if alibi_scale is not None:
+ alibi_scale = alibi_scale.clamp_min(0)
+ if alibi_scale.size(0) == 1:
+ alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
+ alibi_scale = None
+
+ if clone_batch > 1:
+ alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
+
+ if mask_info is not None and remove_masked:
+ alibi_bias = masked_alibi(alibi_bias, mask_info)
+
+ if self.extra_tokens is not None:
+ num = self.extra_tokens.size(1)
+ x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
+ if masked_padding_mask is not None:
+ # B x T
+ masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
+ if alibi_bias is not None:
+ # B x H x T x T
+ alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
+
+ x = self.context_encoder(
+ x,
+ masked_padding_mask,
+ alibi_bias,
+ alibi_scale[: self.modality_cfg.prenet_depth]
+ if alibi_scale is not None
+ else None,
+ )
+
+ return {
+ "x": x,
+ "local_features": local_features,
+ "padding_mask": masked_padding_mask,
+ "alibi_bias": alibi_bias,
+ "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
+ if alibi_scale is not None and alibi_scale.size(0) > 1
+ else alibi_scale,
+ "encoder_mask": mask_info,
+ }
+
+ def forward(
+ self,
+ features,
+ padding_mask,
+ mask: bool,
+ remove_masked: bool,
+ clone_batch: int = 1,
+ mask_seeds: Optional[torch.Tensor] = None,
+ precomputed_mask=None,
+ ):
+ x = self.local_features(features)
+ return self.contextualized_features(
+ x,
+ padding_mask,
+ mask,
+ remove_masked,
+ clone_batch,
+ mask_seeds,
+ precomputed_mask,
+ )
+
+ def reset_parameters(self):
+ pass
+
+ def compute_mask(
+ self,
+ x,
+ padding_mask,
+ mask_seed: Optional[MaskSeed],
+ apply,
+ precomputed_mask,
+ ):
+ if precomputed_mask is not None:
+ mask = precomputed_mask
+ mask_info = self.make_maskinfo(x, mask)
+ else:
+ B, T, C = x.shape
+ cfg = self.modality_cfg
+
+ mask_prob = cfg.mask_prob
+
+ if (
+ cfg.mask_prob_min is not None
+ and cfg.mask_prob_min >= 0
+ and cfg.mask_prob_min < mask_prob
+ ):
+ mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
+
+ if mask_prob > 0:
+ if cfg.mask_length == 1:
+ mask_info = random_masking(x, mask_prob, mask_seed)
+ else:
+ if self.modality_cfg.inverse_mask:
+ mask_prob = 1 - mask_prob
+
+ mask = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ mask_prob,
+ cfg.mask_length,
+ min_masks=1,
+ require_same_masks=True,
+ mask_dropout=cfg.mask_dropout,
+ add_masks=cfg.add_masks,
+ seed=mask_seed.seed if mask_seed is not None else None,
+ epoch=mask_seed.update if mask_seed is not None else None,
+ indices=mask_seed.ids if mask_seed is not None else None,
+ )
+
+ mask = torch.from_numpy(mask).to(device=x.device)
+ if self.modality_cfg.inverse_mask:
+ mask = 1 - mask
+ mask_info = self.make_maskinfo(x, mask)
+ else:
+ mask_info = None
+
+ if apply:
+ x = self.apply_mask(x, mask_info)
+
+ return x, mask_info
+
+ def make_maskinfo(self, x, mask, shape=None):
+ if shape is None:
+ B, T, D = x.shape
+ else:
+ B, T, D = shape
+
+ mask = mask.to(torch.uint8)
+ ids_shuffle = mask.argsort(dim=1)
+ ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
+
+ len_keep = T - mask[0].sum()
+ if self.modality_cfg.keep_masked_pct > 0:
+ len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
+
+ ids_keep = ids_shuffle[:, :len_keep]
+
+ if shape is not None:
+ x_unmasked = None
+ else:
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
+
+ mask_info = MaskInfo(
+ x_unmasked=x_unmasked,
+ mask=mask,
+ ids_restore=ids_restore,
+ ids_keep=ids_keep,
+ )
+ return mask_info
+
+ def apply_mask(self, x, mask_info):
+ cfg = self.modality_cfg
+ B, T, C = x.shape
+
+ if mask_info is not None:
+ mask = mask_info.mask
+ if cfg.encoder_zero_mask:
+ x = x * (1 - mask.type_as(x).unsqueeze(-1))
+ else:
+ num_masks = mask.sum().item()
+ masks = x.new_empty(num_masks, x.size(-1)).normal_(
+ 0, cfg.mask_noise_std
+ )
+ x = index_put(x, mask, masks)
+ if cfg.mask_channel_prob > 0:
+ mask_channel = compute_mask_indices(
+ (B, C),
+ None,
+ cfg.mask_channel_prob,
+ cfg.mask_channel_length,
+ )
+ mask_channel = (
+ torch.from_numpy(mask_channel)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x = index_put(x, mask_channel, 0)
+ return x
+
+ def remove_pretraining_modules(self, keep_decoder=False):
+ if not keep_decoder:
+ self.decoder = None
+
+
+def get_annealed_rate(start, end, curr_step, total_steps):
+ if curr_step >= total_steps:
+ return end
+ r = end - start
+ pct_remaining = 1 - curr_step / total_steps
+ return end - r * pct_remaining
+
+
+# adapted from MAE
+def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
+ N, L, D = x.shape # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ generator = None
+ if mask_seed is not None:
+ seed = int(
+ hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
+ )
+ generator = torch.Generator(device=x.device)
+ generator.manual_seed(seed)
+
+ noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
+ ids_restore = ids_shuffle.argsort(dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
+
+ return MaskInfo(
+ x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
+ )
+
+
+def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
+ return torch.gather(
+ x,
+ dim=1,
+ index=mask_info.ids_keep,
+ )
+
+
+def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
+ return torch.gather(
+ x,
+ dim=1,
+ index=mask_info.ids_keep[..., 0], # ignore the feature dimension
+ )
+
+
+def get_alibi(
+ max_positions: int,
+ attention_heads: int,
+ dims: int = 1,
+ distance: str = "manhattan",
+):
+ def get_slopes(n):
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+ ratio = start
+ return [start * ratio**i for i in range(n)]
+
+ # In the paper, we only train models that have 2^a heads for some
+ # a. This function has some good properties that only occur when
+ # the input is a power of 2. To maintain that even when the number
+ # of heads is not a power of 2, we use this workaround.
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
+ return (
+ get_slopes_power_of_2(closest_power_of_2)
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
+ )
+
+ maxpos = max_positions
+ attn_heads = attention_heads
+ slopes = torch.Tensor(get_slopes(attn_heads))
+
+ if dims == 1:
+ # prepare alibi position linear bias. Note that wav2vec2 is non
+ # autoregressive model so we want a symmetric mask with 0 on the
+ # diagonal and other wise linear decreasing valuees
+ pos_bias = (
+ torch.abs(
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
+ )
+ * -1
+ )
+ elif dims == 2:
+ if distance == "manhattan":
+ df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
+ elif distance == "euclidean":
+ df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
+
+ n = math.sqrt(max_positions)
+ assert n.is_integer(), n
+ n = int(n)
+
+ pos_bias = torch.zeros((max_positions, max_positions))
+
+ for i in range(n):
+ for j in range(n):
+ for k in range(n):
+ for l in range(n):
+ new_x = i * n + j
+ new_y = k * n + l
+ pos_bias[new_x, new_y] = -df(i, j, k, l)
+
+ else:
+ raise Exception(f"unsupported number of alibi dims: {dims}")
+
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
+ attn_heads, -1, -1
+ )
+
+ return alibi_bias
+
+
+def get_alibi_bias(
+ alibi_biases,
+ batch_size,
+ time_steps,
+ heads,
+ dtype,
+ device,
+ dims=1,
+ distance="manhattan",
+):
+ cache_key = f"{dims}_{heads}_{distance}"
+
+ buffered = alibi_biases.get(cache_key, None)
+
+ target_size = heads * batch_size
+ if (
+ buffered is None
+ or buffered.size(0) < target_size
+ or buffered.size(1) < time_steps
+ or buffered.dtype != dtype
+ or buffered.device != device
+ ):
+ bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
+ bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
+
+ buffered = (
+ get_alibi(bt, heads, dims=dims, distance=distance)
+ .to(dtype=dtype, device=device)
+ .repeat(bn, 1, 1)
+ )
+
+ alibi_biases[cache_key] = buffered
+
+ b = buffered[:target_size, :time_steps, :time_steps]
+ b = b.view(batch_size, heads, time_steps, time_steps)
+ return b
+
+
+def _learned_alibi_bias(
+ alibi_bias,
+ batch_size,
+ time_steps,
+ heads,
+ scale,
+ dtype,
+ device,
+):
+ assert alibi_bias.size(1) == heads, alibi_bias.shape
+ assert alibi_bias.dtype == dtype, alibi_bias.dtype
+ assert alibi_bias.device == device, alibi_bias.device
+
+ if alibi_bias.size(-1) < time_steps:
+ psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
+ alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
+
+ alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
+ return alibi_bias[..., :time_steps, :time_steps]
+
+
+def masked_alibi(alibi_bias, mask_info):
+ H = alibi_bias.size(1)
+
+ orig_bias = alibi_bias
+
+ index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
+ alibi_bias = torch.gather(
+ orig_bias,
+ dim=-2,
+ index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
+ )
+ alibi_bias = torch.gather(
+ alibi_bias,
+ dim=-1,
+ index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
+ )
+
+ return alibi_bias
diff --git a/examples/data2vec/models/modalities/images.py b/examples/data2vec/models/modalities/images.py
new file mode 100644
index 0000000000..a6b738cb07
--- /dev/null
+++ b/examples/data2vec/models/modalities/images.py
@@ -0,0 +1,256 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from functools import partial
+from dataclasses import dataclass
+from typing import Callable, Dict, Optional
+from timm.models.layers import to_2tuple
+from fairseq.tasks import FairseqTask
+from examples.data2vec.models.mae import get_2d_sincos_pos_embed, PatchEmbed
+from .base import (
+ D2vModalityConfig,
+ ModalitySpecificEncoder,
+ get_alibi_bias,
+ MaskSeed,
+)
+from .modules import (
+ BlockEncoder,
+ Decoder2d,
+ FixedPositionalEncoder,
+ TransformerDecoder,
+ EncDecTransformerDecoder,
+)
+from examples.data2vec.data.modality import Modality
+
+
+@dataclass
+class D2vImageConfig(D2vModalityConfig):
+ type: Modality = Modality.IMAGE
+
+ input_size: int = 224
+ in_chans: int = 3
+ patch_size: int = 16
+ embed_dim: int = 768
+
+ alibi_dims: int = 2
+ alibi_distance: str = "manhattan"
+
+ fixed_positions: bool = True
+
+ transformer_decoder: bool = False
+ enc_dec_transformer: bool = False
+
+
+class ImageEncoder(ModalitySpecificEncoder):
+
+ modality_cfg: D2vImageConfig
+
+ def __init__(
+ self,
+ modality_cfg: D2vImageConfig,
+ embed_dim: int,
+ make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList],
+ norm_layer: Callable[[int], nn.LayerNorm],
+ layer_norm_first: bool,
+ alibi_biases: Dict,
+ task: Optional[FairseqTask],
+ ):
+
+ img_size = to_2tuple(modality_cfg.input_size)
+ patch_size = to_2tuple(modality_cfg.patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+
+ local_encoder = PatchEmbed(
+ modality_cfg.input_size,
+ modality_cfg.patch_size,
+ modality_cfg.in_chans,
+ modality_cfg.embed_dim,
+ )
+
+ w = local_encoder.proj.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ if modality_cfg.embed_dim != embed_dim:
+ local_encoder = nn.Sequential(
+ local_encoder,
+ nn.Linear(modality_cfg.embed_dim, embed_dim),
+ )
+
+ project_features = nn.Identity()
+
+ pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches, embed_dim), requires_grad=False
+ )
+
+ side_n = int(num_patches ** 0.5)
+
+ emb = get_2d_sincos_pos_embed(
+ pos_embed.shape[-1],
+ side_n,
+ cls_token=False,
+ )
+ pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
+ fixed_positional_encoder = (
+ FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None
+ )
+
+ dpr = np.linspace(
+ modality_cfg.start_drop_path_rate,
+ modality_cfg.end_drop_path_rate,
+ modality_cfg.prenet_depth,
+ )
+
+ context_encoder = BlockEncoder(
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
+ norm_layer(embed_dim) if not layer_norm_first else None,
+ layer_norm_first,
+ modality_cfg.prenet_layerdrop,
+ modality_cfg.prenet_dropout,
+ )
+
+ if modality_cfg.transformer_decoder:
+ if modality_cfg.enc_dec_transformer:
+ decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim)
+ else:
+ dec_enc = BlockEncoder(
+ nn.ModuleList(
+ make_block(0, modality_cfg.decoder.decoder_dim, 8)
+ for _ in range(modality_cfg.decoder.decoder_layers)
+ ),
+ None,
+ layer_norm_first,
+ 0,
+ 0,
+ )
+ decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc)
+ else:
+ decoder = (
+ Decoder2d(modality_cfg.decoder, embed_dim, side_n, side_n)
+ if modality_cfg.decoder is not None
+ else None
+ )
+
+ alibi_bias_fn = partial(
+ get_alibi_bias,
+ alibi_biases=alibi_biases,
+ heads=modality_cfg.num_alibi_heads,
+ dims=modality_cfg.alibi_dims,
+ distance=modality_cfg.alibi_distance,
+ )
+
+ super().__init__(
+ modality_cfg=modality_cfg,
+ embed_dim=embed_dim,
+ local_encoder=local_encoder,
+ project_features=project_features,
+ fixed_positional_encoder=fixed_positional_encoder,
+ relative_positional_encoder=None,
+ context_encoder=context_encoder,
+ decoder=decoder,
+ get_alibi_bias=alibi_bias_fn,
+ )
+
+ def reset_parameters(self):
+ super().reset_parameters()
+ if self.decoder is not None:
+ self.decoder.reset_parameters()
+
+ @torch.no_grad()
+ def patchify(self, imgs):
+ """
+ imgs: (N, 3, H, W)
+ x: (N, L, patch_size**2 *3)
+ """
+ p = self.modality_cfg.patch_size
+ h = w = imgs.shape[2] // p
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+ x = torch.einsum("nchpwq->nhwpqc", x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
+
+ return x
+
+ @torch.no_grad()
+ def unpatchify(self, x):
+ """
+ x: (N, L, patch_size**2 *3)
+ imgs: (N, 3, H, W)
+ """
+ p = self.modality_cfg.patch_size
+ h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
+ return imgs
+
+ def compute_mask(
+ self,
+ x,
+ padding_mask,
+ mask_seed: Optional[MaskSeed],
+ apply,
+ shape=None,
+ precomputed_mask=None,
+ ):
+ mlen = self.modality_cfg.mask_length
+ if mlen <= 1:
+ return super().compute_mask(
+ x, padding_mask, mask_seed, apply, precomputed_mask
+ )
+
+ if precomputed_mask is not None:
+ mask = precomputed_mask
+ else:
+ from fairseq.data.data_utils import compute_block_mask_2d
+
+ if shape is not None:
+ B, L, D = shape
+ else:
+ B, L, D = x.shape
+
+ mask = compute_block_mask_2d(
+ shape=(B, L),
+ mask_prob=self.modality_cfg.mask_prob,
+ mask_length=self.modality_cfg.mask_length,
+ mask_prob_adjust=self.modality_cfg.mask_prob_adjust,
+ inverse_mask=self.modality_cfg.inverse_mask,
+ require_same_masks=True,
+ mask_dropout=self.modality_cfg.mask_dropout,
+ )
+
+ mask_info = self.make_maskinfo(x, mask, shape)
+ if apply:
+ x = self.apply_mask(x, mask_info)
+
+ return x, mask_info
+
+ def decoder_input(self, x, mask_info):
+ if (
+ not self.modality_cfg.transformer_decoder
+ or not self.modality_cfg.enc_dec_transformer
+ ):
+ return super().decoder_input(x, mask_info)
+
+ inp_drop = self.modality_cfg.decoder.input_dropout
+ if inp_drop > 0:
+ x = F.dropout(x, inp_drop, training=self.training, inplace=True)
+
+ kv = x[:, self.modality_cfg.num_extra_tokens :]
+
+ assert self.fixed_positional_encoder is not None
+ pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1)
+
+ mask = mask_info.mask.bool()
+ if self.modality_cfg.decoder.add_positions_all:
+ kv = kv + pos[~mask].view(kv.shape)
+
+ q = pos[mask].view(x.size(0), -1, x.size(-1))
+
+ return q, kv
diff --git a/examples/data2vec/models/modalities/modules.py b/examples/data2vec/models/modalities/modules.py
new file mode 100644
index 0000000000..a4e1a4ea07
--- /dev/null
+++ b/examples/data2vec/models/modalities/modules.py
@@ -0,0 +1,589 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from dataclasses import dataclass
+from fairseq.modules import (
+ LayerNorm,
+ SamePad,
+ SamePad2d,
+ TransposeLast,
+)
+
+
+@dataclass
+class D2vDecoderConfig:
+ decoder_dim: int = 384
+ decoder_groups: int = 16
+ decoder_kernel: int = 5
+ decoder_layers: int = 5
+ input_dropout: float = 0.1
+
+ add_positions_masked: bool = False
+ add_positions_all: bool = False
+
+ decoder_residual: bool = True
+ projection_layers: int = 1
+ projection_ratio: float = 2.0
+
+
+class FixedPositionalEncoder(nn.Module):
+ def __init__(self, pos_embed):
+ super().__init__()
+ self.positions = pos_embed
+
+ def forward(self, x, padding_mask):
+ return self.positions
+
+
+class TextFeatPositionalEncoder(nn.Module):
+ """
+ Original encoder expects (B, T) long input. This module wraps it to take
+ local_encoder output which are (B, T, D) float tensors
+ """
+
+ def __init__(self, pos_encoder):
+ super().__init__()
+ self.pos_encoder = pos_encoder
+
+ def forward(self, x, padding_mask):
+ # assume padded token embeddings are 0s
+ # TODO: consider using padding_mask as input
+ return self.pos_encoder(x[..., 0])
+
+
+class BlockEncoder(nn.Module):
+ def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
+ super().__init__()
+ self.blocks = blocks
+ self.norm = norm_layer
+ self.layer_norm_first = layer_norm_first
+ self.layerdrop = layerdrop
+ self.dropout = nn.Dropout(dropout, inplace=True)
+
+ def forward(self, x, padding_mask, alibi_bias, alibi_scale):
+ if self.norm is not None and not self.layer_norm_first:
+ x = self.norm(x)
+
+ x = self.dropout(x)
+
+ for i, blk in enumerate(self.blocks):
+ if (
+ not self.training
+ or self.layerdrop == 0
+ or (np.random.random() > self.layerdrop)
+ ):
+ ab = alibi_bias
+ if ab is not None and alibi_scale is not None:
+ scale = (
+ alibi_scale[i]
+ if alibi_scale.size(0) > 1
+ else alibi_scale.squeeze(0)
+ )
+ ab = ab * scale.type_as(ab)
+ x, _ = blk(x, padding_mask, ab)
+
+ if self.norm is not None and self.layer_norm_first:
+ x = self.norm(x)
+
+ return x
+
+
+class DecoderBase(nn.Module):
+ decoder_cfg: D2vDecoderConfig
+
+ def __init__(self, cfg: D2vDecoderConfig):
+ super().__init__()
+
+ self.decoder_cfg = cfg
+
+ def reset_parameters(self):
+ for mod in self.proj.modules():
+ if isinstance(mod, nn.Linear):
+ mod.reset_parameters()
+
+ def add_residual(self, x, residual, i, mask_info):
+ if (
+ residual is None
+ or not self.decoder_cfg.decoder_residual
+ or residual.size(1) != x.size(1)
+ ):
+ return x
+
+ ret = x + residual
+
+ return ret
+
+
+class Decoder1d(DecoderBase):
+ def __init__(self, cfg: D2vDecoderConfig, input_dim):
+ super().__init__(cfg)
+
+ def make_block(in_dim):
+ block = [
+ nn.Conv1d(
+ in_dim,
+ cfg.decoder_dim,
+ kernel_size=cfg.decoder_kernel,
+ padding=cfg.decoder_kernel // 2,
+ groups=cfg.decoder_groups,
+ ),
+ SamePad(cfg.decoder_kernel),
+ TransposeLast(),
+ LayerNorm(cfg.decoder_dim, elementwise_affine=False),
+ TransposeLast(),
+ nn.GELU(),
+ ]
+
+ return nn.Sequential(*block)
+
+ self.blocks = nn.Sequential(
+ *[
+ make_block(input_dim if i == 0 else cfg.decoder_dim)
+ for i in range(cfg.decoder_layers)
+ ]
+ )
+
+ projs = []
+ curr_dim = cfg.decoder_dim
+ for i in range(cfg.projection_layers - 1):
+ next_dim = int(curr_dim * cfg.projection_ratio) if i == 0 else curr_dim
+ projs.append(nn.Linear(curr_dim, next_dim))
+ projs.append(nn.GELU())
+ curr_dim = next_dim
+ projs.append(nn.Linear(curr_dim, input_dim))
+ if len(projs) == 1:
+ self.proj = projs[0]
+ else:
+ self.proj = nn.Sequential(*projs)
+
+ def forward(self, x, mask_info):
+
+ x = x.transpose(1, 2)
+
+ residual = x
+
+ for i, layer in enumerate(self.blocks):
+ x = layer(x)
+ x = self.add_residual(x, residual, i, mask_info)
+ residual = x
+
+ x = x.transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class Decoder2d(DecoderBase):
+ def __init__(self, cfg: D2vDecoderConfig, input_dim, h_size, w_size):
+ super().__init__(cfg)
+
+ self.h_size = h_size
+ self.w_size = w_size
+
+ def make_block(in_dim):
+ block = [
+ nn.Conv2d(
+ in_dim,
+ cfg.decoder_dim,
+ kernel_size=cfg.decoder_kernel,
+ padding=cfg.decoder_kernel // 2,
+ groups=cfg.decoder_groups,
+ ),
+ SamePad2d(cfg.decoder_kernel),
+ TransposeLast(tranpose_dim=-3),
+ LayerNorm(cfg.decoder_dim, elementwise_affine=False),
+ TransposeLast(tranpose_dim=-3),
+ nn.GELU(),
+ ]
+
+ return nn.Sequential(*block)
+
+ self.blocks = nn.Sequential(
+ *[
+ make_block(input_dim if i == 0 else cfg.decoder_dim)
+ for i in range(cfg.decoder_layers)
+ ]
+ )
+
+ self.proj = nn.Linear(cfg.decoder_dim, input_dim)
+
+ def forward(self, x, mask_info):
+ B, T, C = x.shape
+
+ x = x.transpose(1, 2).reshape(B, C, self.h_size, self.w_size)
+
+ residual = x
+
+ for i, layer in enumerate(self.blocks):
+ x = layer(x)
+ x = self.add_residual(x, residual, i, mask_info)
+ residual = x
+
+ x = x.reshape(B, -1, T).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class TransformerDecoder(nn.Module):
+ decoder_cfg: D2vDecoderConfig
+
+ def __init__(self, cfg: D2vDecoderConfig, input_dim, encoder):
+ super().__init__()
+
+ self.decoder_cfg = cfg
+
+ self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
+
+ self.encoder = encoder
+
+ self.proj = nn.Linear(cfg.decoder_dim, input_dim)
+
+ def reset_parameters(self):
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+ self.apply(init_bert_params)
+
+ def forward(self, x, mask_info):
+ x = self.input_proj(x)
+ x = self.encoder(x, None, None, 1)
+ x = self.proj(x)
+ return x
+
+
+class AltBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ mlp_drop=0.0,
+ post_mlp_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ layer_norm_first=True,
+ ffn_targets=False,
+ cosine_attention=False,
+ ):
+ super().__init__()
+
+ self.layer_norm_first = layer_norm_first
+ self.ffn_targets = ffn_targets
+
+ from timm.models.vision_transformer import DropPath, Mlp
+
+ self.norm1 = norm_layer(dim)
+ self.attn = AltAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ cosine_attention=cosine_attention,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=mlp_drop,
+ )
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
+
+ def forward(self, x, padding_mask=None, alibi_bias=None):
+ if self.layer_norm_first:
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
+ r = x = self.mlp(self.norm2(x))
+ t = x
+ x = r + self.drop_path(self.post_mlp_dropout(x))
+ if not self.ffn_targets:
+ t = x
+ else:
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
+ r = x = self.norm1(x)
+ x = self.mlp(x)
+ t = x
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
+ if not self.ffn_targets:
+ t = x
+
+ return x, t
+
+
+class AltAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ cosine_attention=False,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.cosine_attention = cosine_attention
+
+ if cosine_attention:
+ self.logit_scale = nn.Parameter(
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
+ )
+
+ def forward(self, x, padding_mask=None, alibi_bias=None):
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ dtype = q.dtype
+
+ if self.cosine_attention:
+ # cosine attention
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
+ logit_scale = torch.clamp(
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
+ ).exp()
+ attn = attn * logit_scale
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if alibi_bias is not None:
+ attn = attn.type_as(alibi_bias)
+ attn[:, : alibi_bias.size(1)] += alibi_bias
+
+ if padding_mask is not None and padding_mask.any():
+ attn = attn.masked_fill(
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
+ )
+
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2) #
+ x = x.reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class EncDecAttention(nn.Module):
+ def __init__(
+ self,
+ q_dim,
+ kv_dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ cosine_attention=False,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = q_dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.q_proj = nn.Linear(q_dim, q_dim, bias=qkv_bias)
+ self.kv_proj = nn.Linear(kv_dim, 2 * q_dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(q_dim, q_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.cosine_attention = cosine_attention
+
+ if cosine_attention:
+ self.logit_scale = nn.Parameter(
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
+ )
+
+ def forward(self, q, kv, padding_mask=None, alibi_bias=None):
+ B, N, C = q.shape
+
+ q = (
+ self.q_proj(q)
+ .reshape(B, N, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ ) # B x H x L x D
+ kv = (
+ self.kv_proj(kv)
+ .reshape(B, -1, 2, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ ) # kv x B x H x L x D
+ k, v = (
+ kv[0],
+ kv[1],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ dtype = q.dtype
+
+ if self.cosine_attention:
+ # cosine attention
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
+ logit_scale = torch.clamp(
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
+ ).exp()
+ attn = attn * logit_scale
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if alibi_bias is not None:
+ attn = attn.type_as(alibi_bias)
+ attn[:, : alibi_bias.size(1)] += alibi_bias
+
+ if padding_mask is not None and padding_mask.any():
+ attn = attn.masked_fill(
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
+ )
+
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2) #
+ x = x.reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class EncDecBlock(nn.Module):
+ def __init__(
+ self,
+ q_dim,
+ kv_dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ mlp_drop=0.0,
+ post_mlp_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ layer_norm_first=True,
+ cosine_attention=False,
+ first_residual=True,
+ ):
+ super().__init__()
+
+ self.layer_norm_first = layer_norm_first
+
+ from timm.models.vision_transformer import DropPath, Mlp
+
+ self.norm1 = norm_layer(q_dim)
+ self.attn = EncDecAttention(
+ q_dim,
+ kv_dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ cosine_attention=cosine_attention,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(q_dim)
+ mlp_hidden_dim = int(q_dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=q_dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=mlp_drop,
+ )
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
+ self.first_residual = first_residual
+
+ def forward(self, q, kv, padding_mask=None, alibi_bias=None):
+ r = q if self.first_residual else 0
+ if self.layer_norm_first:
+ x = r + self.drop_path(
+ self.attn(self.norm1(q), kv, padding_mask, alibi_bias)
+ )
+ r = x = self.mlp(self.norm2(x))
+ x = r + self.drop_path(self.post_mlp_dropout(x))
+ else:
+ x = r + self.drop_path(self.attn(q, kv, padding_mask, alibi_bias))
+ r = x = self.norm1(x)
+ x = self.mlp(x)
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
+
+ return x
+
+
+class EncDecTransformerDecoder(nn.Module):
+ def __init__(self, cfg: D2vDecoderConfig, input_dim):
+ super().__init__()
+
+ self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
+
+ self.blocks = nn.Sequential(
+ *[
+ EncDecBlock(
+ q_dim=cfg.decoder_dim,
+ kv_dim=input_dim,
+ num_heads=8,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ mlp_drop=0.0,
+ post_mlp_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ layer_norm_first=False,
+ cosine_attention=False,
+ first_residual=i > 0,
+ )
+ for i in range(cfg.decoder_layers)
+ ]
+ )
+
+ self.proj = nn.Linear(cfg.decoder_dim, input_dim)
+
+ def reset_parameters(self):
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+ self.apply(init_bert_params)
+
+ def forward(self, x, kv):
+ x = self.input_proj(x)
+ for i, layer in enumerate(self.blocks):
+ x = layer(x, kv)
+
+ x = self.proj(x)
+ return x
diff --git a/examples/data2vec/models/modalities/text.py b/examples/data2vec/models/modalities/text.py
new file mode 100644
index 0000000000..adfac1ca48
--- /dev/null
+++ b/examples/data2vec/models/modalities/text.py
@@ -0,0 +1,161 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable, Dict, Optional
+
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm
+from fairseq.tasks import FairseqTask
+from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
+from .modules import BlockEncoder, Decoder1d
+from examples.data2vec.data.modality import Modality
+
+
+@dataclass
+class D2vTextConfig(D2vModalityConfig):
+ type: Modality = Modality.TEXT
+ max_source_positions: int = 512
+ learned_pos: bool = True
+ dropout: float = 0.1 # used for both local_encoder and contextualized encoder. tied with global transformer in data2vec_text
+
+ no_scale_embedding: bool = True
+ layernorm_embedding: bool = True
+ no_token_positional_embeddings: bool = False
+
+
+class TextEncoder(ModalitySpecificEncoder):
+
+ modality_cfg: D2vTextConfig
+
+ def __init__(
+ self,
+ modality_cfg: D2vTextConfig,
+ embed_dim: int,
+ make_block: Callable[[float], nn.ModuleList],
+ norm_layer: Callable[[int], nn.LayerNorm],
+ layer_norm_first: bool,
+ alibi_biases: Dict,
+ task: Optional[FairseqTask],
+ ):
+ self.pad_idx = task.source_dictionary.pad()
+ self.vocab_size = len(task.source_dictionary)
+
+ local_encoder = TextLocalEncoder(
+ vocab_size=self.vocab_size,
+ embed_dim=embed_dim,
+ max_source_positions=modality_cfg.max_source_positions,
+ pad_idx=self.pad_idx,
+ no_scale_embedding=modality_cfg.no_scale_embedding,
+ layernorm_embedding=modality_cfg.layernorm_embedding,
+ dropout=modality_cfg.dropout,
+ no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
+ learned_pos=modality_cfg.learned_pos,
+ )
+ dpr = np.linspace(
+ modality_cfg.start_drop_path_rate,
+ modality_cfg.end_drop_path_rate,
+ modality_cfg.prenet_depth,
+ )
+ context_encoder = BlockEncoder(
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
+ norm_layer(embed_dim)
+ if not layer_norm_first and modality_cfg.prenet_depth > 0
+ else None,
+ layer_norm_first,
+ modality_cfg.prenet_layerdrop,
+ modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
+ )
+ decoder = (
+ Decoder1d(modality_cfg.decoder, embed_dim)
+ if modality_cfg.decoder is not None
+ else None
+ )
+
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
+
+ super().__init__(
+ modality_cfg=modality_cfg,
+ embed_dim=embed_dim,
+ local_encoder=local_encoder,
+ project_features=nn.Identity(),
+ fixed_positional_encoder=None,
+ relative_positional_encoder=None,
+ context_encoder=context_encoder,
+ decoder=decoder,
+ get_alibi_bias=alibi_bias_fn,
+ )
+
+ def reset_parameters(self):
+ super().reset_parameters()
+
+ def convert_padding_mask(self, x, padding_mask):
+ if padding_mask is None or padding_mask.size(1) == x.size(1):
+ return padding_mask
+
+ diff = self.downsample - padding_mask.size(1) % self.downsample
+ if 0 < diff < self.downsample:
+ padding_mask = F.pad(padding_mask, (0, diff), value=True)
+
+ padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
+ padding_mask = padding_mask.all(-1)
+ if padding_mask.size(1) > x.size(1):
+ padding_mask = padding_mask[:, : x.size(1)]
+
+ assert x.size(1) == padding_mask.size(
+ 1
+ ), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
+
+ return padding_mask
+
+
+class TextLocalEncoder(nn.Module):
+ def __init__(
+ self,
+ vocab_size,
+ embed_dim,
+ max_source_positions,
+ pad_idx,
+ no_scale_embedding,
+ layernorm_embedding,
+ dropout,
+ no_token_positional_embeddings,
+ learned_pos,
+ ):
+ super().__init__()
+ self.pad_idx = pad_idx
+ self.dropout_module = FairseqDropout(dropout)
+
+ self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
+ self.embed_positions = (
+ PositionalEmbedding(
+ max_source_positions,
+ embed_dim,
+ pad_idx,
+ learned=learned_pos,
+ )
+ if not no_token_positional_embeddings
+ else None
+ )
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
+
+ self.layernorm_embedding = None
+ if layernorm_embedding:
+ self.layernorm_embedding = LayerNorm(embed_dim)
+
+ def forward(self, src_tokens):
+ x = self.embed_scale * self.embed_tokens(src_tokens)
+ if self.embed_positions is not None:
+ x = x + self.embed_positions(src_tokens)
+
+ if self.layernorm_embedding is not None:
+ x = self.layernorm_embedding(x)
+ x = self.dropout_module(x)
+ return x
diff --git a/examples/data2vec/models/utils.py b/examples/data2vec/models/utils.py
new file mode 100644
index 0000000000..0e2f240d4f
--- /dev/null
+++ b/examples/data2vec/models/utils.py
@@ -0,0 +1,55 @@
+import math
+import torch
+
+def get_alibi(
+ max_positions: int,
+ attention_heads: int,
+):
+ def get_slopes(n):
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+ ratio = start
+ return [start * ratio ** i for i in range(n)]
+
+ # In the paper, we only train models that have 2^a heads for some
+ # a. This function has some good properties that only occur when
+ # the input is a power of 2. To maintain that even when the number
+ # of heads is not a power of 2, we use this workaround.
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
+ return (
+ get_slopes_power_of_2(closest_power_of_2)
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
+ )
+
+ maxpos = max_positions
+ attn_heads = attention_heads
+ slopes = torch.Tensor(get_slopes(attn_heads))
+ # prepare alibi position linear bias. Note that wav2vec2 is non
+ # autoregressive model so we want a symmetric mask with 0 on the
+ # diagonal and other wise linear decreasing valuees
+ pos_bias = (
+ torch.abs(
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
+ )
+ * -1
+ )
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
+ attn_heads, -1, -1
+ )
+ return alibi_bias
+
+def masked_alibi(alibi_bias, mask_indices, orig_B, orig_T):
+ alibi_bias = alibi_bias.view(orig_B, -1, orig_T, orig_T)
+ H = alibi_bias.size(1)
+ alibi_mask = mask_indices.unsqueeze(1)
+ alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-1))
+ alibi_bias = alibi_bias.view(orig_B, H, -1, orig_T)
+ M = alibi_bias.size(-2)
+ alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-2))
+ alibi_bias = alibi_bias.view(-1, M, M)
+ return alibi_bias
+
+
diff --git a/examples/data2vec/scripts/convert_audioset_labels.py b/examples/data2vec/scripts/convert_audioset_labels.py
new file mode 100644
index 0000000000..7d720e606a
--- /dev/null
+++ b/examples/data2vec/scripts/convert_audioset_labels.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="convert audioset labels")
+ # fmt: off
+ parser.add_argument('in_file', help='audioset csv file to convert')
+ parser.add_argument('--manifest', required=True, metavar='PATH', help='wav2vec-like manifest')
+ parser.add_argument('--descriptors', required=True, metavar='PATH', help='path to label descriptor file')
+ parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted labels')
+ # fmt: on
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ label_descriptors = {}
+ with open(args.descriptors, "r") as ldf:
+ next(ldf)
+ for line in ldf:
+ if line.strip() == "":
+ continue
+
+ items = line.split(",")
+ assert len(items) > 2, line
+ idx = items[0]
+ lbl = items[1]
+ assert lbl not in label_descriptors, lbl
+ label_descriptors[lbl] = idx
+
+ labels = {}
+ with open(args.in_file, "r") as ifd:
+ for line in ifd:
+ if line.lstrip().startswith("#"):
+ continue
+ items = line.rstrip().split(",")
+ id = items[0].strip()
+ start = items[1].strip()
+ end = items[2].strip()
+ lbls = [label_descriptors[it.strip(' "')] for it in items[3:]]
+ labels[id] = [start, end, ",".join(lbls)]
+
+ with open(args.manifest, "r") as mf, open(args.output, "w") as of:
+ next(mf)
+ for line in mf:
+ path, _ = line.split("\t")
+ id = os.path.splitext(os.path.basename(path))[0]
+ lbl = labels[id]
+ print("\t".join(lbl), file=of)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh
new file mode 100755
index 0000000000..41bcd31fc5
--- /dev/null
+++ b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+
+set -eu
+
+job_id="$1"
+task_id="$2"
+dir="$3"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir
+
diff --git a/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh
new file mode 100644
index 0000000000..fc85908b72
--- /dev/null
+++ b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+set -eu
+
+dir="$1"
+
+echo "dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir
+
diff --git a/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh b/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh
new file mode 100755
index 0000000000..121226972b
--- /dev/null
+++ b/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+tasks[mnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MNLI-bin"
+tasks[qqp]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QQP-bin"
+tasks[sts_b]="/fsx-wav2vec/abaevski/data/nlp/GLUE/STS-B-bin"
+
+lrs=(5e-6 8e-6 1e-5 2e-5)
+
+for task data_path in ${(kv)tasks}; do
+ for lr in $lrs; do
+ echo $lr $task
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
+ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/multi/text_finetuning \
+ --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
+ model.model_path="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" +model=text_wrap
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh
new file mode 100755
index 0000000000..18b862c240
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -eu
+
+job_id="$1"
+task_id="$2"
+dir="$3"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/ft_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/ft_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_char_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_all_fair.sh b/examples/data2vec/scripts/text/finetune_all_fair.sh
new file mode 100755
index 0000000000..34a2df3990
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env zsh
+
+job_id=$1
+task_id=$2
+dir="$3"
+cp="$dir/$task_id/checkpoints/checkpoint_last.pt"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+declare -A tasks
+tasks[cola]="/private/home/jgu/data/GLUE/CoLA-bin"
+tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
+tasks[mrpc]="/private/home/jgu/data/GLUE/MRPC-bin"
+tasks[rte]="/private/home/jgu/data/GLUE/RTE-bin"
+tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" &
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_aws.sh b/examples/data2vec/scripts/text/finetune_all_fair_aws.sh
new file mode 100755
index 0000000000..b417c20024
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_aws.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env zsh
+
+job_id=$1
+task_id=$2
+dir="$3"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" &
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh
new file mode 100755
index 0000000000..64dbcb111e
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -eu
+
+job_id="$1"
+task_id="$2"
+dir="$3"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh
new file mode 100755
index 0000000000..d75c549573
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env zsh
+
+job_id=$1
+task_id=$2
+dir="$3"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" &
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh
new file mode 100755
index 0000000000..8be98c0847
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh
@@ -0,0 +1,25 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+lrs=(5e-6 8e-6 1e-5 2e-5)
+
+for task data_path in ${(kv)tasks}; do
+ for lr in $lrs; do
+ echo $lr $task
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
+ python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]"
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh
new file mode 100755
index 0000000000..d02bcc0f75
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/private/home/jgu/data/GLUE/CoLA-bin"
+tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
+tasks[mrpc]="/private/home/jgu/data/GLUE/MRPC-bin"
+tasks[rte]="/private/home/jgu/data/GLUE/RTE-bin"
+tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune/$task" &
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh
new file mode 100755
index 0000000000..75538354e1
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune/$task" &
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh
new file mode 100755
index 0000000000..16c1358b2f
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+set -eu
+
+dir="$1"
+
+echo "dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh
new file mode 100755
index 0000000000..fb5ddbe22c
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" &
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh
new file mode 100755
index 0000000000..1ffab1c850
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" +model.encoder_learned_pos=False &
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh
new file mode 100755
index 0000000000..c3c58adcb8
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -eu
+
+job_id="$1"
+task_id="$2"
+dir="$3"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh
new file mode 100644
index 0000000000..5efb00e0df
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh
@@ -0,0 +1,26 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+lrs=(5e-6 8e-6 1e-5 2e-5)
+
+for task data_path in ${(kv)tasks}; do
+ for lr in $lrs; do
+ echo $lr $task
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
+ python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" \
+ model._name=roberta_large
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh
new file mode 100755
index 0000000000..4fb21bce79
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+set -eu
+
+dir="$1"
+
+echo "dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh b/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh
new file mode 100755
index 0000000000..d7b43bee80
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
+tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
+
+lrs="5e-6 1e-5 2e-5 5e-5 1e-4 2e-4 5e-4 1e-3"
+
+for task data_path in ${(kv)tasks}; do
+ for lr in $(echo "$lrs"); do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_sweep/$task/lr_$lr" "optimization.lr=[${lr}]" &
+ done
+done
diff --git a/examples/data2vec/scripts/text/glue.py b/examples/data2vec/scripts/text/glue.py
new file mode 100644
index 0000000000..5382d31834
--- /dev/null
+++ b/examples/data2vec/scripts/text/glue.py
@@ -0,0 +1,34 @@
+from valids import parser, main as valids_main
+import os.path as osp
+
+
+args = parser.parse_args()
+args.target = "valid_accuracy"
+args.best_biggest = True
+args.best = True
+args.last = 0
+args.path_contains = None
+
+res = valids_main(args, print_output=False)
+
+grouped = {}
+for k, v in res.items():
+ k = osp.dirname(k)
+ run = osp.dirname(k)
+ task = osp.basename(k)
+ val = v["valid_accuracy"]
+
+ if run not in grouped:
+ grouped[run] = {}
+
+ grouped[run][task] = val
+
+for run, tasks in grouped.items():
+ print(run)
+ avg = sum(float(v) for v in tasks.values()) / len(tasks)
+ avg_norte = sum(float(v) for k,v in tasks.items() if k != 'rte') / (len(tasks) -1)
+ try:
+ print(f"{tasks['cola']}\t{tasks['qnli']}\t{tasks['mrpc']}\t{tasks['rte']}\t{tasks['sst_2']}\t{avg:.2f}\t{avg_norte:.2f}")
+ except:
+ print(tasks)
+ print()
diff --git a/examples/data2vec/scripts/text/glue_lr.py b/examples/data2vec/scripts/text/glue_lr.py
new file mode 100644
index 0000000000..75bdfe0368
--- /dev/null
+++ b/examples/data2vec/scripts/text/glue_lr.py
@@ -0,0 +1,143 @@
+import os.path as osp
+import re
+from collections import defaultdict
+
+from valids import parser, main as valids_main
+
+
+TASK_TO_METRIC = {
+ "cola": "mcc",
+ "qnli": "accuracy",
+ "mrpc": "acc_and_f1",
+ "rte": "accuracy",
+ "sst_2": "accuracy",
+ "mnli": "accuracy",
+ "qqp": "acc_and_f1",
+ "sts_b": "pearson_and_spearman",
+}
+TASKS = ["cola", "qnli", "mrpc", "rte", "sst_2", "mnli", "qqp", "sts_b"]
+
+
+def get_best_stat_str(task_vals, show_subdir):
+ task_to_best_val = {}
+ task_to_best_dir = {}
+ for task, subdir_to_val in task_vals.items():
+ task_to_best_val[task] = max(subdir_to_val.values())
+ task_to_best_dir[task] = max(subdir_to_val.keys(), key=lambda x: subdir_to_val[x])
+
+ # import pdb; pdb.set_trace()
+ N1 = len(task_to_best_val)
+ N2 = len([k for k in task_to_best_val if k != "rte"])
+ avg1 = sum(task_to_best_val.values()) / N1
+ avg2 = sum(v for task, v in task_to_best_val.items() if task != "rte") / N2
+
+ try:
+ msg = ""
+ for task in TASKS:
+ dir = task_to_best_dir.get(task, 'null')
+ val = task_to_best_val.get(task, -100)
+ msg += f"({dir}, {val})\t" if show_subdir else f"{val}\t"
+ msg += f"{avg1:.2f}\t{avg2:.2f}"
+ except Exception as e:
+ msg = str(e)
+ msg += str(sorted(task_vals.items()))
+ return msg
+
+def get_all_stat_str(task_vals):
+ msg = ""
+ for task in [task for task in TASKS if task in task_vals]:
+ msg += f"=== {task}\n"
+ for subdir in sorted(task_vals[task].keys()):
+ msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
+ return msg
+
+def get_tabular_stat_str(task_vals):
+ """assume subdir is /run_*/0"""
+ msg = ""
+ for task in [task for task in TASKS if task in task_vals]:
+ msg += f"=== {task}\n"
+ param_to_runs = defaultdict(dict)
+ for subdir in task_vals[task]:
+ match = re.match("(.*)/(run_.*)/0", subdir)
+ assert match, "subdir"
+ param, run = match.groups()
+ param_to_runs[param][run] = task_vals[task][subdir]
+ params = sorted(param_to_runs, key=lambda x: float(x))
+ runs = sorted(set(run for runs in param_to_runs.values() for run in runs))
+ msg += ("runs:" + "\t".join(runs) + "\n")
+ msg += ("params:" + "\t".join(params) + "\n")
+ for param in params:
+ msg += "\t".join([str(param_to_runs[param].get(run, None)) for run in runs])
+ msg += "\n"
+ # for subdir in sorted(task_vals[task].keys()):
+ # msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
+ return msg
+
+
+
+def main():
+ parser.add_argument("--show_glue", action="store_true", help="show glue metric for each task instead of accuracy")
+ parser.add_argument("--print_mode", default="best", help="best|all|tabular")
+ parser.add_argument("--show_subdir", action="store_true", help="print the subdir that has the best results for each run")
+ parser.add_argument("--override_target", default="valid_accuracy", help="override target")
+
+ args = parser.parse_args()
+ args.target = args.override_target
+ args.best_biggest = True
+ args.best = True
+ args.last = 0
+ args.path_contains = None
+
+ res = valids_main(args, print_output=False)
+ grouped_acc = {}
+ grouped_met = {} # use official metric for each task
+ for path, v in res.items():
+ path = "/".join([args.base, path])
+ path = re.sub("//*", "/", path)
+ match = re.match("(.*)finetune[^/]*/([^/]*)/(.*)", path)
+ if not match:
+ continue
+ run, task, subdir = match.groups()
+
+ if run not in grouped_acc:
+ grouped_acc[run] = {}
+ grouped_met[run] = {}
+ if task not in grouped_acc[run]:
+ grouped_acc[run][task] = {}
+ grouped_met[run][task] = {}
+
+ if v is not None:
+ grouped_acc[run][task][subdir] = float(v.get("valid_accuracy", -100))
+ grouped_met[run][task][subdir] = float(v.get(f"valid_{TASK_TO_METRIC[task]}", -100))
+ else:
+ print(f"{path} has None return")
+
+ header = "\t".join(TASKS)
+ for run in sorted(grouped_acc):
+ print(run)
+ if args.print_mode == "all":
+ if args.show_glue:
+ print("===== GLUE =====")
+ print(get_all_stat_str(grouped_met[run]))
+ else:
+ print("===== ACC =====")
+ print(get_all_stat_str(grouped_acc[run]))
+ elif args.print_mode == "best":
+ print(f" {header}")
+ if args.show_glue:
+ print(f"GLEU: {get_best_stat_str(grouped_met[run], args.show_subdir)}")
+ else:
+ print(f"ACC: {get_best_stat_str(grouped_acc[run], args.show_subdir)}")
+ elif args.print_mode == "tabular":
+ if args.show_glue:
+ print("===== GLUE =====")
+ print(get_tabular_stat_str(grouped_met[run]))
+ else:
+ print("===== ACC =====")
+ print(get_tabular_stat_str(grouped_acc[run]))
+ else:
+ raise ValueError(args.print_mode)
+ print()
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/data2vec/scripts/text/unprocess_data.py b/examples/data2vec/scripts/text/unprocess_data.py
new file mode 100644
index 0000000000..f1acb624b8
--- /dev/null
+++ b/examples/data2vec/scripts/text/unprocess_data.py
@@ -0,0 +1,188 @@
+import json
+import os
+import tqdm
+from fairseq.data import Dictionary, data_utils
+
+
+def load_dictionary(dict_path):
+ return Dictionary.load(dict_path)
+
+def load_dataset(split_path, src_dict):
+ dataset = data_utils.load_indexed_dataset(
+ split_path,
+ src_dict,
+ combine=False, # set to true for loading `train*`
+ )
+ if dataset is None:
+ raise FileNotFoundError(f"Dataset not found: {split_path}")
+ return dataset
+
+def load_bpe(enc_path):
+ with open(enc_path) as f:
+ bpe2idx = json.load(f)
+ idx2bpe = {v: k for k, v in bpe2idx.items()}
+ return bpe2idx, idx2bpe
+
+def detokenize(tokens, src_dict, idx2bpe):
+ raw_inds = map(int, src_dict.string(tokens).split())
+ raw_chrs = "".join([idx2bpe[raw_ind] for raw_ind in raw_inds])
+ raw_chrs = raw_chrs.replace("\u0120", " ")
+ return raw_chrs
+
+def _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits):
+ src_dict = load_dictionary(src_dict_path)
+ bpe2idx, idx2bpe = load_bpe(src_bpe_path)
+
+ assert len(src_splits) == len(tgt_splits)
+ for src_split, tgt_split in zip(src_splits, tgt_splits):
+ src_dataset = load_dataset(f"{src_root}/{src_split}", src_dict)
+ tgt_path = f"{tgt_root}/{tgt_split}.txt"
+ print(f"processing {src_split} (dump to {tgt_path})...")
+ os.makedirs(os.path.dirname(tgt_path), exist_ok=True)
+ with open(tgt_path, "w") as f:
+ for tokens in tqdm.tqdm(src_dataset):
+ raw_str = detokenize(tokens, src_dict, idx2bpe)
+ f.write(raw_str + "\n")
+
+def main_pt():
+ src_root = "/datasets01/bookwiki_CC-NEWS_openwebtext_stories-mmap2-bin/121219/bookwiki_CC-NEWS_openwebtext_stories-mmap2-bin"
+ src_dict_path = f"{src_root}/dict.txt"
+ src_bpe_path = f"{src_root}/encoder.json"
+ src_splits = [
+ "bookwiki_aml-mmap2-bin/shard0/train",
+ "bookwiki_aml-mmap2-bin/shard1/train",
+ "bookwiki_aml-mmap2-bin/shard2/train",
+ "bookwiki_aml-mmap2-bin/shard3/train",
+ "bookwiki_aml-mmap2-bin/shard4/train",
+ "bookwiki_aml-mmap2-bin/valid/valid",
+ ]
+
+ tgt_root = "/checkpoint/wnhsu/data/data2vec2/data/text/bookwiki_aml-full-mmap2-txt"
+ tgt_splits = [
+ "train0",
+ "train1",
+ "train2",
+ "train3",
+ "train4",
+ "valid",
+ ]
+ _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits)
+
+def main_ft():
+ src_root = "/fsx-wav2vec/wnhsu/data/data2vec2/data/text/GLUE"
+ src_dict_path = f"{src_root}/dict.txt"
+ src_bpe_path = f"{src_root}/encoder.json"
+ src_splits = [
+ "CoLA-bin/input0/train",
+ "CoLA-bin/input0/valid",
+ "CoLA-bin/input0/test",
+
+ "MNLI-bin/input0/train",
+ "MNLI-bin/input0/valid",
+ "MNLI-bin/input0/test",
+ "MNLI-bin/input0/test1",
+ "MNLI-bin/input1/train",
+ "MNLI-bin/input1/valid",
+ "MNLI-bin/input1/test",
+ "MNLI-bin/input1/test1",
+
+ "MRPC-bin/input0/train",
+ "MRPC-bin/input0/valid",
+ "MRPC-bin/input0/test",
+ "MRPC-bin/input1/train",
+ "MRPC-bin/input1/valid",
+ "MRPC-bin/input1/test",
+
+ "QNLI-bin/input0/train",
+ "QNLI-bin/input0/valid",
+ "QNLI-bin/input0/test",
+ "QNLI-bin/input1/train",
+ "QNLI-bin/input1/valid",
+ "QNLI-bin/input1/test",
+
+ "QQP-bin/input0/train",
+ "QQP-bin/input0/valid",
+ "QQP-bin/input0/test",
+ "QQP-bin/input1/train",
+ "QQP-bin/input1/valid",
+ "QQP-bin/input1/test",
+
+ "RTE-bin/input0/train",
+ "RTE-bin/input0/valid",
+ "RTE-bin/input0/test",
+ "RTE-bin/input1/train",
+ "RTE-bin/input1/valid",
+ "RTE-bin/input1/test",
+
+ "SST-2-bin/input0/train",
+ "SST-2-bin/input0/valid",
+ "SST-2-bin/input0/test",
+
+ "STS-B-bin/input0/train",
+ "STS-B-bin/input0/valid",
+ "STS-B-bin/input0/test",
+ "STS-B-bin/input1/train",
+ "STS-B-bin/input1/valid",
+ "STS-B-bin/input1/test",
+ ]
+
+ tgt_root = "/fsx-wav2vec/wnhsu/data/data2vec2/data/text/GLUE_chr"
+ tgt_splits = [
+ "CoLA-bin/input0/train",
+ "CoLA-bin/input0/valid",
+ "CoLA-bin/input0/test",
+
+ "MNLI-bin/input0/train",
+ "MNLI-bin/input0/valid",
+ "MNLI-bin/input0/test",
+ "MNLI-bin/input0/test1",
+ "MNLI-bin/input1/train",
+ "MNLI-bin/input1/valid",
+ "MNLI-bin/input1/test",
+ "MNLI-bin/input1/test1",
+
+ "MRPC-bin/input0/train",
+ "MRPC-bin/input0/valid",
+ "MRPC-bin/input0/test",
+ "MRPC-bin/input1/train",
+ "MRPC-bin/input1/valid",
+ "MRPC-bin/input1/test",
+
+ "QNLI-bin/input0/train",
+ "QNLI-bin/input0/valid",
+ "QNLI-bin/input0/test",
+ "QNLI-bin/input1/train",
+ "QNLI-bin/input1/valid",
+ "QNLI-bin/input1/test",
+
+ "QQP-bin/input0/train",
+ "QQP-bin/input0/valid",
+ "QQP-bin/input0/test",
+ "QQP-bin/input1/train",
+ "QQP-bin/input1/valid",
+ "QQP-bin/input1/test",
+
+ "RTE-bin/input0/train",
+ "RTE-bin/input0/valid",
+ "RTE-bin/input0/test",
+ "RTE-bin/input1/train",
+ "RTE-bin/input1/valid",
+ "RTE-bin/input1/test",
+
+ "SST-2-bin/input0/train",
+ "SST-2-bin/input0/valid",
+ "SST-2-bin/input0/test",
+
+ "STS-B-bin/input0/train",
+ "STS-B-bin/input0/valid",
+ "STS-B-bin/input0/test",
+ "STS-B-bin/input1/train",
+ "STS-B-bin/input1/valid",
+ "STS-B-bin/input1/test",
+ ]
+ _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits)
+
+
+if __name__ == "__main__":
+ main_pt()
+ main_ft()
diff --git a/examples/data2vec/scripts/text/valids.py b/examples/data2vec/scripts/text/valids.py
new file mode 100644
index 0000000000..b2e5cfb25d
--- /dev/null
+++ b/examples/data2vec/scripts/text/valids.py
@@ -0,0 +1,301 @@
+import os, argparse, re, json, copy, math
+from collections import OrderedDict
+import numpy as np
+
+parser = argparse.ArgumentParser(description='Process some integers.')
+parser.add_argument('base', help='base log path')
+parser.add_argument('--file_name', default='train.log', help='the log file name')
+parser.add_argument('--target', default='valid_loss', help='target metric')
+parser.add_argument('--last', type=int, default=999999999, help='print last n matches')
+parser.add_argument('--last_files', type=int, default=None, help='print last x files')
+parser.add_argument('--everything', action='store_true', help='print everything instead of only last match')
+parser.add_argument('--path_contains', help='only consider matching file pattern')
+parser.add_argument('--group_on', help='if set, groups by this metric and shows table of differences')
+parser.add_argument('--epoch', help='epoch for comparison', type=int)
+parser.add_argument('--skip_empty', action='store_true', help='skip empty results')
+parser.add_argument('--skip_containing', help='skips entries containing this attribute')
+parser.add_argument('--unique_epochs', action='store_true', help='only consider the last line fore each epoch')
+parser.add_argument('--best', action='store_true', help='print the last best result')
+parser.add_argument('--avg_params', help='average these params through entire log')
+parser.add_argument('--extract_prev', help='extracts this metric from previous line')
+
+parser.add_argument('--remove_metric', help='extracts this metric from previous line')
+
+parser.add_argument('--compact', action='store_true', help='if true, just prints checkpoint best val')
+parser.add_argument('--hydra', action='store_true', help='if true, uses hydra param conventions')
+
+parser.add_argument('--best_biggest', action='store_true', help='if true, best is the biggest number, not smallest')
+parser.add_argument('--key_len', type=int, default=10, help='max length of key')
+
+parser.add_argument('--best_only', action='store_true', help='if set, only prints the best value')
+parser.add_argument('--flat', action='store_true', help='just print the best results')
+
+
+def main(args, print_output):
+ ret = {}
+
+ entries = []
+
+ def extract_metric(s, metric):
+ try:
+ j = json.loads(s)
+ except:
+ return None
+ if args.epoch is not None and ('epoch' not in j or j['epoch'] != args.epoch):
+ return None
+ return j[metric] if metric in j else None
+
+
+ def extract_params(s):
+ s = s.replace(args.base, '', 1)
+ if args.path_contains is not None:
+ s = s.replace(args.path_contains, '', 1)
+
+ if args.hydra:
+ num_matches = re.findall(r'(?:/|__)([^/:]+):(\d+\.?\d*)', s)
+ # str_matches = re.findall(r'(?:/|__)([^/:]+):([^\.]*[^\d\.]+)(?:/|__)', s)
+ str_matches = re.findall(r'(?:/|__)?((?:(?!(?:\:|__)).)+):([^\.]*[^\d\.]+\d*)(?:/|__)', s)
+ lr_matches = re.findall(r'optimization.(lr):\[([\d\.,]+)\]', s)
+ task_matches = re.findall(r'.*/(\d+)$', s)
+ else:
+ num_matches = re.findall(r'\.?([^\.]+?)(\d+(e\-\d+)?(?:\.\d+)?)(\.|$)', s)
+ str_matches = re.findall(r'[/\.]([^\.]*[^\d\.]+\d*)(?=\.)', s)
+ lr_matches = []
+ task_matches = []
+
+ cp_matches = re.findall(r'checkpoint(?:_\d+)?_(\d+).pt', s)
+
+ items = OrderedDict()
+ for m in str_matches:
+ if isinstance(m, tuple):
+ if 'checkpoint' not in m[0]:
+ items[m[0]] = m[1]
+ else:
+ items[m] = ''
+
+ for m in num_matches:
+ items[m[0]] = m[1]
+
+ for m in lr_matches:
+ items[m[0]] = m[1]
+
+ for m in task_matches:
+ items["hydra_task"] = m
+
+ for m in cp_matches:
+ items['checkpoint'] = m
+
+ return items
+
+ abs_best = None
+
+ sources = []
+ for root, _, files in os.walk(args.base):
+ if args.path_contains is not None and not args.path_contains in root:
+ continue
+ for f in files:
+ if f.endswith(args.file_name):
+ sources.append((root, f))
+
+ if args.last_files is not None:
+ sources = sources[-args.last_files:]
+
+ for root, file in sources:
+ with open(os.path.join(root, file), 'r') as fin:
+ found = []
+ avg = {}
+ prev = None
+ for line in fin:
+ line = line.rstrip()
+ if line.find(args.target) != -1 and (
+ args.skip_containing is None or line.find(args.skip_containing) == -1):
+ try:
+ idx = line.index("{")
+ line = line[idx:]
+ line_json = json.loads(line)
+ except:
+ continue
+ if prev is not None:
+ try:
+ prev.update(line_json)
+ line_json = prev
+ except:
+ pass
+ if args.target in line_json:
+ found.append(line_json)
+ if args.avg_params:
+ avg_params = args.avg_params.split(',')
+ for p in avg_params:
+ m = extract_metric(line, p)
+ if m is not None:
+ prev_v, prev_c = avg.get(p, (0, 0))
+ avg[p] = prev_v + float(m), prev_c + 1
+ if args.extract_prev:
+ try:
+ prev = json.loads(line)
+ except:
+ pass
+ best = None
+ if args.best:
+ curr_best = None
+ for i in range(len(found)):
+ cand_best = found[i][args.target] if args.target in found[i] else None
+
+ def cmp(a, b):
+ a = float(a)
+ b = float(b)
+ if args.best_biggest:
+ return a > b
+ return a < b
+
+ if cand_best is not None and not math.isnan(float(cand_best)) and (
+ curr_best is None or cmp(cand_best, curr_best)):
+ curr_best = cand_best
+ if abs_best is None or cmp(curr_best, abs_best):
+ abs_best = curr_best
+ best = found[i]
+ if args.unique_epochs or args.epoch:
+ last_found = []
+ last_epoch = None
+ for i in reversed(range(len(found))):
+ epoch = found[i]['epoch']
+ if args.epoch and args.epoch != epoch:
+ continue
+ if epoch != last_epoch:
+ last_epoch = epoch
+ last_found.append(found[i])
+ found = list(reversed(last_found))
+
+ if len(found) == 0:
+ if print_output and (args.last_files is not None or not args.skip_empty):
+ # print(root.split('/')[-1])
+ print(root[len(args.base):])
+ print('Nothing')
+ else:
+ if not print_output:
+ ret[root[len(args.base):]] = best
+ continue
+
+ if args.compact:
+ # print('{}\t{}'.format(root.split('/')[-1], curr_best))
+ print('{}\t{}'.format(root[len(args.base)+1:], curr_best))
+ continue
+
+ if args.group_on is None and not args.best_only:
+ # print(root.split('/')[-1])
+ print(root[len(args.base):])
+ if not args.everything:
+ if best is not None and args.group_on is None and not args.best_only and not args.flat:
+ print(best, '(best)')
+ if args.group_on is None and args.last and not args.best_only and not args.flat:
+ for f in found[-args.last:]:
+ if args.extract_prev is not None:
+ try:
+ print('{}\t{}'.format(f[args.extract_prev], f[args.target]))
+ except Exception as e:
+ print('Exception!', e)
+ else:
+ print(f)
+ try:
+ metric = found[-1][args.target] if not args.best or best is None else best[args.target]
+ except:
+ print(found[-1])
+ raise
+ if metric is not None:
+ entries.append((extract_params(root), metric))
+ else:
+ for f in found:
+ print(f)
+ if not args.group_on and print_output:
+ print()
+
+ if len(avg) > 0:
+ for k, (v, c) in avg.items():
+ print(f'{k}: {v/c}')
+
+ if args.best_only:
+ print(abs_best)
+
+ if args.flat:
+ print("\t".join(m for _, m in entries))
+
+ if args.group_on is not None:
+ by_val = OrderedDict()
+ for e, m in entries:
+ k = args.group_on
+ if k not in e:
+ m_keys = [x for x in e.keys() if x.startswith(k)]
+ if len(m_keys) == 0:
+ val = "False"
+ else:
+ assert len(m_keys) == 1
+ k = m_keys[0]
+ val = m_keys[0]
+ else:
+ val = e[args.group_on]
+ if val == "":
+ val = "True"
+ scrubbed_entry = copy.deepcopy(e)
+ if k in scrubbed_entry:
+ del scrubbed_entry[k]
+ if args.remove_metric and args.remove_metric in scrubbed_entry:
+ val += '_' + scrubbed_entry[args.remove_metric]
+ del scrubbed_entry[args.remove_metric]
+ by_val.setdefault(tuple(scrubbed_entry.items()), dict())[val] = m
+ distinct_vals = set()
+ for v in by_val.values():
+ distinct_vals.update(v.keys())
+ try:
+ distinct_vals = {int(d) for d in distinct_vals}
+ except:
+ print(distinct_vals)
+ print()
+ print("by_val", len(by_val))
+ for k,v in by_val.items():
+ print(k, '=>', v)
+ print()
+
+ # , by_val, entries)
+ raise
+ from natsort import natsorted
+ svals = list(map(str, natsorted(distinct_vals)))
+ print('{}\t{}'.format(args.group_on, '\t'.join(svals)))
+ sums = OrderedDict({n:[] for n in svals})
+ for k, v in by_val.items():
+ kstr = '.'.join(':'.join(x) for x in k)
+ vstr = ''
+ for mv in svals:
+ x = v[mv] if mv in v else ''
+ vstr += '\t{}'.format(round(x, 5) if isinstance(x, float) else x)
+ try:
+ sums[mv].append(float(x))
+ except:
+ pass
+ print('{}{}'.format(kstr[:args.key_len], vstr))
+ if any(len(x) > 0 for x in sums.values()):
+ print('min:', end='')
+ for v in sums.values():
+ min = np.min(v)
+ print(f'\t{round(min, 5)}', end='')
+ print()
+ print('max:', end='')
+ for v in sums.values():
+ max = np.max(v)
+ print(f'\t{round(max, 5)}', end='')
+ print()
+ print('avg:', end='')
+ for v in sums.values():
+ mean = np.mean(v)
+ print(f'\t{round(mean, 5)}', end='')
+ print()
+ print('median:', end='')
+ for v in sums.values():
+ median = np.median(v)
+ print(f'\t{round(median, 5)}', end='')
+ print()
+
+ return ret
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ main(args, print_output=True)
\ No newline at end of file
diff --git a/examples/data2vec/tasks/__init__.py b/examples/data2vec/tasks/__init__.py
new file mode 100644
index 0000000000..a7422e4b30
--- /dev/null
+++ b/examples/data2vec/tasks/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .image_pretraining import ImagePretrainingTask, ImagePretrainingConfig
+from .image_classification import ImageClassificationTask, ImageClassificationConfig
+from .mae_image_pretraining import MaeImagePretrainingTask, MaeImagePretrainingConfig
+
+
+__all__ = [
+ "ImageClassificationTask",
+ "ImageClassificationConfig",
+ "ImagePretrainingTask",
+ "ImagePretrainingConfig",
+ "MaeImagePretrainingTask",
+ "MaeImagePretrainingConfig",
+]
\ No newline at end of file
diff --git a/examples/data2vec/tasks/audio_classification.py b/examples/data2vec/tasks/audio_classification.py
new file mode 100644
index 0000000000..2925a04cf9
--- /dev/null
+++ b/examples/data2vec/tasks/audio_classification.py
@@ -0,0 +1,167 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import logging
+import os
+import numpy as np
+import math
+import torch
+
+from sklearn import metrics as sklearn_metrics
+from dataclasses import dataclass
+
+from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
+from fairseq.tasks import register_task
+from fairseq.logging import metrics
+
+from ..data.add_class_target_dataset import AddClassTargetDataset
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class AudioClassificationConfig(AudioPretrainingConfig):
+ label_descriptors: str = "label_descriptors.csv"
+ labels: str = "lbl"
+
+
+@register_task("audio_classification", dataclass=AudioClassificationConfig)
+class AudioClassificationTask(AudioPretrainingTask):
+ """ """
+
+ cfg: AudioClassificationConfig
+
+ def __init__(
+ self,
+ cfg: AudioClassificationConfig,
+ ):
+ super().__init__(cfg)
+
+ self.state.add_factory("labels", self.load_labels)
+
+ def load_labels(self):
+ labels = {}
+ path = os.path.join(self.cfg.data, self.cfg.label_descriptors)
+ with open(path, "r") as ldf:
+ for line in ldf:
+ if line.strip() == "":
+ continue
+ items = line.split(",")
+ idx = items[0]
+ lbl = items[1]
+ assert lbl not in labels, lbl
+ labels[lbl] = idx
+ return labels
+
+ @property
+ def labels(self):
+ return self.state.labels
+
+ def load_dataset(
+ self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs
+ ):
+ super().load_dataset(split, task_cfg, **kwargs)
+
+ task_cfg = task_cfg or self.cfg
+
+ data_path = self.cfg.data
+ label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
+ skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
+ labels = []
+ with open(label_path, "r") as f:
+ for i, line in enumerate(f):
+ if i not in skipped_indices:
+ lbl_items = line.rstrip().split("\t")
+ labels.append([int(x) for x in lbl_items[2].split(",")])
+
+ assert len(labels) == len(self.datasets[split]), (
+ f"labels length ({len(labels)}) and dataset length "
+ f"({len(self.datasets[split])}) do not match"
+ )
+
+ self.datasets[split] = AddClassTargetDataset(
+ self.datasets[split],
+ labels,
+ multi_class=True,
+ add_to_input=True,
+ num_classes=len(self.labels),
+ )
+
+ def calculate_stats(self, output, target):
+
+ classes_num = target.shape[-1]
+ stats = []
+
+ # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
+ # acc = sklearn_metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
+
+ # Class-wise statistics
+ for k in range(classes_num):
+ # Average precision
+ avg_precision = sklearn_metrics.average_precision_score(
+ target[:, k], output[:, k], average=None
+ )
+
+ dict = {
+ "AP": avg_precision,
+ }
+
+ # # AUC
+ # try:
+ # auc = sklearn_metrics.roc_auc_score(target[:, k], output[:, k], average=None)
+ # except:
+ # auc = 0
+ #
+ # # Precisions, recalls
+ # (precisions, recalls, thresholds) = sklearn_metrics.precision_recall_curve(
+ # target[:, k], output[:, k]
+ # )
+ #
+ # # FPR, TPR
+ # (fpr, tpr, thresholds) = sklearn_metrics.roc_curve(target[:, k], output[:, k])
+ #
+ # save_every_steps = 1000 # Sample statistics to reduce size
+ # dict = {
+ # "precisions": precisions[0::save_every_steps],
+ # "recalls": recalls[0::save_every_steps],
+ # "AP": avg_precision,
+ # "fpr": fpr[0::save_every_steps],
+ # "fnr": 1.0 - tpr[0::save_every_steps],
+ # "auc": auc,
+ # # note acc is not class-wise, this is just to keep consistent with other metrics
+ # "acc": acc,
+ # }
+ stats.append(dict)
+
+ return stats
+
+ def valid_step(self, sample, model, criterion):
+ loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
+ return loss, sample_size, logging_output
+
+ def reduce_metrics(self, logging_outputs, criterion):
+ super().reduce_metrics(logging_outputs, criterion)
+ if "_predictions" in logging_outputs[0]:
+ metrics.log_concat_tensor(
+ "_predictions",
+ torch.cat([l["_predictions"].cpu() for l in logging_outputs], dim=0),
+ )
+ metrics.log_concat_tensor(
+ "_targets",
+ torch.cat([l["_targets"].cpu() for l in logging_outputs], dim=0),
+ )
+
+ def compute_stats(meters):
+ if meters["_predictions"].tensor.shape[0] < 100:
+ return 0
+ stats = self.calculate_stats(
+ meters["_predictions"].tensor, meters["_targets"].tensor
+ )
+ return np.nanmean([stat["AP"] for stat in stats])
+
+ metrics.log_derived("mAP", compute_stats)
diff --git a/examples/data2vec/tasks/image_classification.py b/examples/data2vec/tasks/image_classification.py
new file mode 100644
index 0000000000..1ea4c2afee
--- /dev/null
+++ b/examples/data2vec/tasks/image_classification.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import os.path as osp
+import logging
+
+from dataclasses import dataclass
+import torch
+from torchvision import transforms
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import register_task
+from fairseq.logging import metrics
+
+try:
+ from ..data import ImageDataset
+except:
+ import sys
+
+ sys.path.append("..")
+ from data import ImageDataset
+
+from .image_pretraining import (
+ ImagePretrainingConfig,
+ ImagePretrainingTask,
+ IMG_EXTENSIONS,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ImageClassificationConfig(ImagePretrainingConfig):
+ pass
+
+
+@register_task("image_classification", dataclass=ImageClassificationConfig)
+class ImageClassificationTask(ImagePretrainingTask):
+
+ cfg: ImageClassificationConfig
+
+ @classmethod
+ def setup_task(cls, cfg: ImageClassificationConfig, **kwargs):
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ data_path = self.cfg.data
+ cfg = task_cfg or self.cfg
+
+ path_with_split = osp.join(data_path, split)
+ if osp.exists(path_with_split):
+ data_path = path_with_split
+
+ from timm.data import create_transform
+
+ if split == "train":
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=cfg.input_size,
+ is_training=True,
+ auto_augment="rand-m9-mstd0.5-inc1",
+ interpolation="bicubic",
+ re_prob=0.25,
+ re_mode="pixel",
+ re_count=1,
+ mean=cfg.normalization_mean,
+ std=cfg.normalization_std,
+ )
+ if not cfg.input_size > 32:
+ transform.transforms[0] = transforms.RandomCrop(
+ cfg.input_size, padding=4
+ )
+ else:
+ t = []
+ if cfg.input_size > 32:
+ crop_pct = 1
+ if cfg.input_size < 384:
+ crop_pct = 224 / 256
+ size = int(cfg.input_size / crop_pct)
+ t.append(
+ transforms.Resize(
+ size, interpolation=3
+ ), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(cfg.input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(
+ transforms.Normalize(cfg.normalization_mean, cfg.normalization_std)
+ )
+ transform = transforms.Compose(t)
+ logger.info(transform)
+
+ self.datasets[split] = ImageDataset(
+ root=data_path,
+ extensions=IMG_EXTENSIONS,
+ load_classes=True,
+ transform=transform,
+ )
+ for k in self.datasets.keys():
+ if k != split:
+ assert self.datasets[k].classes == self.datasets[split].classes
+
+ def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
+ model = super().build_model(model_cfg, from_checkpoint)
+
+ actualized_cfg = getattr(model, "cfg", None)
+ if actualized_cfg is not None:
+ if hasattr(actualized_cfg, "pretrained_model_args"):
+ model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
+
+ return model
+
+ def reduce_metrics(self, logging_outputs, criterion):
+ super().reduce_metrics(logging_outputs, criterion)
+
+ if "correct" in logging_outputs[0]:
+ zero = torch.scalar_tensor(0.0)
+ correct = sum(log.get("correct", zero) for log in logging_outputs)
+ metrics.log_scalar_sum("_correct", correct)
+
+ metrics.log_derived(
+ "accuracy",
+ lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
+ )
diff --git a/examples/data2vec/tasks/image_pretraining.py b/examples/data2vec/tasks/image_pretraining.py
new file mode 100644
index 0000000000..cd688fd136
--- /dev/null
+++ b/examples/data2vec/tasks/image_pretraining.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import logging
+import sys
+import os.path as osp
+
+from dataclasses import dataclass, field
+from typing import List
+from omegaconf import MISSING
+
+import torch
+from torchvision import transforms
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import FairseqTask, register_task
+
+try:
+ from ..data import ImageDataset
+except:
+ sys.path.append("..")
+ from data import ImageDataset
+
+logger = logging.getLogger(__name__)
+
+IMG_EXTENSIONS = {
+ ".jpg",
+ ".jpeg",
+ ".png",
+ ".ppm",
+ ".bmp",
+ ".pgm",
+ ".tif",
+ ".tiff",
+ ".webp",
+}
+
+
+@dataclass
+class ImagePretrainingConfig(FairseqDataclass):
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
+ input_size: int = 224
+ normalization_mean: List[float] = (0.485, 0.456, 0.406)
+ normalization_std: List[float] = (0.229, 0.224, 0.225)
+
+
+@register_task("image_pretraining", dataclass=ImagePretrainingConfig)
+class ImagePretrainingTask(FairseqTask):
+ """ """
+
+ cfg: ImagePretrainingConfig
+
+ @classmethod
+ def setup_task(cls, cfg: ImagePretrainingConfig, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ cfg (AudioPretrainingConfig): configuration of this task
+ """
+
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ data_path = self.cfg.data
+ cfg = task_cfg or self.cfg
+
+ path_with_split = osp.join(data_path, split)
+ if osp.exists(path_with_split):
+ data_path = path_with_split
+
+ transform = transforms.Compose(
+ [
+ transforms.ColorJitter(0.4, 0.4, 0.4),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.RandomResizedCrop(
+ size=cfg.input_size,
+ interpolation=transforms.InterpolationMode.BICUBIC,
+ ),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=torch.tensor(cfg.normalization_mean),
+ std=torch.tensor(cfg.normalization_std),
+ ),
+ ]
+ )
+
+ logger.info(transform)
+
+ self.datasets[split] = ImageDataset(
+ root=data_path,
+ extensions=IMG_EXTENSIONS,
+ load_classes=False,
+ transform=transform,
+ )
+
+ @property
+ def source_dictionary(self):
+ return None
+
+ @property
+ def target_dictionary(self):
+ return None
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return sys.maxsize, sys.maxsize
diff --git a/examples/data2vec/tasks/mae_image_classification.py b/examples/data2vec/tasks/mae_image_classification.py
new file mode 100644
index 0000000000..1bf935879f
--- /dev/null
+++ b/examples/data2vec/tasks/mae_image_classification.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import logging
+import sys
+import torch
+
+from typing import Optional
+from dataclasses import dataclass, field
+from omegaconf import MISSING
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import FairseqTask, register_task
+from fairseq.logging import metrics
+
+try:
+ from ..data import MaeFinetuningImageDataset
+except:
+ sys.path.append("..")
+ from data import MaeFinetuningImageDataset
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class MaeImageClassificationConfig(FairseqDataclass):
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
+ input_size: int = 224
+ local_cache_path: Optional[str] = None
+
+ rebuild_batches: bool = True
+
+
+@register_task("mae_image_classification", dataclass=MaeImageClassificationConfig)
+class MaeImageClassificationTask(FairseqTask):
+ """ """
+
+ cfg: MaeImageClassificationConfig
+
+ @classmethod
+ def setup_task(cls, cfg: MaeImageClassificationConfig, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ cfg (AudioPretrainingConfig): configuration of this task
+ """
+
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ data_path = self.cfg.data
+ cfg = task_cfg or self.cfg
+
+ self.datasets[split] = MaeFinetuningImageDataset(
+ root=data_path,
+ split=split,
+ is_train=split == "train",
+ input_size=cfg.input_size,
+ local_cache_path=cfg.local_cache_path,
+ shuffle=split == "train",
+ )
+
+ def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
+ model = super().build_model(model_cfg, from_checkpoint)
+
+ actualized_cfg = getattr(model, "cfg", None)
+ if actualized_cfg is not None:
+ if hasattr(actualized_cfg, "pretrained_model_args"):
+ model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
+
+ return model
+
+ def reduce_metrics(self, logging_outputs, criterion):
+ super().reduce_metrics(logging_outputs, criterion)
+
+ if "correct" in logging_outputs[0]:
+ zero = torch.scalar_tensor(0.0)
+ correct = sum(log.get("correct", zero) for log in logging_outputs)
+ metrics.log_scalar_sum("_correct", correct)
+
+ metrics.log_derived(
+ "accuracy",
+ lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
+ )
+
+ @property
+ def source_dictionary(self):
+ return None
+
+ @property
+ def target_dictionary(self):
+ return None
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return sys.maxsize, sys.maxsize
diff --git a/examples/data2vec/tasks/mae_image_pretraining.py b/examples/data2vec/tasks/mae_image_pretraining.py
new file mode 100644
index 0000000000..35a14891ca
--- /dev/null
+++ b/examples/data2vec/tasks/mae_image_pretraining.py
@@ -0,0 +1,119 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import logging
+import sys
+
+from typing import Optional, List
+from dataclasses import dataclass, field
+from omegaconf import MISSING, II
+
+from fairseq.data import SubsampleDataset
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import FairseqTask, register_task
+
+try:
+ from ..data import MaeImageDataset
+except:
+ sys.path.append("..")
+ from data import MaeImageDataset
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ImageMaskingConfig:
+ patch_size: int = II("model.modalities.image.patch_size")
+ mask_prob: float = II("model.modalities.image.mask_prob")
+ mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust")
+ mask_length: int = II("model.modalities.image.mask_length")
+ inverse_mask: bool = II("model.modalities.image.inverse_mask")
+ mask_dropout: float = II("model.modalities.image.mask_dropout")
+ clone_batch: int = II("model.clone_batch")
+ expand_adjacent: bool = False
+ non_overlapping: bool = False
+
+
+@dataclass
+class MaeImagePretrainingConfig(FairseqDataclass):
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
+ multi_data: Optional[List[str]] = None
+ input_size: int = 224
+ local_cache_path: Optional[str] = None
+ key: str = "imgs"
+
+ beit_transforms: bool = False
+ target_transform: bool = False
+ no_transform: bool = False
+
+ rebuild_batches: bool = True
+
+ precompute_mask_config: Optional[ImageMaskingConfig] = None
+
+ subsample: float = 1
+ seed: int = II("common.seed")
+ dataset_type: str = "imagefolder"
+
+
+@register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig)
+class MaeImagePretrainingTask(FairseqTask):
+ """ """
+
+ cfg: MaeImagePretrainingConfig
+
+ @classmethod
+ def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ cfg (AudioPretrainingConfig): configuration of this task
+ """
+
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ data_path = self.cfg.data
+ cfg = task_cfg or self.cfg
+
+ compute_mask = cfg.precompute_mask_config is not None
+ mask_args = {}
+ if compute_mask:
+ mask_args = cfg.precompute_mask_config
+
+ self.datasets[split] = MaeImageDataset(
+ root=data_path if cfg.multi_data is None else cfg.multi_data,
+ split=split,
+ input_size=cfg.input_size,
+ local_cache_path=cfg.local_cache_path,
+ key=cfg.key,
+ beit_transforms=cfg.beit_transforms,
+ target_transform=cfg.target_transform,
+ no_transform=cfg.no_transform,
+ compute_mask=compute_mask,
+ dataset_type=cfg.dataset_type,
+ **mask_args,
+ )
+
+ if cfg.subsample < 1:
+ self.datasets[split] = SubsampleDataset(
+ self.datasets[split],
+ cfg.subsample,
+ shuffle=True,
+ seed=cfg.seed,
+ )
+
+ @property
+ def source_dictionary(self):
+ return None
+
+ @property
+ def target_dictionary(self):
+ return None
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return sys.maxsize, sys.maxsize
diff --git a/examples/data2vec/tasks/multimodal.py b/examples/data2vec/tasks/multimodal.py
new file mode 100644
index 0000000000..74648e918f
--- /dev/null
+++ b/examples/data2vec/tasks/multimodal.py
@@ -0,0 +1,165 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import sys
+
+from dataclasses import dataclass
+from typing import Optional, List
+from omegaconf import II
+
+from fairseq.data.iterators import GroupedEpochBatchIterator
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import FairseqTask, register_task
+from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
+from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
+from .mae_image_pretraining import MaeImagePretrainingConfig, MaeImagePretrainingTask
+from examples.data2vec.data.modality import Modality
+
+from fairseq.data.audio.multi_modality_dataset import (
+ MultiModalityDataset,
+ ModalityDatasetItem,
+)
+
+
+@dataclass
+class MultimodalPretrainingConfig(FairseqDataclass):
+ audio: Optional[AudioPretrainingConfig] = None
+ image: Optional[MaeImagePretrainingConfig] = None
+ text: Optional[MaskedLMConfig] = None
+
+ audio_ratio: float = 1
+ image_ratio: float = 1
+ text_ratio: float = 1
+
+ max_tokens: Optional[int] = II("dataset.max_tokens")
+ batch_size: Optional[int] = II("dataset.batch_size")
+ update_freq: List[int] = II("optimization.update_freq")
+
+ rebuild_batches: bool = True
+
+
+@register_task("multimodal_pretraining", dataclass=MultimodalPretrainingConfig)
+class MultimodalPretrainingTask(FairseqTask):
+ """ """
+
+ cfg: MultimodalPretrainingConfig
+
+ def __init__(self, cfg: MultimodalPretrainingConfig):
+ super().__init__(cfg)
+ self.audio_task = (
+ AudioPretrainingTask(cfg.audio) if cfg.audio is not None else None
+ )
+ self.image_task = (
+ MaeImagePretrainingTask(cfg.image) if cfg.image is not None else None
+ )
+ self.text_task = MaskedLMTask(cfg.text) if cfg.text is not None else None
+
+ self.mult_ratios = []
+
+ @classmethod
+ def setup_task(cls, cfg: MultimodalPretrainingConfig, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ cfg (AudioPretrainingConfig): configuration of this task
+ """
+
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ datasets = []
+ self.mult_ratios = []
+
+ def load_ds(task, name, ratio):
+ if task is not None:
+ task.load_dataset(split)
+ ds = ModalityDatasetItem(
+ datasetname=name,
+ dataset=task.dataset(split),
+ max_positions=task.max_positions(),
+ max_tokens=self.cfg.max_tokens,
+ max_sentences=self.cfg.batch_size,
+ )
+ datasets.append(ds)
+ self.mult_ratios.append(ratio)
+
+ load_ds(self.audio_task, Modality.AUDIO, self.cfg.audio_ratio)
+ load_ds(self.image_task, Modality.IMAGE, self.cfg.image_ratio)
+ load_ds(self.text_task, Modality.TEXT, self.cfg.text_ratio)
+
+ assert len(datasets) > 0
+
+ self.datasets[split] = MultiModalityDataset(datasets)
+
+ @property
+ def supported_modalities(self):
+ modalities = []
+ if self.cfg.text is not None:
+ modalities.append(Modality.TEXT)
+ if self.cfg.audio is not None:
+ modalities.append(Modality.AUDIO)
+ if self.cfg.image is not None:
+ modalities.append(Modality.IMAGE)
+
+ return modalities
+
+ def get_batch_iterator(
+ self,
+ dataset,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=0,
+ data_buffer_size=0,
+ disable_iterator_cache=False,
+ skip_remainder_batch=False,
+ grouped_shuffling=False,
+ update_epoch_batch_itr=False,
+ ):
+
+ # initialize the dataset with the correct starting epoch
+ dataset.set_epoch(epoch)
+
+ batch_samplers = dataset.get_batch_samplers(
+ self.mult_ratios, required_batch_size_multiple, seed
+ )
+
+ # return a reusable, sharded iterator
+ epoch_iter = GroupedEpochBatchIterator(
+ dataset=dataset,
+ collate_fn=dataset.collater,
+ batch_samplers=batch_samplers,
+ seed=seed,
+ num_shards=num_shards,
+ shard_id=shard_id,
+ num_workers=num_workers,
+ epoch=epoch,
+ mult_rate=max(self.cfg.update_freq),
+ buffer_size=data_buffer_size,
+ skip_remainder_batch=skip_remainder_batch,
+ )
+ self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch
+ return epoch_iter
+
+ @property
+ def source_dictionary(self):
+ return None
+
+ @property
+ def target_dictionary(self):
+ return None
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return sys.maxsize, sys.maxsize
diff --git a/examples/discriminative_reranking_nmt/README.md b/examples/discriminative_reranking_nmt/README.md
new file mode 100644
index 0000000000..b155e855f2
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/README.md
@@ -0,0 +1,202 @@
+# Discriminative Reranking for Neural Machine Translation
+https://aclanthology.org/2021.acl-long.563/
+
+This folder contains source code for training DrNMT, a discriminatively trained reranker for neural machine translation.
+
+## Data preparation
+1. Follow the instructions under `examples/translation` to build a base MT model. Prepare three files, one with source sentences, one with ground truth target sentences, and one with hypotheses generated from the base MT model. Each line in the file contains one sentence in raw text (i.e. no sentencepiece, etc.). Below is an example of the files with _N_ hypotheses for each source sentence.
+
+```
+# Example of the source sentence file: (The file should contain L lines.)
+
+source_sentence_1
+source_sentence_2
+source_sentence_3
+...
+source_sentence_L
+
+# Example of the target sentence file: (The file should contain L lines.)
+
+target_sentence_1
+target_sentence_2
+target_sentence_3
+...
+target_sentence_L
+
+# Example of the hypotheses file: (The file should contain L*N lines.)
+
+source_sentence_1_hypo_1
+source_sentence_1_hypo_2
+...
+source_sentence_1_hypo_N
+source_sentence_2_hypo_1
+...
+source_sentence_2_hypo_N
+...
+source_sentence_L_hypo_1
+...
+source_sentence_L_hypo_N
+```
+
+2. Download the [XLMR model](https://github.com/fairinternal/fairseq-py/tree/main/examples/xlmr#pre-trained-models).
+```
+wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz
+tar zxvf xlmr.base.tar.gz
+
+# The folder should contain dict.txt, model.pt and sentencepiece.bpe.model.
+```
+
+3. Prepare scores and BPE data.
+* `N`: Number of hypotheses per each source sentence. We use 50 in the paper.
+* `SPLIT`: Name of the data split, i.e. train, valid, test. Use split_name, split_name1, split_name2, ..., if there are multiple datasets for a split, e.g. train, train1, valid, valid1.
+* `NUM_SHARDS`: Number of shards. Set this to 1 for non-train splits.
+* `METRIC`: The metric for DrNMT to optimize for. We support either `bleu` or `ter`.
+```
+# For each data split, e.g. train, valid, test, etc., run the following:
+
+SOURCE_FILE=/path/to/source_sentence_file
+TARGET_FILE=/path/to/target_sentence_file
+HYPO_FILE=/path/to/hypo_file
+XLMR_DIR=/path/to/xlmr
+OUTPUT_DIR=/path/to/output
+
+python scripts/prep_data.py \
+ --input-source ${SOURCE_FILE} \
+ --input-target ${TARGET_FILE} \
+ --input-hypo ${HYPO_FILE} \
+ --output-dir ${OUTPUT_DIR} \
+ --split $SPLIT
+ --beam $N \
+ --sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \
+ --metric $METRIC \
+ --num-shards ${NUM_SHARDS}
+
+# The script will create ${OUTPUT_DIR}/$METRIC with ${NUM_SHARDS} splits.
+# Under split*/input_src, split*/input_tgt and split*/$METRIC, there will be $SPLIT.bpe and $SPLIT.$METRIC files, respectively.
+
+```
+
+4. Pre-process the data into fairseq format.
+```
+# use comma to separate if there are more than one train or valid set
+for suffix in src tgt ; do
+ fairseq-preprocess --only-source \
+ --trainpref ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/train.bpe \
+ --validpref ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/valid.bpe \
+ --destdir ${OUTPUT_DIR}/$METRIC/split1/input_${suffix} \
+ --workers 60 \
+ --srcdict ${XLMR_DIR}/dict.txt
+done
+
+for i in `seq 2 ${NUM_SHARDS}`; do
+ for suffix in src tgt ; do
+ fairseq-preprocess --only-source \
+ --trainpref ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix}/train.bpe \
+ --destdir ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix} \
+ --workers 60 \
+ --srcdict ${XLMR_DIR}/dict.txt
+
+ ln -s ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/valid* ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix}/.
+ done
+
+ ln -s ${OUTPUT_DIR}/$METRIC/split1/$METRIC/valid* ${OUTPUT_DIR}/$METRIC/split${i}/$METRIC/.
+done
+```
+
+## Training
+
+```
+EXP_DIR=/path/to/exp
+
+# An example of training the model with the config for De-En experiment in the paper.
+# The config uses 16 GPUs and 50 hypotheses.
+# For training with fewer number of GPUs, set
+# distributed_training.distributed_world_size=k +optimization.update_freq='[x]' where x = 16/k
+# For training with fewer number of hypotheses, set
+# task.mt_beam=N dataset.batch_size=N dataset.required_batch_size_multiple=N
+
+fairseq-hydra-train -m \
+ --config-dir config/ --config-name deen \
+ task.data=${OUTPUT_DIR}/$METRIC/split1/ \
+ task.num_data_splits=${NUM_SHARDS} \
+ model.pretrained_model=${XLMR_DIR}/model.pt \
+ common.user_dir=${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \
+ checkpoint.save_dir=${EXP_DIR}
+
+```
+
+## Inference & scoring
+Perform DrNMT reranking (fw + reranker score)
+1. Tune weights on valid sets.
+```
+# genrate N hypotheses with the base MT model (fw score)
+VALID_SOURCE_FILE=/path/to/source_sentences # one sentence per line, converted to the sentencepiece used by the base MT model
+VALID_TARGET_FILE=/path/to/target_sentences # one sentence per line in raw text, i.e. no sentencepiece and tokenization
+MT_MODEL=/path/to/mt_model
+MT_DATA_PATH=/path/to/mt_data
+
+cat ${VALID_SOURCE_FILE} | \
+ fairseq-interactive ${MT_DATA_PATH} \
+ --max-tokens 4000 --buffer-size 16 \
+ --num-workers 32 --path ${MT_MODEL} \
+ --beam $N --nbest $N \
+ --post-process sentencepiece &> valid-hypo.out
+
+# replace "bleu" with "ter" to optimize for TER
+python drnmt_rerank.py \
+ ${OUTPUT_DIR}/$METRIC/split1/ \
+ --path ${EXP_DIR}/checkpoint_best.pt \
+ --in-text valid-hypo.out \
+ --results-path ${EXP_DIR} \
+ --gen-subset valid \
+ --target-text ${VALID_TARGET_FILE} \
+ --user-dir ${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \
+ --bpe sentencepiece \
+ --sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \
+ --beam $N \
+ --batch-size $N \
+ --metric bleu \
+ --tune
+
+```
+
+2. Apply best weights on test sets
+```
+# genrate N hypotheses with the base MT model (fw score)
+TEST_SOURCE_FILE=/path/to/source_sentences # one sentence per line, converted to the sentencepiece used by the base MT model
+
+cat ${TEST_SOURCE_FILE} | \
+ fairseq-interactive ${MT_DATA_PATH} \
+ --max-tokens 4000 --buffer-size 16 \
+ --num-workers 32 --path ${MT_MODEL} \
+ --beam $N --nbest $N \
+ --post-process sentencepiece &> test-hypo.out
+
+# replace "bleu" with "ter" to evaluate TER
+# Add --target-text for evaluating BLEU/TER,
+# otherwise the script will only generate the hypotheses with the highest scores only.
+python drnmt_rerank.py \
+ ${OUTPUT_DIR}/$METRIC/split1/ \
+ --path ${EXP_DIR}/checkpoint_best.pt \
+ --in-text test-hypo.out \
+ --results-path ${EXP_DIR} \
+ --gen-subset test \
+ --user-dir ${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \
+ --bpe sentencepiece \
+ --sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \
+ --beam $N \
+ --batch-size $N \
+ --metric bleu \
+ --fw-weight ${BEST_FW_WEIGHT} \
+ --lenpen ${BEST_LENPEN}
+```
+
+## Citation
+```bibtex
+@inproceedings{lee2021discriminative,
+ title={Discriminative Reranking for Neural Machine Translation},
+ author={Lee, Ann and Auli, Michael and Ranzato, Marc'Aurelio},
+ booktitle={ACL},
+ year={2021}
+}
+```
diff --git a/examples/discriminative_reranking_nmt/__init__.py b/examples/discriminative_reranking_nmt/__init__.py
new file mode 100644
index 0000000000..0278f6a273
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/__init__.py
@@ -0,0 +1 @@
+from . import criterions, models, tasks # noqa
diff --git a/examples/discriminative_reranking_nmt/config/deen.yaml b/examples/discriminative_reranking_nmt/config/deen.yaml
new file mode 100644
index 0000000000..3fc2d5fcf5
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/config/deen.yaml
@@ -0,0 +1,56 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 50
+ seed: 2
+
+checkpoint:
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: bleu
+ maximize_best_checkpoint_metric: true
+
+task:
+ _name: discriminative_reranking_nmt
+ data: ???
+ num_data_splits: ???
+ include_src: true
+ mt_beam: 50
+ eval_target_metric: true
+ target_metric: bleu
+
+dataset:
+ batch_size: 50
+ num_workers: 6
+ required_batch_size_multiple: 50
+ valid_subset: ???
+
+criterion:
+ _name: kl_divergence_rereanking
+ target_dist_norm: minmax
+ temperature: 0.5
+
+optimization:
+ max_epoch: 200
+ lr: [0.00005]
+ update_freq: [32]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 8000
+ total_num_update: 320000
+
+model:
+ _name: discriminative_nmt_reranker
+ pretrained_model: ???
+ classifier_dropout: 0.2
+
+distributed_training:
+ ddp_backend: no_c10d
+ distributed_world_size: 16
diff --git a/examples/discriminative_reranking_nmt/criterions/__init__.py b/examples/discriminative_reranking_nmt/criterions/__init__.py
new file mode 100644
index 0000000000..7c257c2700
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/criterions/__init__.py
@@ -0,0 +1,6 @@
+from .discriminative_reranking_criterion import KLDivergenceRerankingCriterion
+
+
+__all__ = [
+ "KLDivergenceRerankingCriterion",
+]
diff --git a/examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py b/examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py
new file mode 100644
index 0000000000..c8f19e3858
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py
@@ -0,0 +1,139 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn.functional as F
+
+from fairseq import utils
+from fairseq.logging import metrics
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.dataclass import ChoiceEnum, FairseqDataclass
+
+
+_EPSILON = torch.finfo(torch.float32).eps
+TARGET_DIST_NORM_CHOICES = ChoiceEnum(["none", "minmax"])
+
+
+@dataclass
+class KLDivergenceRerankingCriterionConfig(FairseqDataclass):
+ target_dist_norm: TARGET_DIST_NORM_CHOICES = field(
+ default="none",
+ metadata={"help": "method to normalize the range of target scores"},
+ )
+ temperature: float = field(
+ default=1.0,
+ metadata={"help": "temperature in softmax for target distributions"},
+ )
+ forward_batch_size: int = field(
+ default=32,
+ metadata={
+ "help": "number of hypotheses per batch for model forward (set a value smaller than --mt-beam to avoid OOM when training with a large beam size)"
+ },
+ )
+
+
+@register_criterion(
+ "kl_divergence_rereanking", dataclass=KLDivergenceRerankingCriterionConfig
+)
+class KLDivergenceRerankingCriterion(FairseqCriterion):
+ def __init__(
+ self, task, target_dist_norm, temperature, forward_batch_size,
+ ):
+ super().__init__(task)
+ self.target_dist_norm = target_dist_norm
+ self.temperature = temperature
+ self.forward_batch_size = forward_batch_size
+
+ def forward(self, model, sample, reduce=True):
+ """Compute the loss for the given sample.
+
+ Returns a tuple with three elements:
+ 1) the loss
+ 2) the sample size, which is used as the denominator for the gradient
+ 3) logging outputs to display while training
+ """
+
+ sample_size = sample["id"].numel()
+ assert sample_size % self.task.cfg.mt_beam == 0, (
+ f"sample_size ({sample_size}) cannot be divided by beam size ({self.task.cfg.mt_beam})."
+ f"Please set --required-batch-size-multiple={self.task.cfg.mt_beam}."
+ )
+
+ # split into smaller batches for model forward
+ batch_out = []
+ for i in range(0, sample_size, self.forward_batch_size):
+ j = min(i + self.forward_batch_size, sample_size)
+
+ out = model(
+ src_tokens=sample["net_input"]["src_tokens"][i:j, :],
+ src_lengths=sample["net_input"]["src_lengths"][i:j],
+ )
+
+ batch_out.append(
+ model.sentence_forward(out, sample["net_input"]["src_tokens"][i:j, :])
+ )
+
+ batch_out = torch.cat(batch_out, dim=0).view(
+ self.task.cfg.mt_beam, sample_size // self.task.cfg.mt_beam, -1
+ ) # T x B x C
+ if model.joint_classification == "sent":
+ batch_out = model.joint_forward(batch_out)
+ scores = model.classification_forward(batch_out.view(sample_size, 1, -1)).view(
+ -1, self.task.cfg.mt_beam
+ ) # input: B x T x C
+
+ loss = self.compute_kl_loss(
+ scores, sample["target"][:, 0].view(-1, self.task.cfg.mt_beam)
+ )
+
+ sample_size = sample_size // self.task.cfg.mt_beam
+
+ logging_output = {
+ "loss": loss.detach(),
+ "ntokens": sample["ntokens"],
+ "nsentences": sample_size * self.task.cfg.mt_beam,
+ "sample_size": sample_size,
+ "scores": scores.detach(),
+ }
+
+ return loss, sample_size, logging_output
+
+ def compute_kl_loss(self, logits, target):
+ norm_target = target
+ if self.target_dist_norm == "minmax":
+ min_v = torch.min(target, 1, keepdim=True).values
+ max_v = torch.max(target, 1, keepdim=True).values
+ norm_target = (target - min_v) / (max_v - min_v + _EPSILON)
+
+ target_dist = F.softmax(
+ norm_target / self.temperature, dim=-1, dtype=torch.float32
+ )
+ model_dist = F.log_softmax(logits, dim=-1, dtype=torch.float32)
+ loss = -(target_dist * model_dist - target_dist * target_dist.log()).sum()
+ return loss
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ """Aggregate logging outputs from data parallel training."""
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
+
+ sample_size = utils.item(
+ sum(log.get("sample_size", 0) for log in logging_outputs)
+ )
+
+ loss = loss_sum / sample_size / math.log(2)
+ metrics.log_scalar("loss", loss, sample_size, round=3)
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return True
diff --git a/examples/discriminative_reranking_nmt/drnmt_rerank.py b/examples/discriminative_reranking_nmt/drnmt_rerank.py
new file mode 100644
index 0000000000..2e0fc2bd29
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/drnmt_rerank.py
@@ -0,0 +1,364 @@
+#!/usr/bin/env python3 -u
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Score raw text with a trained model.
+"""
+
+from collections import namedtuple
+import logging
+from multiprocessing import Pool
+import sys
+import os
+import random
+
+import numpy as np
+import sacrebleu
+import torch
+
+from fairseq import checkpoint_utils, options, utils
+
+
+logger = logging.getLogger("fairseq_cli.drnmt_rerank")
+logger.setLevel(logging.INFO)
+
+Batch = namedtuple("Batch", "ids src_tokens src_lengths")
+
+
+pool_init_variables = {}
+
+
+def init_loaded_scores(mt_scores, model_scores, hyp, ref):
+ global pool_init_variables
+ pool_init_variables["mt_scores"] = mt_scores
+ pool_init_variables["model_scores"] = model_scores
+ pool_init_variables["hyp"] = hyp
+ pool_init_variables["ref"] = ref
+
+
+def parse_fairseq_gen(filename, task):
+ source = {}
+ hypos = {}
+ scores = {}
+ with open(filename, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if line.startswith("S-"): # source
+ uid, text = line.split("\t", 1)
+ uid = int(uid[2:])
+ source[uid] = text
+ elif line.startswith("D-"): # hypo
+ uid, score, text = line.split("\t", 2)
+ uid = int(uid[2:])
+ if uid not in hypos:
+ hypos[uid] = []
+ scores[uid] = []
+ hypos[uid].append(text)
+ scores[uid].append(float(score))
+ else:
+ continue
+
+ source_out = [source[i] for i in range(len(hypos))]
+ hypos_out = [h for i in range(len(hypos)) for h in hypos[i]]
+ scores_out = [s for i in range(len(scores)) for s in scores[i]]
+
+ return source_out, hypos_out, scores_out
+
+
+def read_target(filename):
+ with open(filename, "r", encoding="utf-8") as f:
+ output = [line.strip() for line in f]
+ return output
+
+
+def make_batches(args, src, hyp, task, max_positions, encode_fn):
+ assert len(src) * args.beam == len(
+ hyp
+ ), f"Expect {len(src) * args.beam} hypotheses for {len(src)} source sentences with beam size {args.beam}. Got {len(hyp)} hypotheses intead."
+ hyp_encode = [
+ task.source_dictionary.encode_line(encode_fn(h), add_if_not_exist=False).long()
+ for h in hyp
+ ]
+ if task.cfg.include_src:
+ src_encode = [
+ task.source_dictionary.encode_line(
+ encode_fn(s), add_if_not_exist=False
+ ).long()
+ for s in src
+ ]
+ tokens = [(src_encode[i // args.beam], h) for i, h in enumerate(hyp_encode)]
+ lengths = [(t1.numel(), t2.numel()) for t1, t2 in tokens]
+ else:
+ tokens = [(h,) for h in hyp_encode]
+ lengths = [(h.numel(),) for h in hyp_encode]
+
+ itr = task.get_batch_iterator(
+ dataset=task.build_dataset_for_inference(tokens, lengths),
+ max_tokens=args.max_tokens,
+ max_sentences=args.batch_size,
+ max_positions=max_positions,
+ ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
+ ).next_epoch_itr(shuffle=False)
+
+ for batch in itr:
+ yield Batch(
+ ids=batch["id"],
+ src_tokens=batch["net_input"]["src_tokens"],
+ src_lengths=batch["net_input"]["src_lengths"],
+ )
+
+
+def decode_rerank_scores(args):
+ if args.max_tokens is None and args.batch_size is None:
+ args.batch_size = 1
+
+ logger.info(args)
+
+ use_cuda = torch.cuda.is_available() and not args.cpu
+
+ # Load ensemble
+ logger.info("loading model(s) from {}".format(args.path))
+ models, _model_args, task = checkpoint_utils.load_model_ensemble_and_task(
+ [args.path], arg_overrides=eval(args.model_overrides),
+ )
+
+ for model in models:
+ if args.fp16:
+ model.half()
+ if use_cuda:
+ model.cuda()
+
+ # Initialize generator
+ generator = task.build_generator(args)
+
+ # Handle tokenization and BPE
+ tokenizer = task.build_tokenizer(args)
+ bpe = task.build_bpe(args)
+
+ def encode_fn(x):
+ if tokenizer is not None:
+ x = tokenizer.encode(x)
+ if bpe is not None:
+ x = bpe.encode(x)
+ return x
+
+ max_positions = utils.resolve_max_positions(
+ task.max_positions(), *[model.max_positions() for model in models]
+ )
+
+ src, hyp, mt_scores = parse_fairseq_gen(args.in_text, task)
+ model_scores = {}
+ logger.info("decode reranker score")
+ for batch in make_batches(args, src, hyp, task, max_positions, encode_fn):
+ src_tokens = batch.src_tokens
+ src_lengths = batch.src_lengths
+ if use_cuda:
+ src_tokens = src_tokens.cuda()
+ src_lengths = src_lengths.cuda()
+
+ sample = {
+ "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths},
+ }
+ scores = task.inference_step(generator, models, sample)
+
+ for id, sc in zip(batch.ids.tolist(), scores.tolist()):
+ model_scores[id] = sc[0]
+
+ model_scores = [model_scores[i] for i in range(len(model_scores))]
+
+ return src, hyp, mt_scores, model_scores
+
+
+def get_score(mt_s, md_s, w1, lp, tgt_len):
+ return mt_s / (tgt_len ** lp) * w1 + md_s
+
+
+def get_best_hyps(mt_scores, md_scores, hypos, fw_weight, lenpen, beam):
+ assert len(mt_scores) == len(md_scores) and len(mt_scores) == len(hypos)
+ hypo_scores = []
+ best_hypos = []
+ best_scores = []
+ offset = 0
+ for i in range(len(hypos)):
+ tgt_len = len(hypos[i].split())
+ hypo_scores.append(
+ get_score(mt_scores[i], md_scores[i], fw_weight, lenpen, tgt_len)
+ )
+
+ if (i + 1) % beam == 0:
+ max_i = np.argmax(hypo_scores)
+ best_hypos.append(hypos[offset + max_i])
+ best_scores.append(hypo_scores[max_i])
+ hypo_scores = []
+ offset += beam
+ return best_hypos, best_scores
+
+
+def eval_metric(args, hypos, ref):
+ if args.metric == "bleu":
+ score = sacrebleu.corpus_bleu(hypos, [ref]).score
+ else:
+ score = sacrebleu.corpus_ter(hypos, [ref]).score
+
+ return score
+
+
+def score_target_hypo(args, fw_weight, lp):
+ mt_scores = pool_init_variables["mt_scores"]
+ model_scores = pool_init_variables["model_scores"]
+ hyp = pool_init_variables["hyp"]
+ ref = pool_init_variables["ref"]
+ best_hypos, _ = get_best_hyps(
+ mt_scores, model_scores, hyp, fw_weight, lp, args.beam
+ )
+ rerank_eval = None
+ if ref:
+ rerank_eval = eval_metric(args, best_hypos, ref)
+ print(f"fw_weight {fw_weight}, lenpen {lp}, eval {rerank_eval}")
+
+ return rerank_eval
+
+
+def print_result(best_scores, best_hypos, output_file):
+ for i, (s, h) in enumerate(zip(best_scores, best_hypos)):
+ print(f"{i}\t{s}\t{h}", file=output_file)
+
+
+def main(args):
+ utils.import_user_module(args)
+
+ src, hyp, mt_scores, model_scores = decode_rerank_scores(args)
+
+ assert (
+ not args.tune or args.target_text is not None
+ ), "--target-text has to be set when tuning weights"
+ if args.target_text:
+ ref = read_target(args.target_text)
+ assert len(src) == len(
+ ref
+ ), f"different numbers of source and target sentences ({len(src)} vs. {len(ref)})"
+
+ orig_best_hypos = [hyp[i] for i in range(0, len(hyp), args.beam)]
+ orig_eval = eval_metric(args, orig_best_hypos, ref)
+
+ if args.tune:
+ logger.info("tune weights for reranking")
+
+ random_params = np.array(
+ [
+ [
+ random.uniform(
+ args.lower_bound_fw_weight, args.upper_bound_fw_weight
+ ),
+ random.uniform(args.lower_bound_lenpen, args.upper_bound_lenpen),
+ ]
+ for k in range(args.num_trials)
+ ]
+ )
+
+ logger.info("launching pool")
+ with Pool(
+ 32,
+ initializer=init_loaded_scores,
+ initargs=(mt_scores, model_scores, hyp, ref),
+ ) as p:
+ rerank_scores = p.starmap(
+ score_target_hypo,
+ [
+ (args, random_params[i][0], random_params[i][1],)
+ for i in range(args.num_trials)
+ ],
+ )
+ if args.metric == "bleu":
+ best_index = np.argmax(rerank_scores)
+ else:
+ best_index = np.argmin(rerank_scores)
+ best_fw_weight = random_params[best_index][0]
+ best_lenpen = random_params[best_index][1]
+ else:
+ assert (
+ args.lenpen is not None and args.fw_weight is not None
+ ), "--lenpen and --fw-weight should be set"
+ best_fw_weight, best_lenpen = args.fw_weight, args.lenpen
+
+ best_hypos, best_scores = get_best_hyps(
+ mt_scores, model_scores, hyp, best_fw_weight, best_lenpen, args.beam
+ )
+
+ if args.results_path is not None:
+ os.makedirs(args.results_path, exist_ok=True)
+ output_path = os.path.join(
+ args.results_path, "generate-{}.txt".format(args.gen_subset),
+ )
+ with open(output_path, "w", buffering=1, encoding="utf-8") as o:
+ print_result(best_scores, best_hypos, o)
+ else:
+ print_result(best_scores, best_hypos, sys.stdout)
+
+ if args.target_text:
+ rerank_eval = eval_metric(args, best_hypos, ref)
+ print(f"before reranking, {args.metric.upper()}:", orig_eval)
+ print(
+ f"after reranking with fw_weight={best_fw_weight}, lenpen={best_lenpen}, {args.metric.upper()}:",
+ rerank_eval,
+ )
+
+
+def cli_main():
+ parser = options.get_generation_parser(interactive=True)
+
+ parser.add_argument(
+ "--in-text",
+ default=None,
+ required=True,
+ help="text from fairseq-interactive output, containing source sentences and hypotheses",
+ )
+ parser.add_argument("--target-text", default=None, help="reference text")
+ parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu")
+ parser.add_argument(
+ "--tune",
+ action="store_true",
+ help="if set, tune weights on fw scores and lenpen instead of applying fixed weights for reranking",
+ )
+ parser.add_argument(
+ "--lower-bound-fw-weight",
+ default=0.0,
+ type=float,
+ help="lower bound of search space",
+ )
+ parser.add_argument(
+ "--upper-bound-fw-weight",
+ default=3,
+ type=float,
+ help="upper bound of search space",
+ )
+ parser.add_argument(
+ "--lower-bound-lenpen",
+ default=0.0,
+ type=float,
+ help="lower bound of search space",
+ )
+ parser.add_argument(
+ "--upper-bound-lenpen",
+ default=3,
+ type=float,
+ help="upper bound of search space",
+ )
+ parser.add_argument(
+ "--fw-weight", type=float, default=None, help="weight on the fw model score"
+ )
+ parser.add_argument(
+ "--num-trials",
+ default=1000,
+ type=int,
+ help="number of trials to do for random search",
+ )
+
+ args = options.parse_args_and_arch(parser)
+ main(args)
+
+
+if __name__ == "__main__":
+ cli_main()
diff --git a/examples/discriminative_reranking_nmt/models/__init__.py b/examples/discriminative_reranking_nmt/models/__init__.py
new file mode 100644
index 0000000000..c593ea5f18
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/models/__init__.py
@@ -0,0 +1,6 @@
+from .discriminative_reranking_model import DiscriminativeNMTReranker
+
+
+__all__ = [
+ "DiscriminativeNMTReranker",
+]
diff --git a/examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py b/examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py
new file mode 100644
index 0000000000..e4b5887f82
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py
@@ -0,0 +1,365 @@
+from dataclasses import dataclass, field
+import os
+
+import torch
+import torch.nn as nn
+
+from fairseq import utils
+from fairseq.dataclass import ChoiceEnum, FairseqDataclass
+from fairseq.models import (
+ BaseFairseqModel,
+ register_model,
+)
+
+from fairseq.models.roberta.model import RobertaClassificationHead
+
+from fairseq.modules import (
+ LayerNorm,
+ TransformerSentenceEncoder,
+ TransformerSentenceEncoderLayer,
+)
+
+
+ACTIVATION_FN_CHOICES = ChoiceEnum(utils.get_available_activation_fns())
+JOINT_CLASSIFICATION_CHOICES = ChoiceEnum(["none", "sent"])
+SENTENCE_REP_CHOICES = ChoiceEnum(["head", "meanpool", "maxpool"])
+
+
+def update_init_roberta_model_state(state):
+ """
+ update the state_dict of a Roberta model for initializing
+ weights of the BertRanker
+ """
+ for k in list(state.keys()):
+ if ".lm_head." in k or "version" in k:
+ del state[k]
+ continue
+ # remove 'encoder/decoder.sentence_encoder.' from the key
+ assert k.startswith("encoder.sentence_encoder.") or k.startswith(
+ "decoder.sentence_encoder."
+ ), f"Cannot recognize parameter name {k}"
+ if "layernorm_embedding" in k:
+ new_k = k.replace(".layernorm_embedding.", ".emb_layer_norm.")
+ state[new_k[25:]] = state[k]
+ else:
+ state[k[25:]] = state[k]
+ del state[k]
+
+
+class BaseRanker(nn.Module):
+ def __init__(self, args, task):
+ super().__init__()
+
+ self.separator_token = task.dictionary.eos()
+ self.padding_idx = task.dictionary.pad()
+
+ def forward(self, src_tokens):
+ raise NotImplementedError
+
+ def get_segment_labels(self, src_tokens):
+ segment_boundary = (src_tokens == self.separator_token).long()
+ segment_labels = (
+ segment_boundary.cumsum(dim=1)
+ - segment_boundary
+ - (src_tokens == self.padding_idx).long()
+ )
+
+ return segment_labels
+
+ def get_positions(self, src_tokens, segment_labels):
+ segment_positions = (
+ torch.arange(src_tokens.shape[1])
+ .to(src_tokens.device)
+ .repeat(src_tokens.shape[0], 1)
+ )
+ segment_boundary = (src_tokens == self.separator_token).long()
+ _, col_idx = (segment_positions * segment_boundary).nonzero(as_tuple=True)
+ col_idx = torch.cat([torch.zeros(1).type_as(col_idx), col_idx])
+ offset = torch.cat(
+ [
+ torch.zeros(1).type_as(segment_boundary),
+ segment_boundary.sum(dim=1).cumsum(dim=0)[:-1],
+ ]
+ )
+ segment_positions -= col_idx[segment_labels + offset.unsqueeze(1)] * (
+ segment_labels != 0
+ )
+
+ padding_mask = src_tokens.ne(self.padding_idx)
+ segment_positions = (segment_positions + 1) * padding_mask.type_as(
+ segment_positions
+ ) + self.padding_idx
+
+ return segment_positions
+
+
+class BertRanker(BaseRanker):
+ def __init__(self, args, task):
+ super(BertRanker, self).__init__(args, task)
+
+ init_model = getattr(args, "pretrained_model", "")
+ self.joint_layers = nn.ModuleList()
+ if os.path.isfile(init_model):
+ print(f"initialize weight from {init_model}")
+
+ from fairseq import hub_utils
+
+ x = hub_utils.from_pretrained(
+ os.path.dirname(init_model),
+ checkpoint_file=os.path.basename(init_model),
+ )
+
+ in_state_dict = x["models"][0].state_dict()
+ init_args = x["args"].model
+
+ num_positional_emb = init_args.max_positions + task.dictionary.pad() + 1
+
+ # follow the setup in roberta
+ self.model = TransformerSentenceEncoder(
+ padding_idx=task.dictionary.pad(),
+ vocab_size=len(task.dictionary),
+ num_encoder_layers=getattr(
+ args, "encoder_layers", init_args.encoder_layers
+ ),
+ embedding_dim=init_args.encoder_embed_dim,
+ ffn_embedding_dim=init_args.encoder_ffn_embed_dim,
+ num_attention_heads=init_args.encoder_attention_heads,
+ dropout=init_args.dropout,
+ attention_dropout=init_args.attention_dropout,
+ activation_dropout=init_args.activation_dropout,
+ num_segments=2, # add language embeddings
+ max_seq_len=num_positional_emb,
+ offset_positions_by_padding=False,
+ encoder_normalize_before=True,
+ apply_bert_init=True,
+ activation_fn=init_args.activation_fn,
+ freeze_embeddings=args.freeze_embeddings,
+ n_trans_layers_to_freeze=args.n_trans_layers_to_freeze,
+ )
+
+ # still need to learn segment embeddings as we added a second language embedding
+ if args.freeze_embeddings:
+ for p in self.model.segment_embeddings.parameters():
+ p.requires_grad = False
+
+ update_init_roberta_model_state(in_state_dict)
+ print("loading weights from the pretrained model")
+ self.model.load_state_dict(
+ in_state_dict, strict=False
+ ) # ignore mismatch in language embeddings
+
+ ffn_embedding_dim = init_args.encoder_ffn_embed_dim
+ num_attention_heads = init_args.encoder_attention_heads
+ dropout = init_args.dropout
+ attention_dropout = init_args.attention_dropout
+ activation_dropout = init_args.activation_dropout
+ activation_fn = init_args.activation_fn
+
+ classifier_embed_dim = getattr(
+ args, "embed_dim", init_args.encoder_embed_dim
+ )
+ if classifier_embed_dim != init_args.encoder_embed_dim:
+ self.transform_layer = nn.Linear(
+ init_args.encoder_embed_dim, classifier_embed_dim
+ )
+ else:
+ self.model = TransformerSentenceEncoder(
+ padding_idx=task.dictionary.pad(),
+ vocab_size=len(task.dictionary),
+ num_encoder_layers=args.encoder_layers,
+ embedding_dim=args.embed_dim,
+ ffn_embedding_dim=args.ffn_embed_dim,
+ num_attention_heads=args.attention_heads,
+ dropout=args.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ max_seq_len=task.max_positions()
+ if task.max_positions()
+ else args.tokens_per_sample,
+ num_segments=2,
+ offset_positions_by_padding=False,
+ encoder_normalize_before=args.encoder_normalize_before,
+ apply_bert_init=args.apply_bert_init,
+ activation_fn=args.activation_fn,
+ )
+
+ classifier_embed_dim = args.embed_dim
+ ffn_embedding_dim = args.ffn_embed_dim
+ num_attention_heads = args.attention_heads
+ dropout = args.dropout
+ attention_dropout = args.attention_dropout
+ activation_dropout = args.activation_dropout
+ activation_fn = args.activation_fn
+
+ self.joint_classification = args.joint_classification
+ if args.joint_classification == "sent":
+ if args.joint_normalize_before:
+ self.joint_layer_norm = LayerNorm(classifier_embed_dim)
+ else:
+ self.joint_layer_norm = None
+
+ self.joint_layers = nn.ModuleList(
+ [
+ TransformerSentenceEncoderLayer(
+ embedding_dim=classifier_embed_dim,
+ ffn_embedding_dim=ffn_embedding_dim,
+ num_attention_heads=num_attention_heads,
+ dropout=dropout,
+ attention_dropout=attention_dropout,
+ activation_dropout=activation_dropout,
+ activation_fn=activation_fn,
+ )
+ for _ in range(args.num_joint_layers)
+ ]
+ )
+
+ self.classifier = RobertaClassificationHead(
+ classifier_embed_dim,
+ classifier_embed_dim,
+ 1, # num_classes
+ "tanh",
+ args.classifier_dropout,
+ )
+
+ def forward(self, src_tokens, src_lengths):
+ segment_labels = self.get_segment_labels(src_tokens)
+ positions = self.get_positions(src_tokens, segment_labels)
+
+ inner_states, _ = self.model(
+ tokens=src_tokens,
+ segment_labels=segment_labels,
+ last_state_only=True,
+ positions=positions,
+ )
+
+ return inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
+
+ def sentence_forward(self, encoder_out, src_tokens=None, sentence_rep="head"):
+ # encoder_out: B x T x C
+ if sentence_rep == "head":
+ x = encoder_out[:, :1, :]
+ else: # 'meanpool', 'maxpool'
+ assert src_tokens is not None, "meanpool requires src_tokens input"
+ segment_labels = self.get_segment_labels(src_tokens)
+ padding_mask = src_tokens.ne(self.padding_idx)
+ encoder_mask = segment_labels * padding_mask.type_as(segment_labels)
+
+ if sentence_rep == "meanpool":
+ ntokens = torch.sum(encoder_mask, dim=1, keepdim=True)
+ x = torch.sum(
+ encoder_out * encoder_mask.unsqueeze(2), dim=1, keepdim=True
+ ) / ntokens.unsqueeze(2).type_as(encoder_out)
+ else: # 'maxpool'
+ encoder_out[
+ (encoder_mask == 0).unsqueeze(2).repeat(1, 1, encoder_out.shape[-1])
+ ] = -float("inf")
+ x, _ = torch.max(encoder_out, dim=1, keepdim=True)
+
+ if hasattr(self, "transform_layer"):
+ x = self.transform_layer(x)
+
+ return x # B x 1 x C
+
+ def joint_forward(self, x):
+ # x: T x B x C
+ if self.joint_layer_norm:
+ x = self.joint_layer_norm(x.transpose(0, 1))
+ x = x.transpose(0, 1)
+
+ for layer in self.joint_layers:
+ x, _ = layer(x, self_attn_padding_mask=None)
+ return x
+
+ def classification_forward(self, x):
+ # x: B x T x C
+ return self.classifier(x)
+
+
+@dataclass
+class DiscriminativeNMTRerankerConfig(FairseqDataclass):
+ pretrained_model: str = field(
+ default="", metadata={"help": "pretrained model to load"}
+ )
+ sentence_rep: SENTENCE_REP_CHOICES = field(
+ default="head",
+ metadata={
+ "help": "method to transform the output of the transformer stack to a sentence-level representation"
+ },
+ )
+
+ dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
+ attention_dropout: float = field(
+ default=0.0, metadata={"help": "dropout probability for attention weights"}
+ )
+ activation_dropout: float = field(
+ default=0.0, metadata={"help": "dropout probability after activation in FFN"}
+ )
+ classifier_dropout: float = field(
+ default=0.0, metadata={"help": "classifier dropout probability"}
+ )
+ embed_dim: int = field(default=768, metadata={"help": "embedding dimension"})
+ ffn_embed_dim: int = field(
+ default=2048, metadata={"help": "embedding dimension for FFN"}
+ )
+ encoder_layers: int = field(default=12, metadata={"help": "num encoder layers"})
+ attention_heads: int = field(default=8, metadata={"help": "num attention heads"})
+ encoder_normalize_before: bool = field(
+ default=False, metadata={"help": "apply layernorm before each encoder block"}
+ )
+ apply_bert_init: bool = field(
+ default=False, metadata={"help": "use custom param initialization for BERT"}
+ )
+ activation_fn: ACTIVATION_FN_CHOICES = field(
+ default="relu", metadata={"help": "activation function to use"}
+ )
+ freeze_embeddings: bool = field(
+ default=False, metadata={"help": "freeze embeddings in the pretrained model"}
+ )
+ n_trans_layers_to_freeze: int = field(
+ default=0,
+ metadata={
+ "help": "number of layers to freeze in the pretrained transformer model"
+ },
+ )
+
+ # joint classfication
+ joint_classification: JOINT_CLASSIFICATION_CHOICES = field(
+ default="none",
+ metadata={"help": "method to compute joint features for classification"},
+ )
+ num_joint_layers: int = field(
+ default=1, metadata={"help": "number of joint layers"}
+ )
+ joint_normalize_before: bool = field(
+ default=False,
+ metadata={"help": "apply layer norm on the input to the joint layer"},
+ )
+
+
+@register_model(
+ "discriminative_nmt_reranker", dataclass=DiscriminativeNMTRerankerConfig
+)
+class DiscriminativeNMTReranker(BaseFairseqModel):
+ @classmethod
+ def build_model(cls, args, task):
+ model = BertRanker(args, task)
+ return DiscriminativeNMTReranker(args, model)
+
+ def __init__(self, args, model):
+ super().__init__()
+
+ self.model = model
+ self.sentence_rep = args.sentence_rep
+ self.joint_classification = args.joint_classification
+
+ def forward(self, src_tokens, src_lengths, **kwargs):
+ return self.model(src_tokens, src_lengths)
+
+ def sentence_forward(self, encoder_out, src_tokens):
+ return self.model.sentence_forward(encoder_out, src_tokens, self.sentence_rep)
+
+ def joint_forward(self, x):
+ return self.model.joint_forward(x)
+
+ def classification_forward(self, x):
+ return self.model.classification_forward(x)
diff --git a/examples/discriminative_reranking_nmt/scripts/prep_data.py b/examples/discriminative_reranking_nmt/scripts/prep_data.py
new file mode 100755
index 0000000000..7aa7d37edc
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/scripts/prep_data.py
@@ -0,0 +1,136 @@
+#!/usr/bin/env python
+
+import argparse
+from multiprocessing import Pool
+from pathlib import Path
+
+import sacrebleu
+import sentencepiece as spm
+
+
+def read_text_file(filename):
+ with open(filename, "r") as f:
+ output = [line.strip() for line in f]
+
+ return output
+
+
+def get_bleu(in_sent, target_sent):
+ bleu = sacrebleu.corpus_bleu([in_sent], [[target_sent]])
+ out = " ".join(
+ map(str, [bleu.score, bleu.sys_len, bleu.ref_len] + bleu.counts + bleu.totals)
+ )
+ return out
+
+
+def get_ter(in_sent, target_sent):
+ ter = sacrebleu.corpus_ter([in_sent], [[target_sent]])
+ out = " ".join(map(str, [ter.score, ter.num_edits, ter.ref_length]))
+ return out
+
+
+def init(sp_model):
+ global sp
+ sp = spm.SentencePieceProcessor()
+ sp.Load(sp_model)
+
+
+def process(source_sent, target_sent, hypo_sent, metric):
+ source_bpe = " ".join(sp.EncodeAsPieces(source_sent))
+ hypo_bpe = [" ".join(sp.EncodeAsPieces(h)) for h in hypo_sent]
+
+ if metric == "bleu":
+ score_str = [get_bleu(h, target_sent) for h in hypo_sent]
+ else: # ter
+ score_str = [get_ter(h, target_sent) for h in hypo_sent]
+
+ return source_bpe, hypo_bpe, score_str
+
+
+def main(args):
+ assert (
+ args.split.startswith("train") or args.num_shards == 1
+ ), "--num-shards should be set to 1 for valid and test sets"
+ assert (
+ args.split.startswith("train")
+ or args.split.startswith("valid")
+ or args.split.startswith("test")
+ ), "--split should be set to train[n]/valid[n]/test[n]"
+
+ source_sents = read_text_file(args.input_source)
+ target_sents = read_text_file(args.input_target)
+
+ num_sents = len(source_sents)
+ assert num_sents == len(
+ target_sents
+ ), f"{args.input_source} and {args.input_target} should have the same number of sentences."
+
+ hypo_sents = read_text_file(args.input_hypo)
+ assert (
+ len(hypo_sents) % args.beam == 0
+ ), f"Number of hypotheses ({len(hypo_sents)}) cannot be divided by beam size ({args.beam})."
+
+ hypo_sents = [
+ hypo_sents[i : i + args.beam] for i in range(0, len(hypo_sents), args.beam)
+ ]
+ assert num_sents == len(
+ hypo_sents
+ ), f"{args.input_hypo} should contain {num_sents * args.beam} hypotheses but only has {len(hypo_sents) * args.beam}. (--beam={args.beam})"
+
+ output_dir = args.output_dir / args.metric
+ for ns in range(args.num_shards):
+ print(f"processing shard {ns+1}/{args.num_shards}")
+ shard_output_dir = output_dir / f"split{ns+1}"
+ source_output_dir = shard_output_dir / "input_src"
+ hypo_output_dir = shard_output_dir / "input_tgt"
+ metric_output_dir = shard_output_dir / args.metric
+
+ source_output_dir.mkdir(parents=True, exist_ok=True)
+ hypo_output_dir.mkdir(parents=True, exist_ok=True)
+ metric_output_dir.mkdir(parents=True, exist_ok=True)
+
+ if args.n_proc > 1:
+ with Pool(
+ args.n_proc, initializer=init, initargs=(args.sentencepiece_model,)
+ ) as p:
+ output = p.starmap(
+ process,
+ [
+ (source_sents[i], target_sents[i], hypo_sents[i], args.metric)
+ for i in range(ns, num_sents, args.num_shards)
+ ],
+ )
+ else:
+ init(args.sentencepiece_model)
+ output = [
+ process(source_sents[i], target_sents[i], hypo_sents[i], args.metric)
+ for i in range(ns, num_sents, args.num_shards)
+ ]
+
+ with open(source_output_dir / f"{args.split}.bpe", "w") as s_o, open(
+ hypo_output_dir / f"{args.split}.bpe", "w"
+ ) as h_o, open(metric_output_dir / f"{args.split}.{args.metric}", "w") as m_o:
+ for source_bpe, hypo_bpe, score_str in output:
+ assert len(hypo_bpe) == len(score_str)
+ for h, m in zip(hypo_bpe, score_str):
+ s_o.write(f"{source_bpe}\n")
+ h_o.write(f"{h}\n")
+ m_o.write(f"{m}\n")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input-source", type=Path, required=True)
+ parser.add_argument("--input-target", type=Path, required=True)
+ parser.add_argument("--input-hypo", type=Path, required=True)
+ parser.add_argument("--output-dir", type=Path, required=True)
+ parser.add_argument("--split", type=str, required=True)
+ parser.add_argument("--beam", type=int, required=True)
+ parser.add_argument("--sentencepiece-model", type=str, required=True)
+ parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu")
+ parser.add_argument("--num-shards", type=int, default=1)
+ parser.add_argument("--n-proc", type=int, default=8)
+
+ args = parser.parse_args()
+
+ main(args)
diff --git a/examples/discriminative_reranking_nmt/tasks/__init__.py b/examples/discriminative_reranking_nmt/tasks/__init__.py
new file mode 100644
index 0000000000..2d78ca9870
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/tasks/__init__.py
@@ -0,0 +1,6 @@
+from .discriminative_reranking_task import DiscriminativeRerankingNMTTask
+
+
+__all__ = [
+ "DiscriminativeRerankingNMTTask",
+]
diff --git a/examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py b/examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py
new file mode 100644
index 0000000000..b4ed2a69aa
--- /dev/null
+++ b/examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py
@@ -0,0 +1,490 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass, field
+
+import itertools
+import logging
+import os
+
+import numpy as np
+import torch
+
+from fairseq.logging import metrics
+from fairseq.data import (
+ ConcatDataset,
+ ConcatSentencesDataset,
+ data_utils,
+ Dictionary,
+ IdDataset,
+ indexed_dataset,
+ NestedDictionaryDataset,
+ NumSamplesDataset,
+ NumelDataset,
+ PrependTokenDataset,
+ RawLabelDataset,
+ RightPadDataset,
+ SortDataset,
+ TruncateDataset,
+ TokenBlockDataset,
+)
+from fairseq.dataclass import ChoiceEnum, FairseqDataclass
+from fairseq.tasks import FairseqTask, register_task
+from omegaconf import II, MISSING
+
+
+EVAL_BLEU_ORDER = 4
+TARGET_METRIC_CHOICES = ChoiceEnum(["bleu", "ter"])
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class DiscriminativeRerankingNMTConfig(FairseqDataclass):
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
+ num_data_splits: int = field(
+ default=1, metadata={"help": "total number of data splits"}
+ )
+ no_shuffle: bool = field(
+ default=False, metadata={"help": "do not shuffle training data"}
+ )
+ max_positions: int = field(
+ default=512, metadata={"help": "number of positional embeddings to learn"}
+ )
+ include_src: bool = field(
+ default=False, metadata={"help": "include source sentence"}
+ )
+ mt_beam: int = field(default=50, metadata={"help": "beam size of input hypotheses"})
+ eval_target_metric: bool = field(
+ default=False,
+ metadata={"help": "evaluation with the target metric during validation"},
+ )
+ target_metric: TARGET_METRIC_CHOICES = field(
+ default="bleu", metadata={"help": "name of the target metric to optimize for"}
+ )
+ train_subset: str = field(
+ default=II("dataset.train_subset"),
+ metadata={"help": "data subset to use for training (e.g. train, valid, test)"},
+ )
+ seed: int = field(
+ default=II("common.seed"),
+ metadata={"help": "pseudo random number generator seed"},
+ )
+
+
+class RerankerScorer(object):
+ """Scores the target for a given (source (optional), target) input."""
+
+ def __init__(self, args, mt_beam):
+ self.mt_beam = mt_beam
+
+ @torch.no_grad()
+ def generate(self, models, sample, **kwargs):
+ """Score a batch of translations."""
+ net_input = sample["net_input"]
+
+ assert len(models) == 1, "does not support model ensemble"
+ model = models[0]
+
+ bs = net_input["src_tokens"].shape[0]
+ assert (
+ model.joint_classification == "none" or bs % self.mt_beam == 0
+ ), f"invalid batch size ({bs}) for joint classification with beam size ({self.mt_beam})"
+
+ model.eval()
+ logits = model(**net_input)
+
+ batch_out = model.sentence_forward(logits, net_input["src_tokens"])
+ if model.joint_classification == "sent":
+ batch_out = model.joint_forward(
+ batch_out.view(self.mt_beam, bs // self.mt_beam, -1)
+ )
+ scores = model.classification_forward(
+ batch_out.view(bs, 1, -1)
+ ) # input: B x T x C
+
+ return scores
+
+
+@register_task(
+ "discriminative_reranking_nmt", dataclass=DiscriminativeRerankingNMTConfig
+)
+class DiscriminativeRerankingNMTTask(FairseqTask):
+ """
+ Translation rerank task.
+ The input can be either (src, tgt) sentence pairs or tgt sentence only.
+ """
+
+ cfg: DiscriminativeRerankingNMTConfig
+
+ def __init__(self, cfg: DiscriminativeRerankingNMTConfig, data_dictionary=None):
+ super().__init__(cfg)
+ self.dictionary = data_dictionary
+ self._max_positions = cfg.max_positions
+ # args.tokens_per_sample = self._max_positions
+ # self.num_classes = 1 # for model
+
+ @classmethod
+ def load_dictionary(cls, cfg, filename):
+ """Load the dictionary from the filename"""
+ dictionary = Dictionary.load(filename)
+ dictionary.add_symbol("") # for loading pretrained XLMR model
+
+ return dictionary
+
+ @classmethod
+ def setup_task(cls, cfg: DiscriminativeRerankingNMTConfig, **kwargs):
+ # load data dictionary (assume joint dictionary)
+ data_path = cfg.data
+ data_dict = cls.load_dictionary(
+ cfg, os.path.join(data_path, "input_src/dict.txt")
+ )
+
+ logger.info("[input] src dictionary: {} types".format(len(data_dict)))
+
+ return DiscriminativeRerankingNMTTask(cfg, data_dict)
+
+ def load_dataset(self, split, epoch=0, combine=False, **kwargs):
+ """Load a given dataset split (e.g., train, valid, test)."""
+ if self.cfg.data.endswith("1"):
+ data_shard = (epoch - 1) % self.cfg.num_data_splits + 1
+ data_path = self.cfg.data[:-1] + str(data_shard)
+ else:
+ data_path = self.cfg.data
+
+ def get_path(type, data_split):
+ return os.path.join(data_path, str(type), data_split)
+
+ def make_dataset(type, dictionary, data_split, combine):
+ split_path = get_path(type, data_split)
+
+ dataset = data_utils.load_indexed_dataset(
+ split_path,
+ dictionary,
+ combine=combine,
+ )
+ return dataset
+
+ def load_split(data_split, metric):
+ input_src = None
+ if self.cfg.include_src:
+ input_src = make_dataset(
+ "input_src", self.dictionary, data_split, combine=False
+ )
+ assert input_src is not None, "could not find dataset: {}".format(
+ get_path("input_src", data_split)
+ )
+
+ input_tgt = make_dataset(
+ "input_tgt", self.dictionary, data_split, combine=False
+ )
+ assert input_tgt is not None, "could not find dataset: {}".format(
+ get_path("input_tgt", data_split)
+ )
+
+ label_path = f"{get_path(metric, data_split)}.{metric}"
+ assert os.path.exists(label_path), f"could not find dataset: {label_path}"
+
+ np_labels = np.loadtxt(label_path)
+ if self.cfg.target_metric == "ter":
+ np_labels = -np_labels
+ label = RawLabelDataset(np_labels)
+
+ return input_src, input_tgt, label
+
+ src_datasets = []
+ tgt_datasets = []
+ label_datasets = []
+
+ if split == self.cfg.train_subset:
+ for k in itertools.count():
+ split_k = "train" + (str(k) if k > 0 else "")
+ prefix = os.path.join(data_path, "input_tgt", split_k)
+ if not indexed_dataset.dataset_exists(prefix, impl=None):
+ if k > 0:
+ break
+ else:
+ raise FileNotFoundError(f"Dataset not found: {prefix}")
+ input_src, input_tgt, label = load_split(
+ split_k, self.cfg.target_metric
+ )
+ src_datasets.append(input_src)
+ tgt_datasets.append(input_tgt)
+ label_datasets.append(label)
+ else:
+ input_src, input_tgt, label = load_split(split, self.cfg.target_metric)
+ src_datasets.append(input_src)
+ tgt_datasets.append(input_tgt)
+ label_datasets.append(label)
+
+ if len(tgt_datasets) == 1:
+ input_tgt, label = tgt_datasets[0], label_datasets[0]
+ if self.cfg.include_src:
+ input_src = src_datasets[0]
+ else:
+ input_tgt = ConcatDataset(tgt_datasets)
+ label = ConcatDataset(label_datasets)
+ if self.cfg.include_src:
+ input_src = ConcatDataset(src_datasets)
+
+ input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions)
+ if self.cfg.include_src:
+ input_src = PrependTokenDataset(input_src, self.dictionary.bos())
+ input_src = TruncateDataset(input_src, self.cfg.max_positions)
+ src_lengths = NumelDataset(input_src, reduce=False)
+ src_tokens = ConcatSentencesDataset(input_src, input_tgt)
+ else:
+ src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos())
+ src_lengths = NumelDataset(src_tokens, reduce=False)
+
+ dataset = {
+ "id": IdDataset(),
+ "net_input": {
+ "src_tokens": RightPadDataset(
+ src_tokens,
+ pad_idx=self.source_dictionary.pad(),
+ ),
+ "src_lengths": src_lengths,
+ },
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_tokens, reduce=True),
+ "target": label,
+ }
+
+ dataset = NestedDictionaryDataset(
+ dataset,
+ sizes=[src_tokens.sizes],
+ )
+
+ assert (
+ len(dataset) % self.cfg.mt_beam == 0
+ ), "dataset size (%d) is not a multiple of beam size (%d)" % (
+ len(dataset),
+ self.cfg.mt_beam,
+ )
+
+ # no need to shuffle valid/test sets
+ if not self.cfg.no_shuffle and split == self.cfg.train_subset:
+
+ # need to keep all hypothese together
+ start_idx = np.arange(0, len(dataset), self.cfg.mt_beam)
+ with data_utils.numpy_seed(self.cfg.seed + epoch):
+ np.random.shuffle(start_idx)
+
+ idx = np.arange(0, self.cfg.mt_beam)
+ shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile(
+ start_idx, (self.cfg.mt_beam, 1)
+ ).transpose().reshape(-1)
+
+ dataset = SortDataset(
+ dataset,
+ sort_order=[shuffle],
+ )
+
+ logger.info(f"Loaded {split} with #samples: {len(dataset)}")
+
+ self.datasets[split] = dataset
+ return self.datasets[split]
+
+ def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
+ assert not self.cfg.include_src or len(src_tokens[0]) == 2
+ input_src = None
+ if self.cfg.include_src:
+ input_src = TokenBlockDataset(
+ [t[0] for t in src_tokens],
+ [l[0] for l in src_lengths],
+ block_size=None, # ignored for "eos" break mode
+ pad=self.source_dictionary.pad(),
+ eos=self.source_dictionary.eos(),
+ break_mode="eos",
+ )
+ input_src = PrependTokenDataset(input_src, self.dictionary.bos())
+ input_src = TruncateDataset(input_src, self.cfg.max_positions)
+
+ input_tgt = TokenBlockDataset(
+ [t[-1] for t in src_tokens],
+ [l[-1] for l in src_lengths],
+ block_size=None, # ignored for "eos" break mode
+ pad=self.source_dictionary.pad(),
+ eos=self.source_dictionary.eos(),
+ break_mode="eos",
+ )
+ input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions)
+ if self.cfg.include_src:
+ src_tokens = ConcatSentencesDataset(input_src, input_tgt)
+ src_lengths = NumelDataset(input_src, reduce=False)
+ else:
+ input_tgt = PrependTokenDataset(input_tgt, self.dictionary.bos())
+ src_tokens = input_tgt
+ src_lengths = NumelDataset(src_tokens, reduce=False)
+
+ dataset = {
+ "id": IdDataset(),
+ "net_input": {
+ "src_tokens": RightPadDataset(
+ src_tokens,
+ pad_idx=self.source_dictionary.pad(),
+ ),
+ "src_lengths": src_lengths,
+ },
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_tokens, reduce=True),
+ }
+
+ return NestedDictionaryDataset(
+ dataset,
+ sizes=[src_tokens.sizes],
+ )
+
+ def build_model(self, cfg: FairseqDataclass, from_checkpoint: bool = False):
+ return super().build_model(cfg)
+
+ def build_generator(self, args):
+ return RerankerScorer(args, mt_beam=self.cfg.mt_beam)
+
+ def max_positions(self):
+ return self._max_positions
+
+ @property
+ def source_dictionary(self):
+ return self.dictionary
+
+ @property
+ def target_dictionary(self):
+ return self.dictionary
+
+ def create_dummy_batch(self, device):
+ dummy_target = (
+ torch.zeros(self.cfg.mt_beam, EVAL_BLEU_ORDER * 2 + 3).long().to(device)
+ if not self.cfg.eval_ter
+ else torch.zeros(self.cfg.mt_beam, 3).long().to(device)
+ )
+
+ return {
+ "id": torch.zeros(self.cfg.mt_beam, 1).long().to(device),
+ "net_input": {
+ "src_tokens": torch.zeros(self.cfg.mt_beam, 4).long().to(device),
+ "src_lengths": torch.ones(self.cfg.mt_beam, 1).long().to(device),
+ },
+ "nsentences": 0,
+ "ntokens": 0,
+ "target": dummy_target,
+ }
+
+ def train_step(
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
+ ):
+ if ignore_grad and sample is None:
+ sample = self.create_dummy_batch(model.device)
+
+ return super().train_step(
+ sample, model, criterion, optimizer, update_num, ignore_grad
+ )
+
+ def valid_step(self, sample, model, criterion):
+ if sample is None:
+ sample = self.create_dummy_batch(model.device)
+
+ loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
+
+ if not self.cfg.eval_target_metric:
+ return loss, sample_size, logging_output
+
+ scores = logging_output["scores"]
+
+ if self.cfg.target_metric == "bleu":
+ assert sample["target"].shape[1] == EVAL_BLEU_ORDER * 2 + 3, (
+ "target does not contain enough information ("
+ + str(sample["target"].shape[1])
+ + "for evaluating BLEU"
+ )
+
+ max_id = torch.argmax(scores, dim=1)
+ select_id = max_id + torch.arange(
+ 0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam
+ ).to(max_id.device)
+ bleu_data = sample["target"][select_id, 1:].sum(0).data
+
+ logging_output["_bleu_sys_len"] = bleu_data[0]
+ logging_output["_bleu_ref_len"] = bleu_data[1]
+
+ for i in range(EVAL_BLEU_ORDER):
+ logging_output["_bleu_counts_" + str(i)] = bleu_data[2 + i]
+ logging_output["_bleu_totals_" + str(i)] = bleu_data[
+ 2 + EVAL_BLEU_ORDER + i
+ ]
+
+ elif self.cfg.target_metric == "ter":
+ assert sample["target"].shape[1] == 3, (
+ "target does not contain enough information ("
+ + str(sample["target"].shape[1])
+ + "for evaluating TER"
+ )
+
+ max_id = torch.argmax(scores, dim=1)
+ select_id = max_id + torch.arange(
+ 0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam
+ ).to(max_id.device)
+ ter_data = sample["target"][select_id, 1:].sum(0).data
+
+ logging_output["_ter_num_edits"] = -ter_data[0]
+ logging_output["_ter_ref_len"] = -ter_data[1]
+
+ return loss, sample_size, logging_output
+
+ def reduce_metrics(self, logging_outputs, criterion):
+ super().reduce_metrics(logging_outputs, criterion)
+
+ if not self.cfg.eval_target_metric:
+ return
+
+ def sum_logs(key):
+ return sum(log.get(key, 0) for log in logging_outputs)
+
+ if self.cfg.target_metric == "bleu":
+ counts, totals = [], []
+ for i in range(EVAL_BLEU_ORDER):
+ counts.append(sum_logs("_bleu_counts_" + str(i)))
+ totals.append(sum_logs("_bleu_totals_" + str(i)))
+
+ if max(totals) > 0:
+ # log counts as numpy arrays -- log_scalar will sum them correctly
+ metrics.log_scalar("_bleu_counts", np.array(counts))
+ metrics.log_scalar("_bleu_totals", np.array(totals))
+ metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
+ metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))
+
+ def compute_bleu(meters):
+ import inspect
+ import sacrebleu
+
+ fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
+ if "smooth_method" in fn_sig:
+ smooth = {"smooth_method": "exp"}
+ else:
+ smooth = {"smooth": "exp"}
+ bleu = sacrebleu.compute_bleu(
+ correct=meters["_bleu_counts"].sum,
+ total=meters["_bleu_totals"].sum,
+ sys_len=meters["_bleu_sys_len"].sum,
+ ref_len=meters["_bleu_ref_len"].sum,
+ **smooth,
+ )
+ return round(bleu.score, 2)
+
+ metrics.log_derived("bleu", compute_bleu)
+ elif self.cfg.target_metric == "ter":
+ num_edits = sum_logs("_ter_num_edits")
+ ref_len = sum_logs("_ter_ref_len")
+
+ if ref_len > 0:
+ metrics.log_scalar("_ter_num_edits", num_edits)
+ metrics.log_scalar("_ter_ref_len", ref_len)
+
+ def compute_ter(meters):
+ score = meters["_ter_num_edits"].sum / meters["_ter_ref_len"].sum
+ return round(score.item(), 2)
+
+ metrics.log_derived("ter", compute_ter)
diff --git a/examples/emotion_conversion/README.md b/examples/emotion_conversion/README.md
new file mode 100644
index 0000000000..caf22befe4
--- /dev/null
+++ b/examples/emotion_conversion/README.md
@@ -0,0 +1,214 @@
+# Textless speech emotion conversion using decomposed and discrete representations
+[Felix Kreuk](https://felixkreuk.github.io), Adam Polyak, Jade Copet, Eugene Kharitonov, Tu-Anh Nguyen, Morgane Rivière, Wei-Ning Hsu, Abdelrahman Mohamed, Emmanuel Dupoux, [Yossi Adi](https://adiyoss.github.io)
+
+_abstract_: Speech emotion conversion is the task of modifying the perceived emotion of a speech utterance while preserving the lexical content and speaker identity. In this study, we cast the problem of emotion conversion as a spoken language translation task. We decompose speech into discrete and disentangled learned representations, consisting of content units, F0, speaker, and emotion. First, we modify the speech content by translating the content units to a target emotion, and then predict the prosodic features based on these units. Finally, the speech waveform is generated by feeding the predicted representations into a neural vocoder. Such a paradigm allows us to go beyond spectral and parametric changes of the signal, and model non-verbal vocalizations, such as laughter insertion, yawning removal, etc. We demonstrate objectively and subjectively that the proposed method is superior to the baselines in terms of perceived emotion and audio quality. We rigorously evaluate all components of such a complex system and conclude with an extensive model analysis and ablation study to better emphasize the architectural choices, strengths and weaknesses of the proposed method. Samples and code will be publicly available under the following link: https://speechbot.github.io/emotion.
+
+## Installation
+First, create a conda virtual environment and activate it:
+```
+conda create -n emotion python=3.8 -y
+conda activate emotion
+```
+
+Then, clone this repository:
+```
+git clone https://github.com/facebookresearch/fairseq.git
+cd fairseq/examples/emotion_conversion
+git clone https://github.com/felixkreuk/speech-resynthesis
+```
+
+Next, download the EmoV discrete tokens:
+```
+wget https://dl.fbaipublicfiles.com/textless_nlp/emotion_conversion/data.tar.gz # (still in fairseq/examples/emotion_conversion)
+tar -xzvf data.tar.gz
+```
+
+Your `fairseq/examples/emotion_conversion` directory should like this:
+```
+drwxrwxr-x 3 felixkreuk felixkreuk 0 Feb 6 2022 data
+drwxrwxr-x 3 felixkreuk felixkreuk 0 Sep 28 10:41 emotion_models
+drwxr-xr-x 3 felixkreuk felixkreuk 0 Jun 29 05:43 fairseq_models
+drwxr-xr-x 3 felixkreuk felixkreuk 0 Sep 28 10:41 preprocess
+-rw-rw-r-- 1 felixkreuk felixkreuk 11K Dec 5 09:00 README.md
+-rw-rw-r-- 1 felixkreuk felixkreuk 88 Mar 6 2022 requirements.txt
+-rw-rw-r-- 1 felixkreuk felixkreuk 13K Jun 29 06:26 synthesize.py
+```
+
+Lastly, install fairseq and the other packages:
+```
+pip install --editable ./
+pip install -r examples/emotion_conversion/requirements.txt
+```
+
+## Data preprocessing
+
+### Convert your audio to discrete representations
+Please follow the steps described [here](https://github.com/pytorch/fairseq/tree/main/examples/hubert/simple_kmeans).
+To generate the same discrete representations please use the following:
+1. [HuBERT checkpoint](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt)
+2. k-means model at `data/hubert_base_ls960_layer9_clusters200/data_hubert_base_ls960_layer9_clusters200.bin`
+
+### Construct data splits
+This step will use the discrete representations from the previous step and split them to train/valid/test sets for 3 tasks:
+1. Translation model pre-training (BART language denoising)
+2. Translation model training (content units emotion translation mechanism)
+3. HiFiGAN model training (for synthesizing audio from discrete representations)
+
+Your processed data should be at `data/`:
+1. `hubert_base_ls960_layer9_clusters200` - discrete representations extracted using HuBERT layer 9, clustered into 200 clusters.
+2. `data.tsv` - a tsv file pointing to the EmoV dataset in your environment (Please edit the first line of this file according to your path).
+
+The following command will create the above splits:
+```
+python examples/emotion_conversion/preprocess/create_core_manifest.py \
+ --tsv data/data.tsv \
+ --emov-km data/hubert_base_ls960_layer9_clusters200/data.km \
+ --km data/hubert_base_ls960_layer9_clusters200/vctk.km \
+ --dict data/hubert_base_ls960_layer9_clusters200/dict.txt \
+ --manifests-dir $DATA
+```
+* Set `$DATA` as the directory that will contain the processed data.
+
+### Extract F0
+To train the HiFiGAN vocoder we need to first extract the F0 curves:
+```
+python examples/emotion_conversion/preprocess/extract_f0.py \
+ --tsv data/data.tsv \
+ --extractor pyaapt \
+```
+
+## HiFiGAN training
+Now we are all set to train the HiFiGAN vocoder:
+```
+python examples/emotion_conversion/speech-resynthesis/train.py
+ --checkpoint_path \
+ --config examples/emotion_conversion/speech-resynthesis/configs/EmoV/emov_hubert-layer9-cluster200_fixed-spkr-embedder_f0-raw_gst.json
+```
+
+## Translation Pre-training
+Before translating emotions, we first need to pre-train the translation model as a denoising autoencoder (similarly to BART).
+```
+python train.py \
+ $DATA/fairseq-data/emov_multilingual_denoising_cross-speaker_dedup_nonzeroshot/tokenized \
+ --save-dir \
+ --tensorboard-logdir \
+ --langs neutral,amused,angry,sleepy,disgusted,vctk.km \
+ --dataset-impl mmap \
+ --task multilingual_denoising \
+ --arch transformer_small --criterion cross_entropy \
+ --multilang-sampling-alpha 1.0 --sample-break-mode eos --max-tokens 16384 \
+ --update-freq 1 --max-update 3000000 \
+ --dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.0 \
+ --optimizer adam --weight-decay 0.01 --adam-eps 1e-06 \
+ --clip-norm 0.1 --lr-scheduler polynomial_decay --lr 0.0003 \
+ --total-num-update 3000000 --warmup-updates 10000 --fp16 \
+ --poisson-lambda 3.5 --mask 0.3 --mask-length span-poisson --replace-length 1 --rotate 0 --mask-random 0.1 --insert 0 --permute-sentences 1.0 \
+ --skip-invalid-size-inputs-valid-test \
+ --user-dir examples/emotion_conversion/fairseq_models
+```
+
+## Translation Training
+Now we are ready to train our emotion translation model:
+```
+python train.py \
+ --distributed-world-size 1 \
+ $DATA/fairseq-data/emov_multilingual_translation_cross-speaker_dedup/tokenized/ \
+ --save-dir \
+ --tensorboard-logdir \
+ --arch multilingual_small --task multilingual_translation \
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
+ --lang-pairs neutral-amused,neutral-sleepy,neutral-disgusted,neutral-angry,amused-sleepy,amused-disgusted,amused-neutral,amused-angry,angry-amused,angry-sleepy,angry-disgusted,angry-neutral,disgusted-amused,disgusted-sleepy,disgusted-neutral,disgusted-angry,sleepy-amused,sleepy-neutral,sleepy-disgusted,sleepy-angry \
+ --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
+ --lr 1e-05 --clip-norm 0 --dropout 0.1 --attention-dropout 0.1 \
+ --weight-decay 0.01 --warmup-updates 2000 --lr-scheduler inverse_sqrt \
+ --max-tokens 4096 --update-freq 1 --max-update 100000 \
+ --required-batch-size-multiple 8 --fp16 --num-workers 4 \
+ --seed 2 --log-format json --log-interval 25 --save-interval-updates 1000 \
+ --no-epoch-checkpoints --keep-best-checkpoints 1 --keep-interval-updates 1 \
+ --finetune-from-model \
+ --user-dir examples/emotion_conversion/fairseq_models
+```
+* To share encoders/decoders use the `--share-encoders` and `--share-decoders` flags.
+* To add source/target emotion tokens use the `--encoder-langtok {'src'|'tgt'}` and `--decoder-langtok` flags.
+
+## F0-predictor Training
+The following command trains the F0 prediction module:
+```
+cd examples/emotion_conversion
+python -m emotion_models.pitch_predictor n_tokens=200 \
+ train_tsv="$DATA/denoising/emov/train.tsv" \
+ train_km="$DATA/denoising/emov/train.km" \
+ valid_tsv="$DATA/denoising/emov/valid.tsv" \
+ valid_km="$DATA/denoising/emov/valid.km"
+```
+* See `hyra.run.dir` to configure directory for saving models.
+
+## Duration-predictor Training
+The following command trains the duration prediction modules:
+```
+cd examples/emotion_conversion
+for emotion in "neutral" "amused" "angry" "disgusted" "sleepy"; do
+ python -m emotion_models.duration_predictor n_tokens=200 substring=$emotion \
+ train_tsv="$DATA/denoising/emov/train.tsv" \
+ train_km="$DATA/denoising/emov/train.km" \
+ valid_tsv="$DATA/denoising/emov/valid.tsv" \
+ valid_km="$DATA/denoising/emov/valid.km"
+done
+```
+* See `hyra.run.dir` to configure directory for saving models.
+* After the above command you should have 5 duration models in your checkpoint directory:
+```
+❯ ll duration_predictor/
+total 21M
+-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 amused.ckpt
+-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 angry.ckpt
+-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 disgusted.ckpt
+-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 neutral.ckpt
+-rw-rw-r-- 1 felixkreuk felixkreuk 4.1M Nov 15 2021 sleepy.ckpt
+```
+
+## Token Generation
+The following command uses `fairseq-generate` to generate the token sequences based on the source and target emotions.
+```
+fairseq-generate \
+ $DATA/fairseq-data/emov_multilingual_translation_cross-speaker_dedup/tokenized/ \
+ --task multilingual_translation \
+ --gen-subset test \
+ --path \
+ --beam 5 \
+ --batch-size 4 --max-len-a 1.8 --max-len-b 10 --lenpen 1 --min-len 1 \
+ --skip-invalid-size-inputs-valid-test --distributed-world-size 1 \
+ --source-lang neutral --target-lang amused \
+ --lang-pairs neutral-amused,neutral-sleepy,neutral-disgusted,neutral-angry,amused-sleepy,amused-disgusted,amused-neutral,amused-angry,angry-amused,angry-sleepy,angry-disgusted,angry-neutral,disgusted-amused,disgusted-sleepy,disgusted-neutral,disgusted-angry,sleepy-amused,sleepy-neutral,sleepy-disgusted,sleepy-angry \
+ --results-path \
+ --user-dir examples/emotion_conversion/fairseq_models
+```
+* Modify `--source-lang` and `--target-lang` to control for the source and target emotions.
+* See [fairseq documentation](https://fairseq.readthedocs.io/en/latest/command_line_tools.html#fairseq-generate) for a full overview of generation parameters (e.g., top-k/top-p sampling).
+
+## Waveform Synthesis
+Using the output of the above command, the HiFiGAN vocoder, and the prosody prediction modules (F0 and duration) we can now generate the output waveforms:
+```
+python examples/emotion_conversion/synthesize.py \
+ --result-path /generate-test.txt \
+ --data $DATA/fairseq-data/emov_multilingual_translation_cross-speaker_dedup/neutral-amused \
+ --orig-tsv examples/emotion_conversion/data/data.tsv \
+ --orig-km examples/emotion_conversion/data/hubert_base_ls960_layer9_clusters200/data.km \
+ --checkpoint-file /g_00400000 \
+ --dur-model duration_predictor/ \
+ --f0-model pitch_predictor/pitch_predictor.ckpt \
+ -s neutral -t amused \
+ --outdir ~/tmp/emotion_results/wavs/neutral-amused
+```
+* Please make sure the source and target emotions here match those of the previous command.
+
+# Citation
+If you find this useful in your research, please use the following BibTeX entry for citation.
+```
+@article{kreuk2021textless,
+ title={Textless speech emotion conversion using decomposed and discrete representations},
+ author={Kreuk, Felix and Polyak, Adam and Copet, Jade and Kharitonov, Eugene and Nguyen, Tu-Anh and Rivi{\`e}re, Morgane and Hsu, Wei-Ning and Mohamed, Abdelrahman and Dupoux, Emmanuel and Adi, Yossi},
+ journal={Conference on Empirical Methods in Natural Language Processing (EMNLP)},
+ year={2022}
+}
+```
diff --git a/examples/emotion_conversion/emotion_models/__init__.py b/examples/emotion_conversion/emotion_models/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/emotion_conversion/emotion_models/duration_predictor.py b/examples/emotion_conversion/emotion_models/duration_predictor.py
new file mode 100644
index 0000000000..eb47df0a21
--- /dev/null
+++ b/examples/emotion_conversion/emotion_models/duration_predictor.py
@@ -0,0 +1,243 @@
+import logging
+import os
+
+import hydra
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.layers.torch import Rearrange
+from torch.utils.data import DataLoader, Dataset
+
+from .utils import Accuracy
+
+logger = logging.getLogger(__name__)
+
+
+def save_ckpt(model, path, model_class):
+ ckpt = {
+ "state_dict": model.state_dict(),
+ "padding_token": model.padding_token,
+ "model_class": model_class,
+ }
+ torch.save(ckpt, path)
+
+
+def load_ckpt(path):
+ ckpt = torch.load(path)
+ ckpt["model_class"]["_target_"] = "emotion_models.duration_predictor.CnnPredictor"
+ model = hydra.utils.instantiate(ckpt["model_class"])
+ model.load_state_dict(ckpt["state_dict"])
+ model.padding_token = ckpt["padding_token"]
+ model = model.cpu()
+ model.eval()
+ return model
+
+
+class Collator:
+ def __init__(self, padding_idx):
+ self.padding_idx = padding_idx
+
+ def __call__(self, batch):
+ x = [item[0] for item in batch]
+ lengths = [len(item) for item in x]
+ x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=self.padding_idx)
+ y = [item[1] for item in batch]
+ y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=self.padding_idx)
+ mask = (x != self.padding_idx)
+ return x, y, mask, lengths
+
+
+class Predictor(nn.Module):
+ def __init__(self, n_tokens, emb_dim):
+ super(Predictor, self).__init__()
+ self.n_tokens = n_tokens
+ self.emb_dim = emb_dim
+ self.padding_token = n_tokens
+ # add 1 extra embedding for padding token, set the padding index to be the last token
+ # (tokens from the clustering start at index 0)
+ self.emb = nn.Embedding(n_tokens + 1, emb_dim, padding_idx=self.padding_token)
+
+ def inflate_input(self, batch):
+ """ get a sequence of tokens, predict their durations
+ and inflate them accordingly """
+ batch_durs = self.forward(batch)
+ batch_durs = torch.exp(batch_durs) - 1
+ batch_durs = batch_durs.round()
+ output = []
+ for seq, durs in zip(batch, batch_durs):
+ inflated_seq = []
+ for token, n in zip(seq, durs):
+ if token == self.padding_token:
+ break
+ n = int(n.item())
+ token = int(token.item())
+ inflated_seq.extend([token for _ in range(n)])
+ output.append(inflated_seq)
+ output = torch.LongTensor(output)
+ return output
+
+
+class CnnPredictor(Predictor):
+ def __init__(self, n_tokens, emb_dim, channels, kernel, output_dim, dropout, n_layers):
+ super(CnnPredictor, self).__init__(n_tokens=n_tokens, emb_dim=emb_dim)
+ layers = [
+ Rearrange("b t c -> b c t"),
+ nn.Conv1d(emb_dim, channels, kernel_size=kernel, padding=(kernel - 1) // 2),
+ Rearrange("b c t -> b t c"),
+ nn.ReLU(),
+ nn.LayerNorm(channels),
+ nn.Dropout(dropout),
+ ]
+ for _ in range(n_layers-1):
+ layers += [
+ Rearrange("b t c -> b c t"),
+ nn.Conv1d(channels, channels, kernel_size=kernel, padding=(kernel - 1) // 2),
+ Rearrange("b c t -> b t c"),
+ nn.ReLU(),
+ nn.LayerNorm(channels),
+ nn.Dropout(dropout),
+ ]
+ self.conv_layer = nn.Sequential(*layers)
+ self.proj = nn.Linear(channels, output_dim)
+
+ def forward(self, x):
+ x = self.emb(x)
+ x = self.conv_layer(x)
+ x = self.proj(x)
+ x = x.squeeze(-1)
+ return x
+
+
+def l2_log_loss(input, target):
+ return F.mse_loss(
+ input=input.float(),
+ target=torch.log(target.float() + 1),
+ reduce=False
+ )
+
+
+class DurationDataset(Dataset):
+ def __init__(self, tsv_path, km_path, substring=""):
+ lines = open(tsv_path, "r").readlines()
+ self.root, self.tsv = lines[0], lines[1:]
+ self.km = open(km_path, "r").readlines()
+ logger.info(f"loaded {len(self.km)} files")
+
+ if substring != "":
+ tsv, km = [], []
+ for tsv_line, km_line in zip(self.tsv, self.km):
+ if substring.lower() in tsv_line.lower():
+ tsv.append(tsv_line)
+ km.append(km_line)
+ self.tsv, self.km = tsv, km
+ logger.info(f"after filtering: {len(self.km)} files")
+
+ def __len__(self):
+ return len(self.km)
+
+ def __getitem__(self, i):
+ x = self.km[i]
+ x = x.split(" ")
+ x = list(map(int, x))
+
+ y = []
+ xd = []
+ count = 1
+ for x1, x2 in zip(x[:-1], x[1:]):
+ if x1 == x2:
+ count += 1
+ continue
+ else:
+ y.append(count)
+ xd.append(x1)
+ count = 1
+
+ xd = torch.LongTensor(xd)
+ y = torch.LongTensor(y)
+ return xd, y
+
+
+def train(cfg):
+ device = "cuda:0"
+ model = hydra.utils.instantiate(cfg[cfg.model]).to(device)
+ optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters())
+ # add 1 extra embedding for padding token, set the padding index to be the last token
+ # (tokens from the clustering start at index 0)
+ collate_fn = Collator(padding_idx=model.padding_token)
+ logger.info(f"data: {cfg.train_tsv}")
+ train_ds = DurationDataset(cfg.train_tsv, cfg.train_km, substring=cfg.substring)
+ valid_ds = DurationDataset(cfg.valid_tsv, cfg.valid_km, substring=cfg.substring)
+ train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn)
+ valid_dl = DataLoader(valid_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)
+
+ best_loss = float("inf")
+ for epoch in range(cfg.epochs):
+ train_loss, train_loss_scaled = train_epoch(model, train_dl, l2_log_loss, optimizer, device)
+ valid_loss, valid_loss_scaled, *acc = valid_epoch(model, valid_dl, l2_log_loss, device)
+ acc0, acc1, acc2, acc3 = acc
+ if valid_loss_scaled < best_loss:
+ path = f"{os.getcwd()}/{cfg.substring}.ckpt"
+ save_ckpt(model, path, cfg[cfg.model])
+ best_loss = valid_loss_scaled
+ logger.info(f"saved checkpoint: {path}")
+ logger.info(f"[epoch {epoch}] train loss: {train_loss:.3f}, train scaled: {train_loss_scaled:.3f}")
+ logger.info(f"[epoch {epoch}] valid loss: {valid_loss:.3f}, valid scaled: {valid_loss_scaled:.3f}")
+ logger.info(f"acc: {acc0,acc1,acc2,acc3}")
+
+
+def train_epoch(model, loader, criterion, optimizer, device):
+ model.train()
+ epoch_loss = 0
+ epoch_loss_scaled = 0
+ for x, y, mask, _ in loader:
+ x, y, mask = x.to(device), y.to(device), mask.to(device)
+ yhat = model(x)
+ loss = criterion(yhat, y) * mask
+ loss = torch.mean(loss)
+ loss.backward()
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ epoch_loss += loss.item()
+ # get normal scale loss
+ yhat_scaled = torch.exp(yhat) - 1
+ yhat_scaled = torch.round(yhat_scaled)
+ scaled_loss = torch.mean(torch.abs(yhat_scaled - y) * mask)
+ epoch_loss_scaled += scaled_loss.item()
+ return epoch_loss / len(loader), epoch_loss_scaled / len(loader)
+
+
+def valid_epoch(model, loader, criterion, device):
+ model.eval()
+ epoch_loss = 0
+ epoch_loss_scaled = 0
+ acc = Accuracy()
+ for x, y, mask, _ in loader:
+ x, y, mask = x.to(device), y.to(device), mask.to(device)
+ yhat = model(x)
+ loss = criterion(yhat, y) * mask
+ loss = torch.mean(loss)
+ epoch_loss += loss.item()
+ # get normal scale loss
+ yhat_scaled = torch.exp(yhat) - 1
+ yhat_scaled = torch.round(yhat_scaled)
+ scaled_loss = torch.sum(torch.abs(yhat_scaled - y) * mask) / mask.sum()
+ acc.update(yhat_scaled[mask].view(-1).float(), y[mask].view(-1).float())
+ epoch_loss_scaled += scaled_loss.item()
+ logger.info(f"example y: {y[0, :10].tolist()}")
+ logger.info(f"example yhat: {yhat_scaled[0, :10].tolist()}")
+ acc0 = acc.acc(tol=0)
+ acc1 = acc.acc(tol=1)
+ acc2 = acc.acc(tol=2)
+ acc3 = acc.acc(tol=3)
+ logger.info(f"accs: {acc0,acc1,acc2,acc3}")
+ return epoch_loss / len(loader), epoch_loss_scaled / len(loader), acc0, acc1, acc2, acc3
+
+
+@hydra.main(config_path=".", config_name="duration_predictor.yaml")
+def main(cfg):
+ logger.info(f"{cfg}")
+ train(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/emotion_conversion/emotion_models/duration_predictor.yaml b/examples/emotion_conversion/emotion_models/duration_predictor.yaml
new file mode 100644
index 0000000000..0e976f4843
--- /dev/null
+++ b/examples/emotion_conversion/emotion_models/duration_predictor.yaml
@@ -0,0 +1,48 @@
+train_tsv: "/denoising/emov/train.tsv"
+train_km: "/denoising/emov/train.km"
+valid_tsv: "/denoising/emov/valid.tsv"
+valid_km: "/denoising/emov/valid.km"
+
+n_tokens: 200
+batch_size: 32
+lr: 0.0001
+epochs: 300
+model: "cnn"
+substring: ""
+
+rnn:
+ _target_: emotion_models.duration_predictor.RnnPredictor
+ n_tokens: ${n_tokens}
+ emb_dim: 128
+ rnn_hidden: 128
+ output_dim: 1
+ dropout: 0
+ n_layers: 1
+
+optimizer:
+ _target_: torch.optim.Adam
+ lr: ${lr}
+ betas: [0.9, 0.98]
+ eps: 0.000000001
+ weight_decay: 0
+
+cnn:
+ _target_: emotion_models.duration_predictor.CnnPredictor
+ n_tokens: ${n_tokens}
+ emb_dim: 128
+ channels: 256
+ kernel: 3
+ output_dim: 1
+ dropout: 0.5
+ n_layers: 1
+
+hydra:
+ run:
+ dir: /checkpoint/felixkreuk/experiments/duration_predictor/${hydra.job.override_dirname}
+ job:
+ config:
+ # configuration for the ${hydra.job.override_dirname} runtime variable
+ override_dirname:
+ kv_sep: '='
+ item_sep: ','
+ exclude_keys: ['train_tsv', 'train_km', 'valid_tsv', 'valid_km']
diff --git a/examples/emotion_conversion/emotion_models/pitch_predictor.py b/examples/emotion_conversion/emotion_models/pitch_predictor.py
new file mode 100644
index 0000000000..431446996c
--- /dev/null
+++ b/examples/emotion_conversion/emotion_models/pitch_predictor.py
@@ -0,0 +1,559 @@
+import logging
+import os
+import random
+import sys
+from collections import defaultdict
+
+import hydra
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from scipy.io.wavfile import read
+from scipy.ndimage import gaussian_filter1d
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+
+dir_path = os.path.dirname(__file__)
+resynth_path = os.path.dirname(dir_path) + "/speech-resynthesis"
+sys.path.append(resynth_path)
+from dataset import parse_speaker, parse_style
+from .utils import F0Stat
+
+MAX_WAV_VALUE = 32768.0
+logger = logging.getLogger(__name__)
+
+
+def quantize_f0(speaker_to_f0, nbins, normalize, log):
+ f0_all = []
+ for speaker, f0 in speaker_to_f0.items():
+ f0 = f0.raw_data
+ if log:
+ f0 = f0.log()
+ mean = speaker_to_f0[speaker].mean_log if log else speaker_to_f0[speaker].mean
+ std = speaker_to_f0[speaker].std_log if log else speaker_to_f0[speaker].std
+ if normalize == "mean":
+ f0 = f0 - mean
+ elif normalize == "meanstd":
+ f0 = (f0 - mean) / std
+ f0_all.extend(f0.tolist())
+
+ hist, bin_x = np.histogram(f0_all, 100000)
+ cum_hist = np.cumsum(hist) / len(f0_all) * 100
+
+ bin_offset = []
+ bin_size = 100 / nbins
+ threshold = bin_size
+ for i in range(nbins - 1):
+ index = (np.abs(cum_hist - threshold)).argmin()
+ bin_offset.append(bin_x[index])
+ threshold += bin_size
+ bins = np.array(bin_offset)
+ bins = torch.FloatTensor(bins)
+
+ return bins
+
+
+def save_ckpt(model, path, model_class, f0_min, f0_max, f0_bins, speaker_stats):
+ ckpt = {
+ "state_dict": model.state_dict(),
+ "padding_token": model.padding_token,
+ "model_class": model_class,
+ "speaker_stats": speaker_stats,
+ "f0_min": f0_min,
+ "f0_max": f0_max,
+ "f0_bins": f0_bins,
+ }
+ torch.save(ckpt, path)
+
+
+def load_ckpt(path):
+ ckpt = torch.load(path)
+ ckpt["model_class"]["_target_"] = "emotion_models.pitch_predictor.CnnPredictor"
+ model = hydra.utils.instantiate(ckpt["model_class"])
+ model.load_state_dict(ckpt["state_dict"])
+ model.setup_f0_stats(
+ ckpt["f0_min"],
+ ckpt["f0_max"],
+ ckpt["f0_bins"],
+ ckpt["speaker_stats"],
+ )
+ return model
+
+
+def freq2bin(f0, f0_min, f0_max, bins):
+ f0 = f0.clone()
+ f0[f0 < f0_min] = f0_min
+ f0[f0 > f0_max] = f0_max
+ f0 = torch.bucketize(f0, bins)
+ return f0
+
+
+def bin2freq(x, f0_min, f0_max, bins, mode):
+ n_bins = len(bins) + 1
+ assert x.shape[-1] == n_bins
+ bins = torch.cat([torch.tensor([f0_min]), bins]).to(x.device)
+ if mode == "mean":
+ f0 = (x * bins).sum(-1, keepdims=True) / x.sum(-1, keepdims=True)
+ elif mode == "argmax":
+ idx = F.one_hot(x.argmax(-1), num_classes=n_bins)
+ f0 = (idx * bins).sum(-1, keepdims=True)
+ else:
+ raise NotImplementedError()
+ return f0[..., 0]
+
+
+def load_wav(full_path):
+ sampling_rate, data = read(full_path)
+ return data, sampling_rate
+
+
+def l1_loss(input, target):
+ return F.l1_loss(input=input.float(), target=target.float(), reduce=False)
+
+
+def l2_loss(input, target):
+ return F.mse_loss(input=input.float(), target=target.float(), reduce=False)
+
+
+class Collator:
+ def __init__(self, padding_idx):
+ self.padding_idx = padding_idx
+
+ def __call__(self, batch):
+ tokens = [item[0] for item in batch]
+ lengths = [len(item) for item in tokens]
+ tokens = torch.nn.utils.rnn.pad_sequence(
+ tokens, batch_first=True, padding_value=self.padding_idx
+ )
+ f0 = [item[1] for item in batch]
+ f0 = torch.nn.utils.rnn.pad_sequence(
+ f0, batch_first=True, padding_value=self.padding_idx
+ )
+ f0_raw = [item[2] for item in batch]
+ f0_raw = torch.nn.utils.rnn.pad_sequence(
+ f0_raw, batch_first=True, padding_value=self.padding_idx
+ )
+ spk = [item[3] for item in batch]
+ spk = torch.LongTensor(spk)
+ gst = [item[4] for item in batch]
+ gst = torch.LongTensor(gst)
+ mask = tokens != self.padding_idx
+ return tokens, f0, f0_raw, spk, gst, mask, lengths
+
+
+class CnnPredictor(nn.Module):
+ def __init__(
+ self,
+ n_tokens,
+ emb_dim,
+ channels,
+ kernel,
+ dropout,
+ n_layers,
+ spk_emb,
+ gst_emb,
+ n_bins,
+ f0_pred,
+ f0_log,
+ f0_norm,
+ ):
+ super(CnnPredictor, self).__init__()
+ self.n_tokens = n_tokens
+ self.emb_dim = emb_dim
+ self.f0_log = f0_log
+ self.f0_pred = f0_pred
+ self.padding_token = n_tokens
+ self.f0_norm = f0_norm
+ # add 1 extra embedding for padding token, set the padding index to be the last token
+ # (tokens from the clustering start at index 0)
+ self.token_emb = nn.Embedding(
+ n_tokens + 1, emb_dim, padding_idx=self.padding_token
+ )
+
+ self.spk_emb = spk_emb
+ self.gst_emb = nn.Embedding(20, gst_emb)
+ self.setup = False
+
+ feats = emb_dim + gst_emb
+ # feats = emb_dim + gst_emb + (256 if spk_emb else 0)
+ layers = [
+ nn.Sequential(
+ Rearrange("b t c -> b c t"),
+ nn.Conv1d(
+ feats, channels, kernel_size=kernel, padding=(kernel - 1) // 2
+ ),
+ Rearrange("b c t -> b t c"),
+ nn.ReLU(),
+ nn.LayerNorm(channels),
+ nn.Dropout(dropout),
+ )
+ ]
+ for _ in range(n_layers - 1):
+ layers += [
+ nn.Sequential(
+ Rearrange("b t c -> b c t"),
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=kernel,
+ padding=(kernel - 1) // 2,
+ ),
+ Rearrange("b c t -> b t c"),
+ nn.ReLU(),
+ nn.LayerNorm(channels),
+ nn.Dropout(dropout),
+ )
+ ]
+ self.conv_layer = nn.ModuleList(layers)
+ self.proj = nn.Linear(channels, n_bins)
+
+ def forward(self, x, gst=None):
+ x = self.token_emb(x)
+ feats = [x]
+
+ if gst is not None:
+ gst = self.gst_emb(gst)
+ gst = rearrange(gst, "b c -> b c 1")
+ gst = F.interpolate(gst, x.shape[1])
+ gst = rearrange(gst, "b c t -> b t c")
+ feats.append(gst)
+
+ x = torch.cat(feats, dim=-1)
+
+ for i, conv in enumerate(self.conv_layer):
+ if i != 0:
+ x = conv(x) + x
+ else:
+ x = conv(x)
+
+ x = self.proj(x)
+ x = x.squeeze(-1)
+
+ if self.f0_pred == "mean":
+ x = torch.sigmoid(x)
+ elif self.f0_pred == "argmax":
+ x = torch.softmax(x, dim=-1)
+ else:
+ raise NotImplementedError
+ return x
+
+ def setup_f0_stats(self, f0_min, f0_max, f0_bins, speaker_stats):
+ self.f0_min = f0_min
+ self.f0_max = f0_max
+ self.f0_bins = f0_bins
+ self.speaker_stats = speaker_stats
+ self.setup = True
+
+ def inference(self, x, spk_id=None, gst=None):
+ assert (
+ self.setup == True
+ ), "make sure that `setup_f0_stats` was called before inference!"
+ probs = self(x, gst)
+ f0 = bin2freq(probs, self.f0_min, self.f0_max, self.f0_bins, self.f0_pred)
+ for i in range(f0.shape[0]):
+ mean = (
+ self.speaker_stats[spk_id[i].item()].mean_log
+ if self.f0_log
+ else self.speaker_stats[spk_id[i].item()].mean
+ )
+ std = (
+ self.speaker_stats[spk_id[i].item()].std_log
+ if self.f0_log
+ else self.speaker_stats[spk_id[i].item()].std
+ )
+ if self.f0_norm == "mean":
+ f0[i] = f0[i] + mean
+ if self.f0_norm == "meanstd":
+ f0[i] = (f0[i] * std) + mean
+ if self.f0_log:
+ f0 = f0.exp()
+ return f0
+
+
+class PitchDataset(Dataset):
+ def __init__(
+ self,
+ tsv_path,
+ km_path,
+ substring,
+ spk,
+ spk2id,
+ gst,
+ gst2id,
+ f0_bins,
+ f0_bin_type,
+ f0_smoothing,
+ f0_norm,
+ f0_log,
+ ):
+ lines = open(tsv_path, "r").readlines()
+ self.root, self.tsv = lines[0], lines[1:]
+ self.root = self.root.strip()
+ self.km = open(km_path, "r").readlines()
+ print(f"loaded {len(self.km)} files")
+
+ self.spk = spk
+ self.spk2id = spk2id
+ self.gst = gst
+ self.gst2id = gst2id
+
+ self.f0_bins = f0_bins
+ self.f0_smoothing = f0_smoothing
+ self.f0_norm = f0_norm
+ self.f0_log = f0_log
+
+ if substring != "":
+ tsv, km = [], []
+ for tsv_line, km_line in zip(self.tsv, self.km):
+ if substring.lower() in tsv_line.lower():
+ tsv.append(tsv_line)
+ km.append(km_line)
+ self.tsv, self.km = tsv, km
+ print(f"after filtering: {len(self.km)} files")
+
+ self.speaker_stats = self._compute_f0_stats()
+ self.f0_min, self.f0_max = self._compute_f0_minmax()
+ if f0_bin_type == "adaptive":
+ self.f0_bins = quantize_f0(
+ self.speaker_stats, self.f0_bins, self.f0_norm, self.f0_log
+ )
+ elif f0_bin_type == "uniform":
+ self.f0_bins = torch.linspace(self.f0_min, self.f0_max, self.f0_bins + 1)[
+ 1:-1
+ ]
+ else:
+ raise NotImplementedError
+ print(f"f0 min: {self.f0_min}, f0 max: {self.f0_max}")
+ print(f"bins: {self.f0_bins} (shape: {self.f0_bins.shape})")
+
+ def __len__(self):
+ return len(self.km)
+
+ def _load_f0(self, tsv_line):
+ tsv_line = tsv_line.split("\t")[0]
+ f0 = self.root + "/" + tsv_line.replace(".wav", ".yaapt.f0.npy")
+ f0 = np.load(f0)
+ f0 = torch.FloatTensor(f0)
+ return f0
+
+ def _preprocess_f0(self, f0, spk):
+ mask = f0 != -999999 # process all frames
+ # mask = (f0 != 0) # only process voiced frames
+ mean = (
+ self.speaker_stats[spk].mean_log
+ if self.f0_log
+ else self.speaker_stats[spk].mean
+ )
+ std = (
+ self.speaker_stats[spk].std_log
+ if self.f0_log
+ else self.speaker_stats[spk].std
+ )
+ if self.f0_log:
+ f0[f0 == 0] = 1e-5
+ f0[mask] = f0[mask].log()
+ if self.f0_norm == "mean":
+ f0[mask] = f0[mask] - mean
+ if self.f0_norm == "meanstd":
+ f0[mask] = (f0[mask] - mean) / std
+ return f0
+
+ def _compute_f0_minmax(self):
+ f0_min, f0_max = float("inf"), -float("inf")
+ for tsv_line in tqdm(self.tsv, desc="computing f0 minmax"):
+ spk = self.spk2id[parse_speaker(tsv_line, self.spk)]
+ f0 = self._load_f0(tsv_line)
+ f0 = self._preprocess_f0(f0, spk)
+ f0_min = min(f0_min, f0.min().item())
+ f0_max = max(f0_max, f0.max().item())
+ return f0_min, f0_max
+
+ def _compute_f0_stats(self):
+ from functools import partial
+
+ speaker_stats = defaultdict(partial(F0Stat, True))
+ for tsv_line in tqdm(self.tsv, desc="computing speaker stats"):
+ spk = self.spk2id[parse_speaker(tsv_line, self.spk)]
+ f0 = self._load_f0(tsv_line)
+ mask = f0 != 0
+ f0 = f0[mask] # compute stats only on voiced parts
+ speaker_stats[spk].update(f0)
+ return speaker_stats
+
+ def __getitem__(self, i):
+ x = self.km[i]
+ x = x.split(" ")
+ x = list(map(int, x))
+ x = torch.LongTensor(x)
+
+ gst = parse_style(self.tsv[i], self.gst)
+ gst = self.gst2id[gst]
+ spk = parse_speaker(self.tsv[i], self.spk)
+ spk = self.spk2id[spk]
+
+ f0_raw = self._load_f0(self.tsv[i])
+ f0 = self._preprocess_f0(f0_raw.clone(), spk)
+
+ f0 = F.interpolate(f0.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0]
+ f0_raw = F.interpolate(f0_raw.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0]
+
+ f0 = freq2bin(f0, f0_min=self.f0_min, f0_max=self.f0_max, bins=self.f0_bins)
+ f0 = F.one_hot(f0.long(), num_classes=len(self.f0_bins) + 1).float()
+ if self.f0_smoothing > 0:
+ f0 = torch.tensor(
+ gaussian_filter1d(f0.float().numpy(), sigma=self.f0_smoothing)
+ )
+ return x, f0, f0_raw, spk, gst
+
+
+def train(cfg):
+ device = "cuda:0"
+ # add 1 extra embedding for padding token, set the padding index to be the last token
+ # (tokens from the clustering start at index 0)
+ padding_token = cfg.n_tokens
+ collate_fn = Collator(padding_idx=padding_token)
+ train_ds = PitchDataset(
+ cfg.train_tsv,
+ cfg.train_km,
+ substring=cfg.substring,
+ spk=cfg.spk,
+ spk2id=cfg.spk2id,
+ gst=cfg.gst,
+ gst2id=cfg.gst2id,
+ f0_bins=cfg.f0_bins,
+ f0_bin_type=cfg.f0_bin_type,
+ f0_smoothing=cfg.f0_smoothing,
+ f0_norm=cfg.f0_norm,
+ f0_log=cfg.f0_log,
+ )
+ valid_ds = PitchDataset(
+ cfg.valid_tsv,
+ cfg.valid_km,
+ substring=cfg.substring,
+ spk=cfg.spk,
+ spk2id=cfg.spk2id,
+ gst=cfg.gst,
+ gst2id=cfg.gst2id,
+ f0_bins=cfg.f0_bins,
+ f0_bin_type=cfg.f0_bin_type,
+ f0_smoothing=cfg.f0_smoothing,
+ f0_norm=cfg.f0_norm,
+ f0_log=cfg.f0_log,
+ )
+ train_dl = DataLoader(
+ train_ds,
+ num_workers=0,
+ batch_size=cfg.batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ )
+ valid_dl = DataLoader(
+ valid_ds, num_workers=0, batch_size=16, shuffle=False, collate_fn=collate_fn
+ )
+
+ f0_min = train_ds.f0_min
+ f0_max = train_ds.f0_max
+ f0_bins = train_ds.f0_bins
+ speaker_stats = train_ds.speaker_stats
+
+ model = hydra.utils.instantiate(cfg["model"]).to(device)
+ model.setup_f0_stats(f0_min, f0_max, f0_bins, speaker_stats)
+
+ optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters())
+
+ best_loss = float("inf")
+ for epoch in range(cfg.epochs):
+ train_loss, train_l2_loss, train_l2_voiced_loss = run_epoch(
+ model, train_dl, optimizer, device, cfg, mode="train"
+ )
+ valid_loss, valid_l2_loss, valid_l2_voiced_loss = run_epoch(
+ model, valid_dl, None, device, cfg, mode="valid"
+ )
+ print(
+ f"[epoch {epoch}] train loss: {train_loss:.3f}, l2 loss: {train_l2_loss:.3f}, l2 voiced loss: {train_l2_voiced_loss:.3f}"
+ )
+ print(
+ f"[epoch {epoch}] valid loss: {valid_loss:.3f}, l2 loss: {valid_l2_loss:.3f}, l2 voiced loss: {valid_l2_voiced_loss:.3f}"
+ )
+ if valid_l2_voiced_loss < best_loss:
+ path = f"{os.getcwd()}/pitch_predictor.ckpt"
+ save_ckpt(model, path, cfg["model"], f0_min, f0_max, f0_bins, speaker_stats)
+ best_loss = valid_l2_voiced_loss
+ print(f"saved checkpoint: {path}")
+ print(f"[epoch {epoch}] best loss: {best_loss:.3f}")
+
+
+def run_epoch(model, loader, optimizer, device, cfg, mode):
+ if mode == "train":
+ model.train()
+ else:
+ model.eval()
+
+ epoch_loss = 0
+ l1 = 0
+ l1_voiced = 0
+ for x, f0_bin, f0_raw, spk_id, gst, mask, _ in tqdm(loader):
+ x, f0_bin, f0_raw, spk_id, gst, mask = (
+ x.to(device),
+ f0_bin.to(device),
+ f0_raw.to(device),
+ spk_id.to(device),
+ gst.to(device),
+ mask.to(device),
+ )
+ b, t, n_bins = f0_bin.shape
+ yhat = model(x, gst)
+ nonzero_mask = (f0_raw != 0).logical_and(mask)
+ yhat_raw = model.inference(x, spk_id, gst)
+ expanded_mask = mask.unsqueeze(-1).expand(-1, -1, n_bins)
+ if cfg.f0_pred == "mean":
+ loss = F.binary_cross_entropy(
+ yhat[expanded_mask], f0_bin[expanded_mask]
+ ).mean()
+ elif cfg.f0_pred == "argmax":
+ loss = F.cross_entropy(
+ rearrange(yhat, "b t d -> (b t) d"),
+ rearrange(f0_bin.argmax(-1), "b t -> (b t)"),
+ reduce=False,
+ )
+ loss = rearrange(loss, "(b t) -> b t", b=b, t=t)
+ loss = (loss * mask).sum() / mask.float().sum()
+ else:
+ raise NotImplementedError
+ l1 += F.l1_loss(yhat_raw[mask], f0_raw[mask]).item()
+ l1_voiced += F.l1_loss(yhat_raw[nonzero_mask], f0_raw[nonzero_mask]).item()
+ epoch_loss += loss.item()
+
+ if mode == "train":
+ loss.backward()
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+
+ print(f"{mode} example y: {f0_bin.argmax(-1)[0, 50:60].tolist()}")
+ print(f"{mode} example yhat: {yhat.argmax(-1)[0, 50:60].tolist()}")
+ print(f"{mode} example y: {f0_raw[0, 50:60].round().tolist()}")
+ print(f"{mode} example yhat: {yhat_raw[0, 50:60].round().tolist()}")
+ return epoch_loss / len(loader), l1 / len(loader), l1_voiced / len(loader)
+
+
+@hydra.main(config_path=dir_path, config_name="pitch_predictor.yaml")
+def main(cfg):
+ np.random.seed(1)
+ random.seed(1)
+ torch.manual_seed(1)
+ from hydra.core.hydra_config import HydraConfig
+
+ overrides = {
+ x.split("=")[0]: x.split("=")[1]
+ for x in HydraConfig.get().overrides.task
+ if "/" not in x
+ }
+ print(f"{cfg}")
+ train(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/emotion_conversion/emotion_models/pitch_predictor.yaml b/examples/emotion_conversion/emotion_models/pitch_predictor.yaml
new file mode 100644
index 0000000000..d2dbb862c3
--- /dev/null
+++ b/examples/emotion_conversion/emotion_models/pitch_predictor.yaml
@@ -0,0 +1,64 @@
+train_tsv: "/denoising/emov/train.tsv"
+train_km: "/denoising/emov/train.km"
+valid_tsv: "/denoising/emov/valid.tsv"
+valid_km: "/denoising/emov/valid.km"
+
+n_tokens: 200
+batch_size: 64
+lr: 0.0001
+epochs: 1000
+
+substring: ""
+loss: "l2"
+spk: "parent_parent_name"
+gst: "emotion"
+
+f0_bins: 50
+f0_pred: "mean" # [argmax, mean]
+f0_smoothing: 0.1
+f0_norm: "mean"
+f0_log: false
+f0_bin_type: "adaptive" # [uniform, adaptive]
+
+spk2id:
+ bea: 0
+ jenie: 1
+ josh: 2
+ sam: 3
+
+gst2id:
+ amused: 0
+ angry: 1
+ disgusted: 2
+ neutral: 3
+ sleepy: 4
+
+optimizer:
+ _target_: torch.optim.Adam
+ lr: ${lr}
+
+model:
+ _target_: emotion_models.pitch_predictor.CnnPredictor
+ n_tokens: ${n_tokens}
+ emb_dim: 256
+ channels: 256
+ kernel: 5
+ dropout: 0.1
+ n_layers: 6
+ spk_emb: true
+ gst_emb: 8
+ n_bins: ${f0_bins}
+ f0_pred: ${f0_pred}
+ f0_log: ${f0_log}
+ f0_norm: ${f0_norm}
+
+hydra:
+ run:
+ dir: /checkpoint/felixkreuk/experiments/pitch_predictor/${hydra.job.override_dirname}
+ job:
+ config:
+ # configuration for the ${hydra.job.override_dirname} runtime variable
+ override_dirname:
+ kv_sep: '='
+ item_sep: ','
+ exclude_keys: ['train_tsv', 'train_km', 'valid_tsv', 'valid_km']
diff --git a/examples/emotion_conversion/emotion_models/utils.py b/examples/emotion_conversion/emotion_models/utils.py
new file mode 100644
index 0000000000..4199c310f8
--- /dev/null
+++ b/examples/emotion_conversion/emotion_models/utils.py
@@ -0,0 +1,78 @@
+import torch
+
+
+class Stat:
+ def __init__(self, keep_raw=False):
+ self.x = 0.0
+ self.x2 = 0.0
+ self.z = 0.0 # z = logx
+ self.z2 = 0.0
+ self.n = 0.0
+ self.u = 0.0
+ self.keep_raw = keep_raw
+ self.raw = []
+
+ def update(self, new_x):
+ new_z = new_x.log()
+
+ self.x += new_x.sum()
+ self.x2 += (new_x**2).sum()
+ self.z += new_z.sum()
+ self.z2 += (new_z**2).sum()
+ self.n += len(new_x)
+ self.u += 1
+
+ if self.keep_raw:
+ self.raw.append(new_x)
+
+ @property
+ def mean(self):
+ return self.x / self.n
+
+ @property
+ def std(self):
+ return (self.x2 / self.n - self.mean**2) ** 0.5
+
+ @property
+ def mean_log(self):
+ return self.z / self.n
+
+ @property
+ def std_log(self):
+ return (self.z2 / self.n - self.mean_log**2) ** 0.5
+
+ @property
+ def n_frms(self):
+ return self.n
+
+ @property
+ def n_utts(self):
+ return self.u
+
+ @property
+ def raw_data(self):
+ assert self.keep_raw, "does not support storing raw data!"
+ return torch.cat(self.raw)
+
+
+class F0Stat(Stat):
+ def update(self, new_x):
+ # assume unvoiced frames are 0 and consider only voiced frames
+ if new_x is not None:
+ super().update(new_x[new_x != 0])
+
+
+class Accuracy:
+ def __init__(self):
+ self.y, self.yhat = [], []
+
+ def update(self, yhat, y):
+ self.yhat.append(yhat)
+ self.y.append(y)
+
+ def acc(self, tol):
+ yhat = torch.cat(self.yhat)
+ y = torch.cat(self.y)
+ acc = torch.abs(yhat - y) <= tol
+ acc = acc.float().mean().item()
+ return acc
diff --git a/examples/emotion_conversion/fairseq_models/__init__.py b/examples/emotion_conversion/fairseq_models/__init__.py
new file mode 100644
index 0000000000..441bc03db4
--- /dev/null
+++ b/examples/emotion_conversion/fairseq_models/__init__.py
@@ -0,0 +1,226 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq import utils
+from fairseq.models import (
+ FairseqMultiModel,
+ register_model,
+ register_model_architecture,
+)
+from fairseq.models.transformer import (
+ Embedding,
+ base_architecture,
+)
+from fairseq.models.multilingual_transformer import (
+ MultilingualTransformerModel,
+ base_multilingual_architecture,
+)
+from fairseq.utils import safe_hasattr
+from collections import OrderedDict
+
+
+@register_model("multilingual_transformer_from_mbart")
+class MultilingualTransformerModelFromMbart(MultilingualTransformerModel):
+ @classmethod
+ def build_model(cls, args, task):
+ """Build a new model instance."""
+ from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
+
+ assert isinstance(task, MultilingualTranslationTask)
+
+ # make sure all arguments are present in older models
+ base_multilingual_architecture(args)
+
+ if not safe_hasattr(args, "max_source_positions"):
+ args.max_source_positions = 1024
+ if not safe_hasattr(args, "max_target_positions"):
+ args.max_target_positions = 1024
+
+ src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs]
+ tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs]
+
+ if args.share_encoders:
+ args.share_encoder_embeddings = True
+ if args.share_decoders:
+ args.share_decoder_embeddings = True
+
+ def build_embedding(dictionary, embed_dim, path=None):
+ num_embeddings = len(dictionary)
+ padding_idx = dictionary.pad()
+ emb = Embedding(num_embeddings, embed_dim, padding_idx)
+ # if provided, load from preloaded dictionaries
+ if path:
+ embed_dict = utils.parse_embedding(path)
+ utils.load_embedding(embed_dict, dictionary, emb)
+ return emb
+
+ # build shared embeddings (if applicable)
+ shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
+ if args.share_all_embeddings:
+ if args.encoder_embed_dim != args.decoder_embed_dim:
+ raise ValueError(
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
+ )
+ if args.decoder_embed_path and (
+ args.decoder_embed_path != args.encoder_embed_path
+ ):
+ raise ValueError(
+ "--share-all-embeddings not compatible with --decoder-embed-path"
+ )
+ shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
+ dicts=task.dicts,
+ langs=task.langs,
+ embed_dim=args.encoder_embed_dim,
+ build_embedding=build_embedding,
+ pretrained_embed_path=args.encoder_embed_path,
+ )
+ shared_decoder_embed_tokens = shared_encoder_embed_tokens
+ args.share_decoder_input_output_embed = True
+ else:
+ if args.share_encoder_embeddings:
+ shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
+ dicts=task.dicts,
+ langs=src_langs,
+ embed_dim=args.encoder_embed_dim,
+ build_embedding=build_embedding,
+ pretrained_embed_path=args.encoder_embed_path,
+ )
+ if args.share_decoder_embeddings:
+ shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
+ dicts=task.dicts,
+ langs=tgt_langs,
+ embed_dim=args.decoder_embed_dim,
+ build_embedding=build_embedding,
+ pretrained_embed_path=args.decoder_embed_path,
+ )
+
+ # encoders/decoders for each language
+ lang_encoders, lang_decoders = {}, {}
+
+ def get_encoder(lang):
+ if lang not in lang_encoders:
+ if shared_encoder_embed_tokens is not None:
+ encoder_embed_tokens = shared_encoder_embed_tokens
+ else:
+ encoder_embed_tokens = build_embedding(
+ task.dicts[lang],
+ args.encoder_embed_dim,
+ args.encoder_embed_path,
+ )
+ lang_encoders[lang] = MultilingualTransformerModel._get_module_class(
+ True, args, task.dicts[lang], encoder_embed_tokens, src_langs
+ )
+ return lang_encoders[lang]
+
+ def get_decoder(lang):
+ if lang not in lang_decoders:
+ if shared_decoder_embed_tokens is not None:
+ decoder_embed_tokens = shared_decoder_embed_tokens
+ else:
+ decoder_embed_tokens = build_embedding(
+ task.dicts[lang],
+ args.decoder_embed_dim,
+ args.decoder_embed_path,
+ )
+ lang_decoders[lang] = MultilingualTransformerModel._get_module_class(
+ False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs
+ )
+ return lang_decoders[lang]
+
+ # shared encoders/decoders (if applicable)
+ shared_encoder, shared_decoder = None, None
+ if args.share_encoders:
+ shared_encoder = get_encoder(src_langs[0])
+ if args.share_decoders:
+ shared_decoder = get_decoder(tgt_langs[0])
+
+ encoders, decoders = OrderedDict(), OrderedDict()
+ for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
+ encoders[lang_pair] = (
+ shared_encoder if shared_encoder is not None else get_encoder(src)
+ )
+ decoders[lang_pair] = (
+ shared_decoder if shared_decoder is not None else get_decoder(tgt)
+ )
+
+ return MultilingualTransformerModelFromMbart(encoders, decoders)
+
+ def load_state_dict(self, state_dict, strict=True, model_cfg=None):
+ state_dict_subset = state_dict.copy()
+ lang_pairs = set([x.split(".")[1] for x in state_dict.keys()])
+ finetune_mode = not any("neutral" in lp for lp in lang_pairs)
+
+ if finetune_mode:
+ # load a pre-trained mBART/BART model
+ # we need this code because mBART/BART are not of type FairseqMultiModel but FairseqModel
+ # so we hackishly load the weights by replicating them for all lang pairs
+ print("loading pre-trained BART")
+ self_state_dict = self.state_dict()
+ for k, v in state_dict.items():
+ for lang_pair in self.models:
+ new_key = k if "models." in k else f"models.{lang_pair}.{k}"
+ # print(new_key)
+ if self_state_dict[new_key].shape == v.shape:
+ state_dict_subset[new_key] = v
+ elif any(
+ w in k
+ for w in [
+ "encoder.embed_tokens.weight",
+ "decoder.embed_tokens.weight",
+ "decoder.output_projection.weight",
+ ]
+ ):
+ # why vocab_size - 5? because there are `vocab_size` tokens from the language
+ # and 5 additional tokens in the denoising task: eos,bos,pad,unk,mask.
+ # but in the translation task there are only `vocab_size` + 4 (no mask).
+ print(
+ f"{k}: {self_state_dict[new_key].shape} != {v.shape}",
+ end="",
+ flush=True,
+ )
+ vocab_size = v.shape[0] - 5
+ state_dict_subset[new_key] = self_state_dict[new_key]
+ state_dict_subset[new_key] = v[: vocab_size + 4]
+ print(f" => fixed by using first {vocab_size + 4} dims")
+ else:
+ raise ValueError("unable to load model due to mimatched dims!")
+ del state_dict_subset[k]
+ else:
+ print("loading pre-trained emotion translation model")
+ for k, _ in state_dict.items():
+ assert k.startswith("models.")
+ lang_pair = k.split(".")[1]
+ if lang_pair not in self.models:
+ del state_dict_subset[k]
+
+ super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg)
+
+
+@register_model_architecture("transformer", "transformer_small")
+def transformer_small(args):
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
+ args.encoder_layers = getattr(args, "encoder_layers", 3)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
+ args.decoder_layers = getattr(args, "decoder_layers", 3)
+ base_architecture(args)
+
+
+@register_model_architecture(
+ "multilingual_transformer_from_mbart", "multilingual_small"
+)
+def multilingual_small(args):
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
+ args.encoder_layers = getattr(args, "encoder_layers", 3)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
+ args.decoder_layers = getattr(args, "decoder_layers", 3)
+ base_multilingual_architecture(args)
diff --git a/examples/emotion_conversion/preprocess/__init__.py b/examples/emotion_conversion/preprocess/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/emotion_conversion/preprocess/build_hifigan_manifest.py b/examples/emotion_conversion/preprocess/build_hifigan_manifest.py
new file mode 100644
index 0000000000..29c0d79cee
--- /dev/null
+++ b/examples/emotion_conversion/preprocess/build_hifigan_manifest.py
@@ -0,0 +1,38 @@
+import torchaudio
+import argparse
+import json
+
+def main():
+ parser = argparse.ArgumentParser(description="example: python create_hifigan_manifest.py --tsv /checkpoint/felixkreuk/datasets/vctk/splits/vctk_16khz/train.tsv --km /checkpoint/felixkreuk/experiments/hubert/hubert_feats/vctk_16khz_km_100/train.km --km_type hubert_100km > ~/tmp/tmp_mani.txt")
+ parser.add_argument("--tsv", required=True, help="path to fairseq tsv file")
+ parser.add_argument("--km", required=True, help="path to a km file generated by HuBERT clustering")
+ parser.add_argument("--km_type", required=True, help="name of the codes in the output json (for example: 'cpc_100km')")
+ args = parser.parse_args()
+
+ km_lines = open(args.km, "r").readlines()
+ tsv_lines = open(args.tsv, "r").readlines()
+ assert len(km_lines) == len(tsv_lines) - 1, "tsv and km files are not of the same length!"
+
+ wav_root = tsv_lines[0].strip()
+ tsv_lines = tsv_lines[1:]
+
+ for tsv_line, km_line in zip(tsv_lines, km_lines):
+ tsv_line, km_line = tsv_line.strip(), km_line.strip()
+ wav_basename, wav_num_frames = tsv_line.split("\t")
+ wav_path = wav_root + "/" + wav_basename
+ wav_info = torchaudio.info(wav_path)
+ assert int(wav_num_frames) == wav_info.num_frames, "tsv duration and actual duration don't match!"
+ wav_duration = wav_info.num_frames / wav_info.sample_rate
+ manifest_line = {"audio": wav_path, "duration": wav_duration, args.km_type: km_line}
+ print(json.dumps(manifest_line))
+
+if __name__ == "__main__":
+ """
+ usage:
+ python create_hifigan_manifest.py \
+ --tsv /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/valid.tsv \
+ --km /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/hubert_km_100/valid.km \
+ --km_type hubert \
+ > /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/hubert_km_100/hifigan_valid_manifest.txt
+ """
+ main()
diff --git a/examples/emotion_conversion/preprocess/build_translation_manifests.py b/examples/emotion_conversion/preprocess/build_translation_manifests.py
new file mode 100644
index 0000000000..d38454a713
--- /dev/null
+++ b/examples/emotion_conversion/preprocess/build_translation_manifests.py
@@ -0,0 +1,258 @@
+from glob import glob
+import argparse
+from collections import defaultdict, Counter
+from itertools import combinations, product, groupby
+from pathlib import Path
+import os
+from sklearn.utils import shuffle
+import numpy as np
+import random
+from shutil import copy
+from subprocess import check_call
+
+np.random.seed(42)
+random.seed(42)
+
+
+def get_fname(s):
+ return s.split("\t")[0]
+
+def get_emotion(s):
+ return get_fname(s).split("_")[0].split("/")[1].lower()
+
+def get_utt_id(s):
+ return get_fname(s).split(".")[0].split("_")[-1]
+
+def dedup(seq):
+ """ >> remove_repetitions("1 2 2 3 100 2 2 1")
+ '1 2 3 100 2 1' """
+ seq = seq.strip().split(" ")
+ result = seq[:1]
+ reps = []
+ rep_counter = 1
+ for k in seq[1:]:
+ if k != result[-1]:
+ result += [k]
+ reps += [rep_counter]
+ rep_counter = 1
+ else:
+ rep_counter += 1
+ reps += [rep_counter]
+ assert len(reps) == len(result) and sum(reps) == len(seq)
+ return " ".join(result) + "\n" #, reps
+
+def remove_under_k(seq, k):
+ """ remove tokens that repeat less then k times in a row
+ >> remove_under_k("a a a a b c c c", 1) ==> a a a a c c c """
+ seq = seq.strip().split(" ")
+ result = []
+
+ freqs = [(k,len(list(g))) for k, g in groupby(seq)]
+ for c, f in freqs:
+ if f > k:
+ result += [c for _ in range(f)]
+ return " ".join(result) + "\n" #, reps
+
+
+def call(cmd):
+ print(cmd)
+ check_call(cmd, shell=True)
+
+
+def denoising_preprocess(path, lang, dict):
+ bin = 'fairseq-preprocess'
+ cmd = [
+ bin,
+ f'--trainpref {path}/train.{lang} --validpref {path}/valid.{lang} --testpref {path}/test.{lang}',
+ f'--destdir {path}/tokenized/{lang}',
+ '--only-source',
+ '--task multilingual_denoising',
+ '--workers 40',
+ ]
+ if dict != "":
+ cmd += [f'--srcdict {dict}']
+ cmd = " ".join(cmd)
+ call(cmd)
+
+
+def translation_preprocess(path, src_lang, trg_lang, dict, only_train=False):
+ bin = 'fairseq-preprocess'
+ cmd = [
+ bin,
+ f'--source-lang {src_lang} --target-lang {trg_lang}',
+ f'--trainpref {path}/train',
+ f'--destdir {path}/tokenized',
+ '--workers 40',
+ ]
+ if not only_train:
+ cmd += [f'--validpref {path}/valid --testpref {path}/test']
+ if dict != "":
+ cmd += [
+ f'--srcdict {dict}',
+ f'--tgtdict {dict}',
+ ]
+ cmd = " ".join(cmd)
+ call(cmd)
+
+
+def load_tsv_km(tsv_path, km_path):
+ assert tsv_path.exists() and km_path.exists()
+ tsv_lines = open(tsv_path, "r").readlines()
+ root, tsv_lines = tsv_lines[0], tsv_lines[1:]
+ km_lines = open(km_path, "r").readlines()
+ assert len(tsv_lines) == len(km_lines), ".tsv and .km should be the same length!"
+ return root, tsv_lines, km_lines
+
+
+def main():
+ desc = """
+ this script takes as input .tsv and .km files for EMOV dataset, and a pairs of emotions.
+ it generates parallel .tsv and .km files for these emotions. for exmaple:
+ ❯ python build_emov_translation_manifests.py \
+ /checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/train.tsv \
+ /checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/emov_16khz_km_100/train.km \
+ ~/tmp/emov_pairs \
+ --src-emotion amused --trg-emotion neutral \
+ --dedup --shuffle --cross-speaker --dry-run
+ """
+ parser = argparse.ArgumentParser(description=desc)
+ parser.add_argument("data", type=Path, help="path to a dir containing .tsv and .km files containing emov dataset")
+ parser.add_argument("output_path", type=Path, help="output directory with the manifests will be created")
+ parser.add_argument("-cs", "--cross-speaker", action='store_true', help="if set then translation will occur also between speakers, meaning the same sentence can be translated between different speakers (default: false)")
+ parser.add_argument("-dd", "--dedup", action='store_true', help="remove repeated tokens (example: 'aaabc=>abc')")
+ parser.add_argument("-sh", "--shuffle", action='store_true', help="shuffle the data")
+ parser.add_argument("-ae", "--autoencode", action='store_true', help="include training pairs from the same emotion (this includes examples of the same sentence uttered by different people and examples where the src and trg are the exact same seq)")
+ parser.add_argument("-dr", "--dry-run", action='store_true', help="don't write anything to disk")
+ parser.add_argument("-zs", "--zero-shot", action='store_true', help="if true, the denoising task will train on the same splits as the translation task (split by utterance id). if false, the denoising task will train on randomly sampled splits (not split by utterance id)")
+ parser.add_argument("--km-ext", default="km", help="")
+ parser.add_argument("--dict", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/fairseq.dict.txt", help="")
+ args = parser.parse_args()
+ SPEAKERS = ["bea", "jenie", "josh", "sam", "SAME"]
+ EMOTIONS = ['neutral', 'amused', 'angry', 'disgusted', 'sleepy']
+
+ suffix = ""
+ if args.cross_speaker: suffix += "_cross-speaker"
+ if args.dedup: suffix += "_dedup"
+ translation_suffix = ""
+ if args.autoencode: translation_suffix += "_autoencode"
+ denoising_suffix = ""
+ denoising_suffix += "_zeroshot" if args.zero_shot else "_nonzeroshot"
+
+ translation_dir = Path(args.output_path) / ("emov_multilingual_translation" + suffix + translation_suffix)
+ os.makedirs(translation_dir, exist_ok=True)
+ denoising_dir = Path(args.output_path) / ("emov_multilingual_denoising" + suffix + denoising_suffix)
+ os.makedirs(denoising_dir, exist_ok=True)
+
+ denoising_data = [p.name for p in (args.data / "denoising").glob("*") if "emov" not in p.name]
+
+ for split in ["train", "valid", "test"]:
+ root, tsv_lines, km_lines = load_tsv_km(
+ tsv_path = args.data / "denoising" / "emov" / f"{split}.tsv",
+ km_path = args.data / "denoising" / "emov" / f"{split}.{args.km_ext}"
+ )
+
+ # generate data for the multilingual denoising task
+ for EMOTION in EMOTIONS:
+ print("---")
+ print(split)
+ print(f"denoising: {EMOTION}")
+ emotion_tsv, emotion_km = [], []
+ for tsv_line, km_line in zip(tsv_lines, km_lines):
+ if EMOTION.lower() in tsv_line.lower():
+ km_line = km_line if not args.dedup else dedup(km_line)
+ emotion_tsv.append(tsv_line)
+ emotion_km.append(km_line)
+ print(f"{len(emotion_km)} samples")
+ open(denoising_dir / f"files.{split}.{EMOTION}", "w").writelines([root] + emotion_tsv)
+ open(denoising_dir / f"{split}.{EMOTION}", "w").writelines(emotion_km)
+
+ for data in denoising_data:
+ with open(args.data / "denoising" / data / f"{split}.{args.km_ext}", "r") as f1:
+ with open(denoising_dir / f"{split}.{data}", "w") as f2:
+ f2.writelines([l if not args.dedup else dedup(l) for l in f1.readlines()])
+
+ # start of translation preprocessing
+ root, tsv_lines, km_lines = load_tsv_km(
+ tsv_path = args.data / "translation" / f"{split}.tsv",
+ km_path = args.data / "translation" / f"{split}.{args.km_ext}"
+ )
+
+ # generate data for the multilingual translation task
+ for SRC_EMOTION in EMOTIONS:
+ TRG_EMOTIONS = EMOTIONS if args.autoencode else set(EMOTIONS) - set([SRC_EMOTION])
+ for TRG_EMOTION in TRG_EMOTIONS:
+ # when translating back to the same emotion - we dont want these emotion
+ # pairs to be part of the validation/test sets (because its not really emotion conversino)
+ # if SRC_EMOTION == TRG_EMOTION and split in ["valid", "test"]: continue
+ print("---")
+ print(split)
+ print(f"src emotions: {SRC_EMOTION}\ntrg emotions: {TRG_EMOTION}")
+
+ # create a dictionary with the following structure:
+ # output[SPEAKER][UTT_ID] = list with indexes of line from the tsv file
+ # that match the speaker and utterance id. for exmaple:
+ # output = {'sam': {'0493': [875, 1608, 1822], ...}, ...}
+ # meaning, for speaker 'sam', utterance id '0493', the indexes in tsv_lines
+ # are 875, 1608, 1822
+ spkr2utts = defaultdict(lambda: defaultdict(list))
+ for i, tsv_line in enumerate(tsv_lines):
+ speaker = tsv_line.split("/")[0]
+ if args.cross_speaker: speaker = "SAME"
+ assert speaker in SPEAKERS, "unknown speaker! make sure the .tsv contains EMOV data"
+ utt_id = get_utt_id(tsv_line)
+ spkr2utts[speaker][utt_id].append(i)
+
+ # create a tsv and km files with all the combinations for translation
+ src_tsv, trg_tsv, src_km, trg_km = [], [], [], []
+ for speaker, utt_ids in spkr2utts.items():
+ for utt_id, indices in utt_ids.items():
+ # generate all pairs
+ pairs = [(x,y) for x in indices for y in indices]
+ # self-translation
+ if SRC_EMOTION == TRG_EMOTION:
+ pairs = [(x,y) for (x,y) in pairs if x == y]
+ # filter according to src and trg emotions
+ pairs = [(x,y) for (x,y) in pairs
+ if get_emotion(tsv_lines[x]) == SRC_EMOTION and get_emotion(tsv_lines[y]) == TRG_EMOTION]
+
+ for idx1, idx2 in pairs:
+ assert get_utt_id(tsv_lines[idx1]) == get_utt_id(tsv_lines[idx2])
+ src_tsv.append(tsv_lines[idx1])
+ trg_tsv.append(tsv_lines[idx2])
+ km_line_idx1 = km_lines[idx1]
+ km_line_idx2 = km_lines[idx2]
+ km_line_idx1 = km_line_idx1 if not args.dedup else dedup(km_line_idx1)
+ km_line_idx2 = km_line_idx2 if not args.dedup else dedup(km_line_idx2)
+ src_km.append(km_line_idx1)
+ trg_km.append(km_line_idx2)
+ assert len(src_tsv) == len(trg_tsv) == len(src_km) == len(trg_km)
+ print(f"{len(src_tsv)} pairs")
+
+ if len(src_tsv) == 0:
+ raise Exception("ERROR: generated 0 pairs!")
+
+ if args.dry_run: continue
+
+ # create files
+ os.makedirs(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}", exist_ok=True)
+ open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"files.{split}.{SRC_EMOTION}", "w").writelines([root] + src_tsv)
+ open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"files.{split}.{TRG_EMOTION}", "w").writelines([root] + trg_tsv)
+ open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"{split}.{SRC_EMOTION}", "w").writelines(src_km)
+ open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"{split}.{TRG_EMOTION}", "w").writelines(trg_km)
+
+
+ # fairseq-preprocess the denoising data
+ for EMOTION in EMOTIONS + denoising_data:
+ denoising_preprocess(denoising_dir, EMOTION, args.dict)
+ os.system(f"cp {args.dict} {denoising_dir}/tokenized/dict.txt")
+
+ # fairseq-preprocess the translation data
+ os.makedirs(translation_dir / "tokenized", exist_ok=True)
+ for SRC_EMOTION in EMOTIONS:
+ TRG_EMOTIONS = EMOTIONS if args.autoencode else set(EMOTIONS) - set([SRC_EMOTION])
+ for TRG_EMOTION in TRG_EMOTIONS:
+ translation_preprocess(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}", SRC_EMOTION, TRG_EMOTION, args.dict)#, only_train=SRC_EMOTION==TRG_EMOTION)
+ os.system(f"cp -rf {translation_dir}/**/tokenized/* {translation_dir}/tokenized")
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/emotion_conversion/preprocess/create_core_manifest.py b/examples/emotion_conversion/preprocess/create_core_manifest.py
new file mode 100644
index 0000000000..b55740e00b
--- /dev/null
+++ b/examples/emotion_conversion/preprocess/create_core_manifest.py
@@ -0,0 +1,91 @@
+from pathlib import Path
+import os
+import sys
+import subprocess
+import argparse
+from datetime import datetime
+import logging
+
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s [%(levelname)s] %(message)s',
+ handlers=[logging.FileHandler('debug.log'), logging.StreamHandler()]
+)
+logger = logging.getLogger(__name__)
+
+
+def verify_dict_size(km, dict):
+ logger.info(f"verifying: {km}")
+ dict_size = len(open(dict, "r").readlines())
+ km_vocab = set(open(km, "r").read().replace("\n", " ").split(" "))
+ if "" in km_vocab: km_vocab.remove("")
+ km_vocab_size = len(km_vocab)
+ return dict_size == km_vocab_size
+
+
+def verify_files_exist(l):
+ for f in l:
+ if not f.exists():
+ logging.error(f"{f} doesn't exist!")
+ return False
+ return True
+
+
+def run_cmd(cmd, print_output=True):
+ try:
+ out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True, shell=True)
+ if print_output:
+ logger.info(f"command output:\n{out}")
+ return out
+ except subprocess.CalledProcessError as grepexc:
+ logger.info(f"error executing command!:\n{cmd}")
+ logger.info(grepexc.output)
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--tsv", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/data.tsv", type=Path)
+ parser.add_argument("--emov-km", required=True, type=Path)
+ parser.add_argument("--km", nargs='+', required=True, type=Path)
+ parser.add_argument("--seed", type=int, default=1)
+ parser.add_argument("--dict", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/fairseq.dict.txt")
+ parser.add_argument("--manifests-dir", type=Path, default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz")
+ args = parser.parse_args()
+
+ manifests_dir = args.manifests_dir
+ date = datetime.now().strftime('%d%m%y')
+ outdir = manifests_dir / f"{date}"
+
+ # verify input and create folders
+ all_kms = args.km + [args.emov_km]
+ assert verify_files_exist(all_kms), "make sure the km dir contains: train-clean-all.km, blizzard2013.km, data.km"
+ for codes in all_kms:
+ assert verify_dict_size(codes, args.dict), "dict argument doesn't match the vocabulary of the km file!"
+ assert not outdir.exists(), "data dir already exists!"
+ outdir.mkdir(parents=True, exist_ok=True)
+
+ logger.info("generating denoising split (emov)")
+ run_cmd(f"python preprocess/split_km_tsv.py {args.tsv} {args.emov_km} --destdir {outdir}/denoising/emov -sh --seed {args.seed}")
+ for codes in args.km:
+ codes_name = os.path.basename(codes)
+ run_cmd(f"python preprocess/split_km.py {codes} --destdir {outdir}/denoising/{codes_name} -sh --seed {args.seed}")
+
+ logger.info("generating translation split")
+ run_cmd(f"python preprocess/split_emov_km_tsv_by_uttid.py {args.tsv} {args.emov_km} --destdir {outdir}/translation --seed {args.seed}")
+
+ emov_code_name = os.path.basename(args.emov_km)
+ logger.info("generating hifigan split")
+ run_cmd(
+ f"mkdir -p {outdir}/hifigan &&"
+ f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/train.tsv --km {outdir}/denoising/emov/train.km > {outdir}/hifigan/train.txt &&"
+ f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/valid.tsv --km {outdir}/denoising/emov/valid.km > {outdir}/hifigan/valid.txt &&"
+ f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/test.tsv --km {outdir}/denoising/emov/test.km > {outdir}/hifigan/test.txt"
+ )
+
+ logger.info("generating fairseq manifests")
+ run_cmd(f"python preprocess/build_translation_manifests.py {outdir} {outdir}/fairseq-data -dd -cs --dict {args.dict}")
+
+ logger.info(f"finished processing data at:\n{outdir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/emotion_conversion/preprocess/extract_f0.py b/examples/emotion_conversion/preprocess/extract_f0.py
new file mode 100644
index 0000000000..4204aa4db1
--- /dev/null
+++ b/examples/emotion_conversion/preprocess/extract_f0.py
@@ -0,0 +1,57 @@
+import argparse
+from tqdm import tqdm
+from multiprocessing import Manager, Pool
+
+from scipy.io.wavfile import read
+from librosa.util import normalize
+import numpy as np
+import amfm_decompy.pYAAPT as pYAAPT
+import amfm_decompy.basic_tools as basic
+
+MAX_WAV_VALUE = 32768.0
+
+parser = argparse.ArgumentParser(description="")
+parser.add_argument("tsv", help="")
+parser.add_argument("--extractor", choices=["crepe", "pyaapt"], default="pyaapt", help="")
+parser.add_argument("--interp", action="store_true", help="")
+parser.add_argument("--n_workers", type=int, default=40, help="")
+args = parser.parse_args()
+
+tsv_lines = open(args.tsv, "r").readlines()
+root, tsv_lines = tsv_lines[0].strip(), tsv_lines[1:]
+
+
+def extract_f0(tsv_line):
+ wav_path, _ = tsv_line.split("\t")
+ wav_path = root.strip() + "/" + wav_path
+ sr, wav = read(wav_path)
+ wav = wav / MAX_WAV_VALUE
+ wav = normalize(wav) * 0.95
+
+ if args.extractor == "pyaapt":
+ frame_length = 20.0
+ pad = int(frame_length / 1000 * sr) // 2
+ wav = np.pad(wav.squeeze(), (pad, pad), "constant", constant_values=0)
+ signal = basic.SignalObj(wav, sr)
+ pitch = pYAAPT.yaapt(
+ signal,
+ **{
+ 'frame_length': frame_length,
+ 'frame_space': 5.0,
+ 'nccf_thresh1': 0.25,
+ 'tda_frame_length': 25.0
+ })
+ pitch = pitch.samp_interp[None, None, :] if args.interp else pitch.samp_values[None, None, :]
+ pitch = pitch[0, 0]
+ f0_path = wav_path.replace(".wav", ".yaapt")
+ f0_path += ".interp.f0" if args.interp else ".f0"
+ np.save(f0_path, pitch)
+
+
+def main():
+ with Pool(args.n_workers) as p:
+ r = list(tqdm(p.imap(extract_f0, tsv_lines), total=len(tsv_lines)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/emotion_conversion/preprocess/process_km.py b/examples/emotion_conversion/preprocess/process_km.py
new file mode 100644
index 0000000000..864a022105
--- /dev/null
+++ b/examples/emotion_conversion/preprocess/process_km.py
@@ -0,0 +1,40 @@
+import sys
+import argparse
+from tqdm import tqdm
+from build_emov_translation_manifests import dedup, remove_under_k
+
+
+if __name__ == "__main__":
+ """
+ this is a standalone script to process a km file
+ specifically, to dedup or remove tokens that repeat less
+ than k times in a row
+ """
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("km", type=str, help="path to km file")
+ parser.add_argument("--dedup", action='store_true')
+ parser.add_argument("--remove-under-k", type=int, default=0)
+ parser.add_argument("--output", default=None)
+ args = parser.parse_args()
+
+ if not args.dedup and args.remove_under_k == 0:
+ print("nothing to do! quitting...")
+ sys.exit(0)
+
+ km = open(args.km, "r").readlines()
+ out = []
+ for line in tqdm(km):
+ if args.remove_under_k > 0:
+ line = remove_under_k(line, args.remove_under_k)
+ if args.dedup:
+ line = dedup(line)
+ out.append(line)
+
+ path = args.km if args.output is None else args.output
+ if args.remove_under_k > 0:
+ path = path.replace(".km", f"-k{args.remove_under_k}.km")
+ if args.dedup:
+ path = path.replace(".km", f"-deduped.km")
+
+ open(path, "w").writelines(out)
+ print(f"written to {path}")
diff --git a/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py b/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py
new file mode 100644
index 0000000000..94221afba7
--- /dev/null
+++ b/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py
@@ -0,0 +1,70 @@
+from pathlib import Path
+import os
+import sys
+import argparse
+import random
+import numpy as np
+from tqdm import tqdm
+from sklearn.model_selection import train_test_split
+from build_translation_manifests import get_utt_id
+
+
+def train_val_test_split(tsv_lines, km_lines, valid_percent, test_percent, seed=42):
+ utt_ids = list(sorted(set([get_utt_id(x) for x in tsv_lines])))
+ utt_ids, valid_utt_ids, _, _ = train_test_split(utt_ids, utt_ids, test_size=valid_percent, shuffle=True, random_state=seed)
+ train_utt_ids, test_utt_ids, _, _ = train_test_split(utt_ids, utt_ids, test_size=test_percent, shuffle=True, random_state=seed)
+
+ train_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in train_utt_ids]
+ valid_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in valid_utt_ids]
+ test_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in test_utt_ids]
+
+ train_tsv, train_km = [tsv_lines[i] for i in train_idx], [km_lines[i] for i in train_idx]
+ valid_tsv, valid_km = [tsv_lines[i] for i in valid_idx], [km_lines[i] for i in valid_idx]
+ test_tsv, test_km = [tsv_lines[i] for i in test_idx], [km_lines[i] for i in test_idx]
+
+ print(f"train {len(train_km)}")
+ print(f"valid {len(valid_km)}")
+ print(f"test {len(test_km)}")
+
+ return train_tsv, train_km, valid_tsv, valid_km, test_tsv, test_km
+
+
+if __name__ == "__main__":
+ """
+ this is a standalone script to process a km file
+ specifically, to dedup or remove tokens that repeat less
+ than k times in a row
+ """
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("tsv", type=str, help="path to tsv file")
+ parser.add_argument("km", type=str, help="path to km file")
+ parser.add_argument("--destdir", required=True, type=str)
+ parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set")
+ parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set")
+ parser.add_argument("--seed", type=int, default=42, help="")
+ args = parser.parse_args()
+
+ np.random.seed(args.seed)
+ random.seed(args.seed)
+
+ os.makedirs(args.destdir, exist_ok=True)
+ km = open(args.km, "r").readlines()
+ tsv = open(args.tsv, "r").readlines()
+ root, tsv = tsv[0], tsv[1:]
+
+ assert args.tsv.endswith(".tsv") and args.km.endswith(".km")
+ assert len(tsv) == len(km)
+
+ train_tsv, train_km, valid_tsv, valid_km, test_tsv, test_km = train_val_test_split(tsv, km, args.valid_percent, args.test_percent, args.seed)
+
+ assert len(train_tsv) + len(valid_tsv) + len(test_tsv) == len(tsv)
+ assert len(train_tsv) == len(train_km) and len(valid_tsv) == len(valid_km) and len(test_tsv) == len(test_km)
+
+ dir = Path(args.destdir)
+ open(dir / f"train.tsv", "w").writelines([root] + train_tsv)
+ open(dir / f"valid.tsv", "w").writelines([root] + valid_tsv)
+ open(dir / f"test.tsv", "w").writelines([root] + test_tsv)
+ open(dir / f"train.km", "w").writelines(train_km)
+ open(dir / f"valid.km", "w").writelines(valid_km)
+ open(dir / f"test.km", "w").writelines(test_km)
+ print("done")
diff --git a/examples/emotion_conversion/preprocess/split_km.py b/examples/emotion_conversion/preprocess/split_km.py
new file mode 100644
index 0000000000..d145fc2bde
--- /dev/null
+++ b/examples/emotion_conversion/preprocess/split_km.py
@@ -0,0 +1,50 @@
+from pathlib import Path
+import os
+import argparse
+import random
+import numpy as np
+from sklearn.utils import shuffle
+
+
+if __name__ == "__main__":
+ """
+ this is a standalone script to process a km file
+ specifically, to dedup or remove tokens that repeat less
+ than k times in a row
+ """
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("km", type=str, help="path to km file")
+ parser.add_argument("--destdir", required=True, type=str)
+ parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set")
+ parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set")
+ parser.add_argument("-sh", "--shuffle", action="store_true", help="path to km file")
+ parser.add_argument("--seed", type=int, default=42, help="")
+ args = parser.parse_args()
+
+ np.random.seed(args.seed)
+ random.seed(args.seed)
+
+ os.makedirs(args.destdir, exist_ok=True)
+ km = open(args.km, "r").readlines()
+
+ if args.shuffle:
+ km = shuffle(km)
+ print(f"shuffled")
+
+ N = len(km)
+ N_tt = int(N * args.test_percent)
+ N_cv = int(N * args.valid_percent)
+ N_tr = N - N_tt - N_cv
+
+ train_km = km[:N_tr]
+ valid_km = km[N_tr:N_tr + N_cv]
+ test_km = km[N_tr + N_cv:]
+
+ dir = Path(args.destdir)
+ open(dir / f"train.km", "w").writelines(train_km)
+ open(dir / f"valid.km", "w").writelines(valid_km)
+ open(dir / f"test.km", "w").writelines(test_km)
+ print(f"train: {len(train_km)}")
+ print(f"valid: {len(valid_km)}")
+ print(f"test: {len(test_km)}")
+ print("done")
diff --git a/examples/emotion_conversion/preprocess/split_km_tsv.py b/examples/emotion_conversion/preprocess/split_km_tsv.py
new file mode 100644
index 0000000000..2113aa718d
--- /dev/null
+++ b/examples/emotion_conversion/preprocess/split_km_tsv.py
@@ -0,0 +1,65 @@
+from pathlib import Path
+import os
+import argparse
+import random
+import numpy as np
+from sklearn.utils import shuffle
+
+
+if __name__ == "__main__":
+ """
+ this is a standalone script to process a km file
+ specifically, to dedup or remove tokens that repeat less
+ than k times in a row
+ """
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("tsv", type=str, help="path to tsv file")
+ parser.add_argument("km", type=str, help="path to km file")
+ parser.add_argument("--destdir", required=True, type=str)
+ parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set")
+ parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set")
+ parser.add_argument("-sh", "--shuffle", action="store_true", help="path to km file")
+ parser.add_argument("--seed", type=int, default=42, help="")
+ args = parser.parse_args()
+
+ np.random.seed(args.seed)
+ random.seed(args.seed)
+
+ os.makedirs(args.destdir, exist_ok=True)
+ km = open(args.km, "r").readlines()
+ tsv = open(args.tsv, "r").readlines()
+ root, tsv = tsv[0], tsv[1:]
+
+ assert args.tsv.endswith(".tsv") and args.km.endswith(".km")
+ assert len(tsv) == len(km)
+
+ if args.shuffle:
+ tsv, km = shuffle(tsv, km)
+ print(f"shuffled")
+
+ N = len(tsv)
+ N_tt = int(N * args.test_percent)
+ N_cv = int(N * args.valid_percent)
+ N_tr = N - N_tt - N_cv
+
+ train_tsv = tsv[:N_tr]
+ valid_tsv = tsv[N_tr:N_tr + N_cv]
+ test_tsv = tsv[N_tr + N_cv:]
+ train_km = km[:N_tr]
+ valid_km = km[N_tr:N_tr + N_cv]
+ test_km = km[N_tr + N_cv:]
+
+ assert len(train_tsv) + len(valid_tsv) + len(test_tsv) == len(tsv)
+ assert len(train_tsv) == len(train_km) and len(valid_tsv) == len(valid_km) and len(test_tsv) == len(test_km)
+
+ dir = Path(args.destdir)
+ open(dir / f"train.tsv", "w").writelines([root] + train_tsv)
+ open(dir / f"valid.tsv", "w").writelines([root] + valid_tsv)
+ open(dir / f"test.tsv", "w").writelines([root] + test_tsv)
+ open(dir / f"train.km", "w").writelines(train_km)
+ open(dir / f"valid.km", "w").writelines(valid_km)
+ open(dir / f"test.km", "w").writelines(test_km)
+ print(f"train: {len(train_km)}")
+ print(f"valid: {len(valid_km)}")
+ print(f"test: {len(test_km)}")
+ print("done")
diff --git a/examples/emotion_conversion/requirements.txt b/examples/emotion_conversion/requirements.txt
new file mode 100644
index 0000000000..fc94c5a547
--- /dev/null
+++ b/examples/emotion_conversion/requirements.txt
@@ -0,0 +1,11 @@
+scipy
+einops
+amfm_decompy
+joblib
+numba
+decorator
+requests
+appdirs
+packaging
+six
+sklearn
diff --git a/examples/emotion_conversion/synthesize.py b/examples/emotion_conversion/synthesize.py
new file mode 100644
index 0000000000..327fdaf4ea
--- /dev/null
+++ b/examples/emotion_conversion/synthesize.py
@@ -0,0 +1,322 @@
+import logging
+import argparse
+import random
+import sys
+import os
+import numpy as np
+import torch
+import soundfile as sf
+import shutil
+import librosa
+import json
+from pathlib import Path
+from tqdm import tqdm
+import amfm_decompy.basic_tools as basic
+import amfm_decompy.pYAAPT as pYAAPT
+
+dir_path = os.path.dirname(__file__)
+resynth_path = os.path.dirname(os.path.abspath(__file__)) + "/speech-resynthesis"
+sys.path.append(resynth_path)
+
+from models import CodeGenerator
+from inference import scan_checkpoint, load_checkpoint, generate
+from emotion_models.pitch_predictor import load_ckpt as load_pitch_predictor
+from emotion_models.duration_predictor import load_ckpt as load_duration_predictor
+from dataset import load_audio, MAX_WAV_VALUE, parse_style, parse_speaker, EMOV_SPK2ID, EMOV_STYLE2ID
+
+
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s [%(levelname)s] %(message)s',
+ handlers=[logging.FileHandler('debug.log'), logging.StreamHandler()]
+)
+logger = logging.getLogger(__name__)
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def parse_generation_file(fname):
+ lines = open(fname).read()
+ lines = lines.split('\n')
+
+ results = {}
+ for l in lines:
+ if len(l) == 0:
+ continue
+
+ if l[0] == 'H':
+ parts = l[2:].split('\t')
+ if len(parts) == 2:
+ sid, utt = parts
+ else:
+ sid, _, utt = parts
+ sid = int(sid)
+ utt = [int(x) for x in utt.split()]
+ if sid in results:
+ results[sid]['H'] = utt
+ else:
+ results[sid] = {'H': utt}
+ elif l[0] == 'S':
+ sid, utt = l[2:].split('\t')
+ sid = int(sid)
+ utt = [x for x in utt.split()]
+ if sid in results:
+ results[sid]['S'] = utt
+ else:
+ results[sid] = {'S': utt}
+ elif l[0] == 'T':
+ sid, utt = l[2:].split('\t')
+ sid = int(sid)
+ utt = [int(x) for x in utt.split()]
+ if sid in results:
+ results[sid]['T'] = utt
+ else:
+ results[sid] = {'T': utt}
+
+ for d, result in results.items():
+ if 'H' not in result:
+ result['H'] = result['S']
+
+ return results
+
+
+def get_code_to_fname(manifest, tokens):
+ if tokens is None:
+ code_to_fname = {}
+ with open(manifest) as f:
+ for line in f:
+ line = line.strip()
+ fname, code = line.split()
+ code = code.replace(',', ' ')
+ code_to_fname[code] = fname
+
+ return code_to_fname
+
+ with open(manifest) as f:
+ fnames = [l.strip() for l in f.readlines()]
+ root = Path(fnames[0])
+ fnames = fnames[1:]
+ if '\t' in fnames[0]:
+ fnames = [x.split()[0] for x in fnames]
+
+ with open(tokens) as f:
+ codes = [l.strip() for l in f.readlines()]
+
+ code_to_fname = {}
+ for fname, code in zip(fnames, codes):
+ code = code.replace(',', ' ')
+ code_to_fname[code] = str(root / fname)
+
+ return root, code_to_fname
+
+
+def code_to_str(s):
+ k = ' '.join([str(x) for x in s])
+ return k
+
+
+def get_praat_f0(audio, rate=16000, interp=False):
+ frame_length = 20.0
+ to_pad = int(frame_length / 1000 * rate) // 2
+
+ f0s = []
+ for y in audio.astype(np.float64):
+ y_pad = np.pad(y.squeeze(), (to_pad, to_pad), "constant", constant_values=0)
+ signal = basic.SignalObj(y_pad, rate)
+ pitch = pYAAPT.yaapt(signal, **{'frame_length': frame_length, 'frame_space': 5.0, 'nccf_thresh1': 0.25,
+ 'tda_frame_length': 25.0})
+ if interp:
+ f0s += [pitch.samp_interp[None, None, :]]
+ else:
+ f0s += [pitch.samp_values[None, None, :]]
+
+ f0 = np.vstack(f0s)
+ return f0
+
+
+def generate_from_code(generator, h, code, spkr=None, f0=None, gst=None, device="cpu"):
+ batch = {
+ 'code': torch.LongTensor(code).to(device).view(1, -1),
+ }
+ if spkr is not None:
+ batch['spkr'] = spkr.to(device).unsqueeze(0)
+ if f0 is not None:
+ batch['f0'] = f0.to(device)
+ if gst is not None:
+ batch['style'] = gst.to(device)
+
+ with torch.no_grad():
+ audio, rtf = generate(h, generator, batch)
+ audio = librosa.util.normalize(audio / 2 ** 15)
+
+ return audio
+
+
+@torch.no_grad()
+def synth(argv, interactive=False):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--result-path', type=Path, help='Translation Model Output', required=True)
+ parser.add_argument('--data', type=Path, help='a directory with the files: src.tsv, src.km, trg.tsv, trg.km, orig.tsv, orig.km')
+ parser.add_argument("--orig-tsv", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/data.tsv")
+ parser.add_argument("--orig-km", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/core_manifests/emov_16khz_km_100/data.km")
+
+ parser.add_argument('--checkpoint-file', type=Path, help='Generator Checkpoint', required=True)
+ parser.add_argument('--dur-model', type=Path, help='a token duration prediction model (if tokens were deduped)')
+ parser.add_argument('--f0-model', type=Path, help='a f0 prediction model')
+
+ parser.add_argument('-s', '--src-emotion', default=None)
+ parser.add_argument('-t', '--trg-emotion', default=None)
+ parser.add_argument('-N', type=int, default=10)
+ parser.add_argument('--split', default="test")
+
+ parser.add_argument('--outdir', type=Path, default=Path('results'))
+ parser.add_argument('--orig-filename', action='store_true')
+
+ parser.add_argument('--device', type=int, default=0)
+ a = parser.parse_args(argv)
+
+ seed = 52
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ if os.path.isdir(a.checkpoint_file):
+ config_file = os.path.join(a.checkpoint_file, 'config.json')
+ else:
+ config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
+ with open(config_file) as f:
+ data = f.read()
+ json_config = json.loads(data)
+ h = AttrDict(json_config)
+
+ generator = CodeGenerator(h).to(a.device)
+ if os.path.isdir(a.checkpoint_file):
+ cp_g = scan_checkpoint(a.checkpoint_file, 'g_')
+ else:
+ cp_g = a.checkpoint_file
+ state_dict_g = load_checkpoint(cp_g)
+ generator.load_state_dict(state_dict_g['generator'])
+
+ generator.eval()
+ generator.remove_weight_norm()
+
+ dur_models = {
+ "neutral": load_duration_predictor(f"{a.dur_model}/neutral.ckpt"),
+ "amused": load_duration_predictor(f"{a.dur_model}/amused.ckpt"),
+ "disgusted": load_duration_predictor(f"{a.dur_model}/disgusted.ckpt"),
+ "angry": load_duration_predictor(f"{a.dur_model}/angry.ckpt"),
+ "sleepy": load_duration_predictor(f"{a.dur_model}/sleepy.ckpt"),
+ }
+ logger.info(f"loaded duration prediction model from {a.dur_model}")
+
+ f0_model = load_pitch_predictor(a.f0_model).to(a.device)
+ logger.info(f"loaded f0 prediction model from {a.f0_model}")
+
+ # we need to know how to map code back to the filename
+ # (if we want the original files names as output)
+ results = parse_generation_file(a.result_path)
+ _, src_code_to_fname = get_code_to_fname(f'{a.data}/files.{a.split}.{a.src_emotion}', f'{a.data}/{a.split}.{a.src_emotion}')
+ _, tgt_code_to_fname = get_code_to_fname(f'{a.data}/files.{a.split}.{a.trg_emotion}', f'{a.data}/{a.split}.{a.trg_emotion}')
+
+ # we need the originals (before dedup) to get the ground-truth durations
+ orig_tsv = open(a.orig_tsv, 'r').readlines()
+ orig_tsv_root, orig_tsv = orig_tsv[0].strip(), orig_tsv[1:]
+ orig_km = open(a.orig_km, 'r').readlines()
+ fname_to_idx = {orig_tsv_root + "/" + line.split("\t")[0]: i for i, line in enumerate(orig_tsv)}
+
+ outdir = a.outdir
+ outdir.mkdir(parents=True, exist_ok=True)
+ (outdir / '0-source').mkdir(exist_ok=True)
+ (outdir / '1-src-tokens-src-style-src-f0').mkdir(exist_ok=True)
+ (outdir / '2-src-tokens-trg-style-src-f0').mkdir(exist_ok=True)
+ (outdir / '2.5-src-tokens-trg-style-src-f0').mkdir(exist_ok=True)
+ (outdir / '3-src-tokens-trg-style-pred-f0').mkdir(exist_ok=True)
+ (outdir / '4-gen-tokens-trg-style-pred-f0').mkdir(exist_ok=True)
+ (outdir / '5-target').mkdir(exist_ok=True)
+
+ N = 0
+ results = list(results.items())
+ random.shuffle(results)
+ for i, (sid, result) in tqdm(enumerate(results)):
+ N += 1
+ if N > a.N and a.N != -1:
+ break
+
+ if '[' in result['S'][0]:
+ result['S'] = result['S'][1:]
+ if '_' in result['S'][-1]:
+ result['S'] = result['S'][:-1]
+ src_ref = src_code_to_fname[code_to_str(result['S'])]
+ trg_ref = tgt_code_to_fname[code_to_str(result['T'])]
+
+ src_style, trg_style = None, None
+ src_spkr, trg_spkr = None, None
+ src_f0 = None
+ src_audio = (load_audio(src_ref)[0] / MAX_WAV_VALUE) * 0.95
+ trg_audio = (load_audio(trg_ref)[0] / MAX_WAV_VALUE) * 0.95
+ src_audio = torch.FloatTensor(src_audio).unsqueeze(0).cuda()
+ trg_audio = torch.FloatTensor(trg_audio).unsqueeze(0).cuda()
+
+ src_spkr = parse_speaker(src_ref, h.multispkr)
+ src_spkr = src_spkr if src_spkr in EMOV_SPK2ID else random.choice(list(EMOV_SPK2ID.keys()))
+ src_spkr = EMOV_SPK2ID[src_spkr]
+ src_spkr = torch.LongTensor([src_spkr])
+ trg_spkr = parse_speaker(trg_ref, h.multispkr)
+ trg_spkr = trg_spkr if trg_spkr in EMOV_SPK2ID else random.choice(list(EMOV_SPK2ID.keys()))
+ trg_spkr = EMOV_SPK2ID[trg_spkr]
+ trg_spkr = torch.LongTensor([trg_spkr])
+
+ src_style = EMOV_STYLE2ID[a.src_emotion]
+ src_style = torch.LongTensor([src_style]).cuda()
+ trg_style_str = a.trg_emotion
+ trg_style = EMOV_STYLE2ID[a.trg_emotion]
+ trg_style = torch.LongTensor([trg_style]).cuda()
+
+ src_tokens = list(map(int, orig_km[fname_to_idx[src_ref]].strip().split(" ")))
+ src_tokens = torch.LongTensor(src_tokens).unsqueeze(0)
+ src_tokens_dur_pred = torch.LongTensor(list(map(int, result['S']))).unsqueeze(0)
+ src_tokens_dur_pred = dur_models[trg_style_str].inflate_input(src_tokens_dur_pred)
+ gen_tokens = torch.LongTensor(result['H']).unsqueeze(0)
+ gen_tokens = dur_models[trg_style_str].inflate_input(gen_tokens)
+ trg_tokens = torch.LongTensor(result['T']).unsqueeze(0)
+ trg_tokens = dur_models[trg_style_str].inflate_input(trg_tokens)
+
+ src_f0 = get_praat_f0(src_audio.unsqueeze(0).cpu().numpy())
+ src_f0 = torch.FloatTensor(src_f0).cuda()
+
+ pred_src_f0 = f0_model.inference(torch.LongTensor(src_tokens).to(a.device), src_spkr, trg_style).unsqueeze(0)
+ pred_src_dur_pred_f0 = f0_model.inference(torch.LongTensor(src_tokens_dur_pred).to(a.device), src_spkr, trg_style).unsqueeze(0)
+ pred_gen_f0 = f0_model.inference(torch.LongTensor(gen_tokens).to(a.device), src_spkr, trg_style).unsqueeze(0)
+ pred_trg_f0 = f0_model.inference(torch.LongTensor(trg_tokens).to(a.device), src_spkr, trg_style).unsqueeze(0)
+
+ if a.orig_filename:
+ path = src_code_to_fname[code_to_str(result['S'])]
+ sid = str(sid) + "__" + Path(path).stem
+ shutil.copy(src_code_to_fname[code_to_str(result['S'])], outdir / '0-source' / f'{sid}.wav')
+
+ audio = generate_from_code(generator, h, src_tokens, spkr=src_spkr, f0=src_f0, gst=src_style, device=a.device)
+ sf.write(outdir / '1-src-tokens-src-style-src-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate)
+
+ audio = generate_from_code(generator, h, src_tokens, spkr=src_spkr, f0=src_f0, gst=trg_style, device=a.device)
+ sf.write(outdir / '2-src-tokens-trg-style-src-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate)
+
+ audio = generate_from_code(generator, h, src_tokens_dur_pred, spkr=src_spkr, f0=src_f0, gst=trg_style, device=a.device)
+ sf.write(outdir / '2.5-src-tokens-trg-style-src-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate)
+
+ audio = generate_from_code(generator, h, src_tokens_dur_pred, spkr=src_spkr, f0=pred_src_dur_pred_f0, gst=trg_style, device=a.device)
+ sf.write(outdir / '3-src-tokens-trg-style-pred-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate)
+
+ audio = generate_from_code(generator, h, gen_tokens, spkr=src_spkr, f0=pred_gen_f0, gst=trg_style, device=a.device)
+ sf.write(outdir / '4-gen-tokens-trg-style-pred-f0' / f'{sid}.wav', audio, samplerate=h.sampling_rate)
+
+ shutil.copy(tgt_code_to_fname[code_to_str(result['T'])], outdir / '5-target' / f'{sid}.wav')
+
+ logger.info("Done.")
+
+
+if __name__ == '__main__':
+ synth(sys.argv[1:])
diff --git a/examples/fast_noisy_channel/README.md b/examples/fast_noisy_channel/README.md
new file mode 100644
index 0000000000..f2631a8c34
--- /dev/null
+++ b/examples/fast_noisy_channel/README.md
@@ -0,0 +1,345 @@
+# Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling
+
+## Introduction
+- [Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) introduce a simple and effective noisy channel modeling approach for neural machine translation. However, the noisy channel online decoding approach introduced in this paper is too slow to be practical.
+- To address this, [Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 simple approximations to make this approach very fast and practical without much loss in accuracy.
+- This README provides intructions on how to run online decoding or generation with the noisy channel modeling approach, including ways to make it very fast without much loss in accuracy.
+
+## Noisy Channel Modeling
+
+[Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) applies the Bayes Rule to predict `P(y|x)`, the probability of the target `y` given the source `x`.
+```P(y|x) = P(x|y) * P(y) / P(x)```
+- `P(x|y)` predicts the source `x` given the target `y` and is referred to as the **channel model**
+- `P(y)` is a **language model** over the target `y`
+- `P(x)` is generally not modeled since it is constant for all `y`.
+
+We use Transformer models to parameterize the direct model `P(y|x)`, the channel model `P(x|y)` and the language model `P(y)`.
+
+During online decoding with beam search, we generate the top `K2` candidates per beam and score them with the following linear combination of the channel model, the language model as well as the direct model scores.
+
+```(1 / t) * log(P(y|x) + (1 / s) * ( λ1 * log(P(x|y)) + λ2 * log(P(y) ) )```
+- `t` - Target Prefix Length
+- `s` - Source Length
+- `λ1` - Channel Model Weight
+- `λ2` - Language Model Weight
+
+The top `beam_size` candidates based on the above combined scores are chosen to continue the beams in beam search. In beam search with a direct model alone, the scores from the direct model `P(y|x)` are used to choose the top candidates in beam search.
+
+This framework provides a great way to utlize strong target language models trained on large amounts of unlabeled data. Language models can prefer targets unrelated to the source, so we also need a channel model whose role is to ensure that the target preferred by the language model also translates back to the source.
+
+### Training Translation Models and Language Models
+
+For training Transformer models in fairseq for machine translation, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/translation)
+
+For training Transformer models in fairseq for language modeling, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/language_model)
+
+### Generation with Language Model for German-English translation with fairseq
+
+Here are instructions to generate using a direct model and a target-side language model.
+
+Note:
+- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
+- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
+
+```sh
+binarized_data=data_dir/binarized
+direct_model=de_en_seed4.pt
+lm_model=en_lm.pt
+lm_data=lm_data
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
+mkdir -p ${lm_data}
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
+
+k2=10
+lenpen=0.16
+lm_wt=0.14
+fairseq-generate ${binarized_data} \
+ --user-dir examples/fast_noisy_channel \
+ --beam 5 \
+ --path ${direct_model} \
+ --lm-model ${lm_model} \
+ --lm-data ${lm_data} \
+ --k2 ${k2} \
+ --combine-method lm_only \
+ --task noisy_channel_translation \
+ --lenpen ${lenpen} \
+ --lm-wt ${lm_wt} \
+ --gen-subset valid \
+ --remove-bpe \
+ --fp16 \
+ --batch-size 10
+```
+### Noisy Channel Generation for German-English translation with fairseq
+
+Here are instructions for noisy channel generation with a direct model, channel model and language model as explained in section [Noisy Channel Modeling](#noisy-channel-modeling).
+
+Note:
+- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
+- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
+
+```sh
+binarized_data=data_dir/binarized
+direct_model=de_en_seed4.pt
+lm_model=en_lm.pt
+lm_data=lm_data
+ch_model=en_de.big.seed4.pt
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
+mkdir -p ${lm_data}
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt -O ${ch_model}
+
+k2=10
+lenpen=0.21
+lm_wt=0.50
+bw_wt=0.30
+fairseq-generate ${binarized_data} \
+ --user-dir examples/fast_noisy_channel \
+ --beam 5 \
+ --path ${direct_model} \
+ --lm-model ${lm_model} \
+ --lm-data ${lm_data} \
+ --channel-model ${ch_model} \
+ --k2 ${k2} \
+ --combine-method noisy_channel \
+ --task noisy_channel_translation \
+ --lenpen ${lenpen} \
+ --lm-wt ${lm_wt} \
+ --ch-wt ${bw_wt} \
+ --gen-subset test \
+ --remove-bpe \
+ --fp16 \
+ --batch-size 1
+```
+## Fast Noisy Channel Modeling
+
+[Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 approximations that speed up online noisy channel decoding -
+- Smaller channel models (`Tranformer Base` with 1 encoder and decoder layer each vs. `Transformer Big`)
+ - This involves training a channel model that is possibly smaller and less accurate in terms of BLEU than a channel model of the same size as the direct model.
+ - Since the role of the channel model is mainly to assign low scores to generations from the language model if they don't translate back to the source, we may not need the most accurate channel model for this purpose.
+- Smaller output vocabulary size for the channel model (~30,000 -> ~1000)
+ - The channel model doesn't need to score the full output vocabulary, it just needs to score the source tokens, which are completely known.
+ - This is specified using the arguments `--channel-scoring-type src_vocab --top-k-vocab 500`
+ - This means that the output vocabulary for the channel model will be the source tokens for all examples in the batch and the top-K most frequent tokens in the vocabulary
+ - This reduces the memory consumption needed to store channel model scores significantly
+- Smaller number of candidates (`k2`) scored per beam
+ - This is specified by reducing the argument `--k2`
+
+
+### Fast Noisy Channel Generation for German-English translation with fairseq
+
+Here are instructions for **fast** noisy channel generation with a direct model, channel model and language model as explained in section [Fast Noisy Channel Modeling](#fast-noisy-channel-modeling). The main differences are that we use a smaller channel model, reduce `--k2`, set `--channel-scoring-type src_vocab --top-k-vocab 500` and increase the `--batch-size`.
+
+Note:
+- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
+- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
+
+```sh
+binarized_data=data_dir/binarized
+direct_model=de_en_seed4.pt
+lm_model=en_lm.pt
+lm_data=lm_data
+small_ch_model=en_de.base_1_1.seed4.pt
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
+mkdir -p ${lm_data}
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
+wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt -O ${small_ch_model}
+
+k2=3
+lenpen=0.23
+lm_wt=0.58
+bw_wt=0.26
+fairseq-generate ${binarized_data} \
+ --user-dir examples/fast_noisy_channel \
+ --beam 5 \
+ --path ${direct_model} \
+ --lm-model ${lm_model} \
+ --lm-data ${lm_data} \
+ --channel-model ${small_ch_model} \
+ --k2 ${k2} \
+ --combine-method noisy_channel \
+ --task noisy_channel_translation \
+ --lenpen ${lenpen} \
+ --lm-wt ${lm_wt} \
+ --ch-wt ${bw_wt} \
+ --gen-subset test \
+ --remove-bpe \
+ --fp16 \
+ --batch-size 50 \
+ --channel-scoring-type src_vocab --top-k-vocab 500
+```
+
+## Test Data Preprocessing
+
+For preprocessing and binarizing the test sets for Romanian-English and German-English translation, we use the following script -
+
+```sh
+FAIRSEQ=/path/to/fairseq
+cd $FAIRSEQ
+SCRIPTS=$FAIRSEQ/mosesdecoder/scripts
+if [ ! -d "${SCRIPTS}" ]; then
+ echo 'Cloning Moses github repository (for tokenization scripts)...'
+ git clone https://github.com/moses-smt/mosesdecoder.git
+fi
+TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
+NORMALIZE=$SCRIPTS/tokenizer/normalize-punctuation.perl
+
+s=de
+t=en
+test=wmt18
+
+mkdir -p data_dir
+
+# Tokenization
+if [ $s == "ro" ] ; then
+ # Note: Get normalise-romanian.py and remove-diacritics.py from
+ # https://github.com/rsennrich/wmt16-scripts/tree/master/preprocess
+ sacrebleu -t $test -l $s-$t --echo src | \
+ $NORMALIZE -l $s | \
+ python normalise-romanian.py | \
+ python remove-diacritics.py | \
+ $TOKENIZER -l $s -a -q > data_dir/$test.$s-$t.$s
+else
+ sacrebleu -t $test -l $s-$t --echo src | perl $NORMALIZE -l $s | perl $TOKENIZER -threads 8 -a -l $s > data_dir/$test.$s-$t.$s
+fi
+
+sacrebleu -t $test -l $s-$t --echo ref | perl $NORMALIZE -l $t | perl $TOKENIZER -threads 8 -a -l $t > data_dir/$test.$s-$t.$t
+
+
+# Applying BPE
+src_bpe_code=/path/to/source/language/bpe/code
+tgt_bpe_code=/path/to/target/language/bpe/code
+src_dict=/path/to/source/language/dict
+tgt_dict=/path/to/target/language/dict
+
+FASTBPE=$FAIRSEQ/fastBPE
+if [ ! -d "${FASTBPE}" ] ; then
+ git clone https://github.com/glample/fastBPE.git
+ # Follow compilation instructions at https://github.com/glample/fastBPE
+ g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
+fi
+
+${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${src_bpe_code}
+${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${tgt_bpe_code}
+
+fairseq-preprocess -s $s -t $t \
+ --testpref data_dir/bpe.$test.$s-$t \
+ --destdir data_dir/binarized \
+ --srcdict ${src_dict} \
+ --tgtdict ${tgt_dict}
+```
+
+## Calculating BLEU
+
+```sh
+DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl
+cat ${generation_output} | grep -P "^H" | sort -V | cut -f 3- | $DETOKENIZER -l $t -q -a | sacrebleu -t $test -l $s-$t
+```
+
+
+## Romanian-English Translation
+
+The direct and channel models are trained using bitext data (WMT16) combined with backtranslated data (The monolingual data used for backtranslation comes from http://data.statmt.org/rsennrich/wmt16_backtranslations/ (Sennrich et al., 2016c))
+
+The backtranslated data is generated using an ensemble of 3 English-Romanian models trained on bitext training data (WMT16) with unrestricted sampling.
+
+### BPE Codes and Dictionary
+
+We learn a joint BPE vocabulary of 18K types on the bitext training data which is used for both the source and target.
+||Path|
+|----------|------|
+| BPE Code | [joint_bpe_18k](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/bpe_18k) |
+| Dictionary | [dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/dict) |
+
+### Direct Models
+For Ro-En with backtranslation, the direct and channel models use a Transformer-Big architecture.
+
+| Seed | Model |
+|----|----|
+| 2 | [ro_en_seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed2.pt)
+| 4 | [ro_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed4.pt)
+| 6 | [ro_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed6.pt)
+
+### Channel Models
+For channel models, we follow the same steps as for the direct models. But backtranslated data is generated in the opposite direction using [this Romanian monolingual data](http://data.statmt.org/rsennrich/wmt16_backtranslations/).
+The best lenpen, LM weight and CH weight are obtained by sweeping over the validation set (wmt16/dev) using beam 5.
+| Model Size | Lenpen | LM Weight | CH Weight | Seed 2 | Seed 4 | Seed 6 |
+|----|----|----|----|----|----|----|
+| `big` | 0.84 | 0.64 | 0.56 | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) |
+| `base_1_1` | 0.63 | 0.40 | 0.37 | [base_1_1.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed2.pt) | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed6.pt) |
+
+### Language Model
+The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
+| | Path |
+|----|----|
+| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/transformer_lm.pt) |
+| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/lm_dict)
+
+## German-English Translation
+
+### BPE Codes and Dictionaries
+
+| | Path|
+|----------|------|
+| Source BPE Code | [de_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_bpe_code_24K) |
+| Target BPE Code | [en_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_bpe_code_24K)
+| Source Dictionary | [de_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_dict) |
+| Target Dictionary | [en_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_dict) |
+
+### Direct Models
+We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
+We use the Transformer-Big architecture for the direct model.
+
+| Seed | Model |
+|:----:|----|
+| 4 | [de_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt)
+| 5 | [de_en_seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed5.pt)
+| 6 | [de_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed6.pt)
+
+### Channel Models
+
+We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
+
+| Model Size | Seed 4 | Seed 5 | Seed 6 |
+|----|----|----|----|
+| `big` | [big.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt) | [big.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed5.pt) | [big.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed6.pt) |
+| `big_1_1` | [big_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed4.pt) | [big_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed5.pt) | [big_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed6.pt) |
+| `base` | [base.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed4.pt) | [base.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed5.pt) | [base.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed6.pt) |
+| `base_1_1` | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed5.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed6.pt) |
+| `half` | [half.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed4.pt) | [half.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed5.pt) | [half.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed6.pt) |
+| `half_1_1` | [half_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed4.pt) | [half_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed5.pt) | [half_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed6.pt) |
+| `quarter` | [quarter.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed4.pt) | [quarter.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed5.pt) | [quarter.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed6.pt) |
+| `quarter_1_1` | [quarter_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed4.pt) | [quarter_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed5.pt) | [quarter_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed6.pt) |
+| `8th` | [8th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed4.pt) | [8th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed5.pt) | [8th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed6.pt) |
+| `8th_1_1` | [8th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed4.pt) | [8th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed5.pt) | [8th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed6.pt) |
+| `16th` | [16th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed4.pt) | [16th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed5.pt) | [16th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed6.pt) |
+| `16th_1_1` | [16th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed4.pt) | [16th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed5.pt) | [16th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed6.pt) |
+
+### Language Model
+The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
+| | Path |
+|----|----|
+| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt) |
+| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/)
+
+
+## Citation
+
+```bibtex
+@inproceedings{bhosale2020language,
+ title={Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling},
+ author={Shruti Bhosale and Kyra Yee and Sergey Edunov and Michael Auli},
+ booktitle={Proceedings of the Fifth Conference on Machine Translation (WMT)},
+ year={2020},
+}
+
+@inproceedings{yee2019simple,
+ title={Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
+ author={Yee, Kyra and Dauphin, Yann and Auli, Michael},
+ booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
+ pages={5700--5705},
+ year={2019}
+}
+```
diff --git a/examples/fast_noisy_channel/__init__.py b/examples/fast_noisy_channel/__init__.py
new file mode 100644
index 0000000000..9b248c3a24
--- /dev/null
+++ b/examples/fast_noisy_channel/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import noisy_channel_translation # noqa
+from . import noisy_channel_sequence_generator # noqa
+from . import noisy_channel_beam_search # noqa
diff --git a/examples/fast_noisy_channel/noisy_channel_beam_search.py b/examples/fast_noisy_channel/noisy_channel_beam_search.py
new file mode 100644
index 0000000000..23869ebcd0
--- /dev/null
+++ b/examples/fast_noisy_channel/noisy_channel_beam_search.py
@@ -0,0 +1,71 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from fairseq.search import Search
+
+
+class NoisyChannelBeamSearch(Search):
+
+ def __init__(self, tgt_dict):
+ super().__init__(tgt_dict)
+ self.fw_scores_buf = None
+ self.lm_scores_buf = None
+
+ def _init_buffers(self, t):
+ # super()._init_buffers(t)
+ if self.fw_scores_buf is None:
+ self.scores_buf = t.new()
+ self.indices_buf = torch.LongTensor().to(device=t.device)
+ self.beams_buf = torch.LongTensor().to(device=t.device)
+ self.fw_scores_buf = t.new()
+ self.lm_scores_buf = t.new()
+
+ def combine_fw_bw(self, combine_method, fw_cum, bw, step):
+ if combine_method == "noisy_channel":
+ fw_norm = fw_cum.div(step + 1)
+ lprobs = bw + fw_norm
+ elif combine_method == "lm_only":
+ lprobs = bw + fw_cum
+
+ return lprobs
+
+ def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method):
+ self._init_buffers(fw_lprobs)
+ bsz, beam_size, vocab_size = fw_lprobs.size()
+
+ if step == 0:
+ # at the first step all hypotheses are equally likely, so use
+ # only the first beam
+ fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous()
+ bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous()
+ # nothing to add since we are at the first step
+ fw_lprobs_cum = fw_lprobs
+
+ else:
+ # make probs contain cumulative scores for each hypothesis
+ raw_scores = (scores[:, :, step - 1].unsqueeze(-1))
+ fw_lprobs_cum = (fw_lprobs.add(raw_scores))
+
+ combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step)
+
+ # choose the top k according to the combined noisy channel model score
+ torch.topk(
+ combined_lprobs.view(bsz, -1),
+ k=min(
+ # Take the best 2 x beam_size predictions. We'll choose the first
+ # beam_size of these which don't predict eos to continue with.
+ beam_size * 2,
+ combined_lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
+ ),
+ out=(self.scores_buf, self.indices_buf),
+ )
+ # save corresponding fw and lm scores
+ self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf)
+ self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf)
+ # Project back into relative indices and beams
+ self.beams_buf = self.indices_buf // vocab_size
+ self.indices_buf.fmod_(vocab_size)
+ return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf
diff --git a/examples/fast_noisy_channel/noisy_channel_sequence_generator.py b/examples/fast_noisy_channel/noisy_channel_sequence_generator.py
new file mode 100644
index 0000000000..ea8fae98e8
--- /dev/null
+++ b/examples/fast_noisy_channel/noisy_channel_sequence_generator.py
@@ -0,0 +1,842 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Dict, List, Optional
+
+import math
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from .noisy_channel_beam_search import NoisyChannelBeamSearch
+from fairseq.sequence_generator import EnsembleModel
+
+
+class NoisyChannelSequenceGenerator(object):
+ def __init__(
+ self,
+ combine_method,
+ tgt_dict,
+ src_dict=None,
+ beam_size=1,
+ max_len_a=0,
+ max_len_b=200,
+ min_len=1,
+ len_penalty=1.0,
+ unk_penalty=0.0,
+ retain_dropout=False,
+ temperature=1.0,
+ match_source_len=False,
+ no_repeat_ngram_size=0,
+ normalize_scores=True,
+ channel_models=None,
+ k2=10,
+ ch_weight=1.0,
+ channel_scoring_type='log_norm',
+ top_k_vocab=0,
+ lm_models=None,
+ lm_dict=None,
+ lm_weight=1.0,
+ normalize_lm_scores_by_tgt_len=False,
+ ):
+ """Generates translations of a given source sentence,
+ using beam search with noisy channel decoding.
+
+ Args:
+ combine_method (string, optional): Method to combine direct, LM and
+ channel model scores (default: None)
+ tgt_dict (~fairseq.data.Dictionary): target dictionary
+ src_dict (~fairseq.data.Dictionary): source dictionary
+ beam_size (int, optional): beam width (default: 1)
+ max_len_a/b (int, optional): generate sequences of maximum length
+ ax + b, where x is the source length
+ min_len (int, optional): the minimum length of the generated output
+ (not including end-of-sentence)
+ len_penalty (float, optional): length penalty, where <1.0 favors
+ shorter, >1.0 favors longer sentences (default: 1.0)
+ unk_penalty (float, optional): unknown word penalty, where <0
+ produces more unks, >0 produces fewer (default: 0.0)
+ retain_dropout (bool, optional): use dropout when generating
+ (default: False)
+ temperature (float, optional): temperature, where values
+ >1.0 produce more uniform samples and values <1.0 produce
+ sharper samples (default: 1.0)
+ match_source_len (bool, optional): outputs should match the source
+ length (default: False)
+ no_repeat_ngram_size (int, optional): Size of n-grams that we avoid
+ repeating in the generation (default: 0)
+ normalize_scores (bool, optional): normalize scores by the length
+ of the output (default: True)
+ channel_models (List[~fairseq.models.FairseqModel]): ensemble of models
+ translating from the target to the source
+ k2 (int, optional): Top K2 candidates to score per beam at each step (default:10)
+ ch_weight (int, optional): Weight associated with the channel model score
+ assuming that the direct model score has weight 1.0 (default: 1.0)
+ channel_scoring_type (str, optional): String specifying how to score
+ the channel model (default: 'log_norm')
+ top_k_vocab (int, optional): If `channel_scoring_type` is `'src_vocab'` or
+ `'src_vocab_batched'`, then this parameter specifies the number of
+ most frequent tokens to include in the channel model output vocabulary,
+ in addition to the source tokens in the input batch (default: 0)
+ lm_models (List[~fairseq.models.FairseqModel]): ensemble of models
+ generating text in the target language
+ lm_dict (~fairseq.data.Dictionary): LM Model dictionary
+ lm_weight (int, optional): Weight associated with the LM model score
+ assuming that the direct model score has weight 1.0 (default: 1.0)
+ normalize_lm_scores_by_tgt_len (bool, optional): Should we normalize LM scores
+ by the target length? By default, we normalize the combination of
+ LM and channel model scores by the source length
+ """
+ self.pad = tgt_dict.pad()
+ self.unk = tgt_dict.unk()
+ self.eos = tgt_dict.eos()
+ self.vocab_size = len(tgt_dict)
+ self.beam_size = beam_size
+ # the max beam size is the dictionary size - 1, since we never select pad
+ self.beam_size = min(beam_size, self.vocab_size - 1)
+ self.max_len_a = max_len_a
+ self.max_len_b = max_len_b
+ self.min_len = min_len
+ self.normalize_scores = normalize_scores
+ self.len_penalty = len_penalty
+ self.unk_penalty = unk_penalty
+ self.retain_dropout = retain_dropout
+ self.temperature = temperature
+ self.match_source_len = match_source_len
+ self.no_repeat_ngram_size = no_repeat_ngram_size
+ self.channel_models = channel_models
+ self.src_dict = src_dict
+ self.tgt_dict = tgt_dict
+ self.combine_method = combine_method
+ self.k2 = k2
+ self.ch_weight = ch_weight
+ self.channel_scoring_type = channel_scoring_type
+ self.top_k_vocab = top_k_vocab
+ self.lm_models = lm_models
+ self.lm_dict = lm_dict
+ self.lm_weight = lm_weight
+ self.log_softmax_fn = torch.nn.LogSoftmax(dim=1)
+ self.normalize_lm_scores_by_tgt_len = normalize_lm_scores_by_tgt_len
+
+ self.share_tgt_dict = (self.lm_dict == self.tgt_dict)
+ self.tgt_to_lm = make_dict2dict(tgt_dict, lm_dict)
+
+ self.ch_scoring_bsz = 3072
+
+ assert temperature > 0, '--temperature must be greater than 0'
+
+ self.search = NoisyChannelBeamSearch(tgt_dict)
+
+ @torch.no_grad()
+ def generate(
+ self,
+ models,
+ sample,
+ prefix_tokens=None,
+ bos_token=None,
+ **kwargs
+ ):
+ """Generate a batch of translations.
+ Args:
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
+ sample (dict): batch
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
+ with these tokens
+ """
+ model = EnsembleModel(models)
+ incremental_states = torch.jit.annotate(
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
+ [
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
+ for i in range(model.models_size)
+ ],
+ )
+ if not self.retain_dropout:
+ model.eval()
+
+ # model.forward normally channels prev_output_tokens into the decoder
+ # separately, but SequenceGenerator directly calls model.encoder
+ encoder_input = {
+ k: v for k, v in sample['net_input'].items()
+ if k != 'prev_output_tokens'
+ }
+ src_tokens = encoder_input['src_tokens']
+ src_lengths_no_eos = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
+ input_size = src_tokens.size()
+ # batch dimension goes first followed by source lengths
+ bsz = input_size[0]
+ src_len = input_size[1]
+ beam_size = self.beam_size
+
+ if self.match_source_len:
+ max_len = src_lengths_no_eos.max().item()
+ else:
+ max_len = min(
+ int(self.max_len_a * src_len + self.max_len_b),
+ # exclude the EOS marker
+ model.max_decoder_positions() - 1,
+ )
+
+ # compute the encoder output for each beam
+ encoder_outs = model.forward_encoder(encoder_input)
+ new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
+ new_order = new_order.to(src_tokens.device).long()
+ encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
+
+ src_lengths = encoder_input['src_lengths']
+ # initialize buffers
+ scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
+ lm_prefix_scores = src_tokens.new(bsz * beam_size).float().fill_(0)
+
+ scores_buf = scores.clone()
+ tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
+ tokens_buf = tokens.clone()
+ tokens[:, 0] = self.eos if bos_token is None else bos_token
+
+ # reorder source tokens so they may be used as a reference in generating P(S|T)
+ src_tokens = reorder_all_tokens(src_tokens, src_lengths, self.src_dict.eos_index)
+
+ src_tokens = src_tokens.repeat(1, beam_size).view(-1, src_len)
+ src_lengths = src_lengths.view(bsz, -1).repeat(1, beam_size).view(bsz*beam_size, -1)
+
+ attn, attn_buf = None, None
+ nonpad_idxs = None
+
+ # The cands_to_ignore indicates candidates that should be ignored.
+ # For example, suppose we're sampling and have already finalized 2/5
+ # samples. Then the cands_to_ignore would mark 2 positions as being ignored,
+ # so that we only finalize the remaining 3 samples.
+ cands_to_ignore = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask
+
+ # list of completed sentences
+ finalized = [[] for i in range(bsz)]
+ finished = [False for i in range(bsz)]
+ num_remaining_sent = bsz
+
+ # number of candidate hypos per step
+ cand_size = 2 * beam_size # 2 x beam size in case half are EOS
+
+ # offset arrays for converting between different indexing schemes
+ bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
+ cand_offsets = torch.arange(0, cand_size).type_as(tokens)
+
+ # helper function for allocating buffers on the fly
+ buffers = {}
+
+ def buffer(name, type_of=tokens): # noqa
+ if name not in buffers:
+ buffers[name] = type_of.new()
+ return buffers[name]
+
+ def is_finished(sent, step, unfin_idx):
+ """
+ Check whether we've finished generation for a given sentence, by
+ comparing the worst score among finalized hypotheses to the best
+ possible score among unfinalized hypotheses.
+ """
+ assert len(finalized[sent]) <= beam_size
+ if len(finalized[sent]) == beam_size:
+ return True
+ return False
+
+ def finalize_hypos(step, bbsz_idx, eos_scores, combined_noisy_channel_eos_scores):
+ """
+ Finalize the given hypotheses at this step, while keeping the total
+ number of finalized hypotheses per sentence <= beam_size.
+
+ Note: the input must be in the desired finalization order, so that
+ hypotheses that appear earlier in the input are preferred to those
+ that appear later.
+
+ Args:
+ step: current time step
+ bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
+ indicating which hypotheses to finalize
+ eos_scores: A vector of the same size as bbsz_idx containing
+ fw scores for each hypothesis
+ combined_noisy_channel_eos_scores: A vector of the same size as bbsz_idx containing
+ combined noisy channel scores for each hypothesis
+ """
+ assert bbsz_idx.numel() == eos_scores.numel()
+
+ # clone relevant token and attention tensors
+ tokens_clone = tokens.index_select(0, bbsz_idx)
+ tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
+ assert not tokens_clone.eq(self.eos).any()
+ tokens_clone[:, step] = self.eos
+ attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
+
+ # compute scores per token position
+ pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
+ pos_scores[:, step] = eos_scores
+ # convert from cumulative to per-position scores
+ pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
+
+ # normalize sentence-level scores
+ if self.normalize_scores:
+ combined_noisy_channel_eos_scores /= (step + 1) ** self.len_penalty
+
+ cum_unfin = []
+ prev = 0
+ for f in finished:
+ if f:
+ prev += 1
+ else:
+ cum_unfin.append(prev)
+
+ sents_seen = set()
+ for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), combined_noisy_channel_eos_scores.tolist())):
+ unfin_idx = idx // beam_size
+ sent = unfin_idx + cum_unfin[unfin_idx]
+
+ sents_seen.add((sent, unfin_idx))
+
+ if self.match_source_len and step > src_lengths_no_eos[unfin_idx]:
+ score = -math.inf
+
+ def get_hypo():
+
+ if attn_clone is not None:
+ # remove padding tokens from attn scores
+ hypo_attn = attn_clone[i][nonpad_idxs[sent]]
+ _, alignment = hypo_attn.max(dim=0)
+ else:
+ hypo_attn = None
+ alignment = None
+
+ return {
+ 'tokens': tokens_clone[i],
+ 'score': score,
+ 'attention': hypo_attn, # src_len x tgt_len
+ 'alignment': alignment,
+ 'positional_scores': pos_scores[i],
+ }
+
+ if len(finalized[sent]) < beam_size:
+ finalized[sent].append(get_hypo())
+
+ newly_finished = []
+ for sent, unfin_idx in sents_seen:
+ # check termination conditions for this sentence
+ if not finished[sent] and is_finished(sent, step, unfin_idx):
+ finished[sent] = True
+ newly_finished.append(unfin_idx)
+ return newly_finished
+
+ def noisy_channel_rescoring(lprobs, beam_size, bsz, src_tokens, tokens, k):
+ """Rescore the top k hypothesis from each beam using noisy channel modeling
+ Returns:
+ new_fw_lprobs: the direct model probabilities after pruning the top k
+ new_ch_lm_lprobs: the combined channel and language model probabilities
+ new_lm_lprobs: the language model probabilities after pruning the top k
+ """
+ with torch.no_grad():
+ lprobs_size = lprobs.size()
+ if prefix_tokens is not None and step < prefix_tokens.size(1):
+ probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
+ cand_scores = torch.gather(
+ probs_slice, dim=1,
+ index=prefix_tokens[:, step].view(-1, 1).data
+ ).expand(-1, beam_size).contiguous().view(bsz*beam_size, 1)
+ cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, beam_size).data.contiguous().view(bsz*beam_size, 1)
+
+ # need to calculate and save fw and lm probs for prefix tokens
+ fw_top_k = cand_scores
+ fw_top_k_idx = cand_indices
+ k = 1
+ else:
+ # take the top k best words for every sentence in batch*beam
+ fw_top_k, fw_top_k_idx = torch.topk(lprobs.view(beam_size*bsz, -1), k=k)
+ eos_idx = torch.nonzero(fw_top_k_idx.view(bsz*beam_size*k, -1) == self.eos)[:, 0]
+ ch_scores = fw_top_k.new_full((beam_size*bsz*k, ), 0)
+ src_size = torch.sum(src_tokens[:, :] != self.src_dict.pad_index, dim=1, keepdim=True, dtype=fw_top_k.dtype)
+
+ if self.combine_method != "lm_only":
+ temp_src_tokens_full = src_tokens[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
+ not_padding = temp_src_tokens_full[:, 1:] != self.src_dict.pad_index
+ cur_tgt_size = step+2
+
+ # add eos to all candidate sentences except those that already end in eos
+ eos_tokens = tokens[:, 0].repeat(1, k).view(-1, 1)
+ eos_tokens[eos_idx] = self.tgt_dict.pad_index
+
+ if step == 0:
+ channel_input = torch.cat((fw_top_k_idx.view(-1, 1), eos_tokens), 1)
+ else:
+ # move eos from beginning to end of target sentence
+ channel_input = torch.cat((tokens[:, 1:step + 1].repeat(1, k).view(-1, step), fw_top_k_idx.view(-1, 1), eos_tokens), 1)
+
+ ch_input_lengths = torch.tensor(np.full(channel_input.size(0), cur_tgt_size))
+ ch_input_lengths[eos_idx] = cur_tgt_size-1
+ if self.channel_scoring_type == "unnormalized":
+ ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
+ ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
+ del ch_encoder_output
+ ch_intermed_scores = channel_model.decoder.unnormalized_scores_given_target(ch_decoder_output, target_ids=temp_src_tokens_full[:, 1:])
+ ch_intermed_scores = ch_intermed_scores.float()
+ ch_intermed_scores *= not_padding.float()
+ ch_scores = torch.sum(ch_intermed_scores, dim=1)
+ elif self.channel_scoring_type == "k2_separate":
+ for k_idx in range(k):
+ k_eos_tokens = eos_tokens[k_idx::k, :]
+ if step == 0:
+ k_ch_input = torch.cat((fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
+ else:
+ # move eos from beginning to end of target sentence
+ k_ch_input = torch.cat((tokens[:, 1:step + 1], fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
+ k_ch_input_lengths = ch_input_lengths[k_idx::k]
+ k_ch_output = channel_model(k_ch_input, k_ch_input_lengths, src_tokens)
+ k_ch_lprobs = channel_model.get_normalized_probs(k_ch_output, log_probs=True)
+ k_ch_intermed_scores = torch.gather(k_ch_lprobs[:, :-1, :], 2, src_tokens[:, 1:].unsqueeze(2)).squeeze(2)
+ k_ch_intermed_scores *= not_padding.float()
+ ch_scores[k_idx::k] = torch.sum(k_ch_intermed_scores, dim=1)
+ elif self.channel_scoring_type == "src_vocab":
+ ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
+ ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
+
+ del ch_encoder_output
+ ch_lprobs = normalized_scores_with_batch_vocab(
+ channel_model.decoder,
+ ch_decoder_output, src_tokens, k, bsz, beam_size,
+ self.src_dict.pad_index, top_k=self.top_k_vocab)
+ ch_scores = torch.sum(ch_lprobs, dim=1)
+ elif self.channel_scoring_type == "src_vocab_batched":
+ ch_bsz_size = temp_src_tokens_full.shape[0]
+ ch_lprobs_list = [None] * len(range(0, ch_bsz_size, self.ch_scoring_bsz))
+ for i, start_idx in enumerate(range(0, ch_bsz_size, self.ch_scoring_bsz)):
+ end_idx = min(start_idx + self.ch_scoring_bsz, ch_bsz_size)
+ temp_src_tokens_full_batch = temp_src_tokens_full[start_idx:end_idx, :]
+ channel_input_batch = channel_input[start_idx:end_idx, :]
+ ch_input_lengths_batch = ch_input_lengths[start_idx:end_idx]
+ ch_encoder_output_batch = channel_model.encoder(channel_input_batch, src_lengths=ch_input_lengths_batch)
+ ch_decoder_output_batch, _ = channel_model.decoder(temp_src_tokens_full_batch, encoder_out=ch_encoder_output_batch, features_only=True)
+ ch_lprobs_list[i] = normalized_scores_with_batch_vocab(
+ channel_model.decoder,
+ ch_decoder_output_batch, src_tokens, k, bsz, beam_size,
+ self.src_dict.pad_index, top_k=self.top_k_vocab,
+ start_idx=start_idx, end_idx=end_idx)
+ ch_lprobs = torch.cat(ch_lprobs_list, dim=0)
+ ch_scores = torch.sum(ch_lprobs, dim=1)
+ else:
+ ch_output = channel_model(channel_input, ch_input_lengths, temp_src_tokens_full)
+ ch_lprobs = channel_model.get_normalized_probs(ch_output, log_probs=True)
+ ch_intermed_scores = torch.gather(ch_lprobs[:, :-1, :], 2, temp_src_tokens_full[:, 1:].unsqueeze(2)).squeeze().view(bsz*beam_size*k, -1)
+ ch_intermed_scores *= not_padding.float()
+ ch_scores = torch.sum(ch_intermed_scores, dim=1)
+
+ else:
+ cur_tgt_size = 0
+ ch_scores = ch_scores.view(bsz*beam_size, k)
+ expanded_lm_prefix_scores = lm_prefix_scores.unsqueeze(1).expand(-1, k).flatten()
+
+ if self.share_tgt_dict:
+ lm_scores = get_lm_scores(lm, tokens[:, :step + 1].view(-1, step+1), lm_incremental_states, fw_top_k_idx.view(-1, 1), torch.tensor(np.full(tokens.size(0), step+1)), k)
+ else:
+ new_lm_input = dict2dict(tokens[:, :step + 1].view(-1, step+1), self.tgt_to_lm)
+ new_cands = dict2dict(fw_top_k_idx.view(-1, 1), self.tgt_to_lm)
+ lm_scores = get_lm_scores(lm, new_lm_input, lm_incremental_states, new_cands, torch.tensor(np.full(tokens.size(0), step+1)), k)
+
+ lm_scores.add_(expanded_lm_prefix_scores)
+ ch_lm_scores = combine_ch_lm(self.combine_method, ch_scores, lm_scores, src_size, cur_tgt_size)
+ # initialize all as min value
+ new_fw_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
+ new_ch_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
+ new_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
+ new_fw_lprobs[:, self.pad] = -math.inf
+ new_ch_lm_lprobs[:, self.pad] = -math.inf
+ new_lm_lprobs[:, self.pad] = -math.inf
+
+ new_fw_lprobs.scatter_(1, fw_top_k_idx, fw_top_k)
+ new_ch_lm_lprobs.scatter_(1, fw_top_k_idx, ch_lm_scores)
+ new_lm_lprobs.scatter_(1, fw_top_k_idx, lm_scores.view(-1, k))
+ return new_fw_lprobs, new_ch_lm_lprobs, new_lm_lprobs
+
+ def combine_ch_lm(combine_type, ch_scores, lm_scores1, src_size, tgt_size):
+ if self.channel_scoring_type == "unnormalized":
+ ch_scores = self.log_softmax_fn(
+ ch_scores.view(-1, self.beam_size * self.k2)
+ ).view(ch_scores.shape)
+ ch_scores = ch_scores * self.ch_weight
+ lm_scores1 = lm_scores1 * self.lm_weight
+
+ if combine_type == "lm_only":
+ # log P(T|S) + log P(T)
+ ch_scores = lm_scores1.view(ch_scores.size())
+ elif combine_type == "noisy_channel":
+ # 1/t log P(T|S) + 1/s log P(S|T) + 1/t log P(T)
+ if self.normalize_lm_scores_by_tgt_len:
+ ch_scores.div_(src_size)
+ lm_scores_norm = lm_scores1.view(ch_scores.size()).div(tgt_size)
+ ch_scores.add_(lm_scores_norm)
+ # 1/t log P(T|S) + 1/s log P(S|T) + 1/s log P(T)
+ else:
+ ch_scores.add_(lm_scores1.view(ch_scores.size()))
+ ch_scores.div_(src_size)
+
+ return ch_scores
+
+ if self.channel_models is not None:
+ channel_model = self.channel_models[0] # assume only one channel_model model
+ else:
+ channel_model = None
+
+ lm = EnsembleModel(self.lm_models)
+ lm_incremental_states = torch.jit.annotate(
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
+ [
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
+ for i in range(lm.models_size)
+ ],
+ )
+
+ reorder_state = None
+ batch_idxs = None
+ for step in range(max_len + 1): # one extra step for EOS marker
+ # reorder decoder internal states based on the prev choice of beams
+ if reorder_state is not None:
+ if batch_idxs is not None:
+ # update beam indices to take into account removed sentences
+ corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
+ reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
+ model.reorder_incremental_state(incremental_states, reorder_state)
+ encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state)
+
+ lm.reorder_incremental_state(lm_incremental_states, reorder_state)
+
+ fw_lprobs, avg_attn_scores = model.forward_decoder(
+ tokens[:, :step + 1], encoder_outs, incremental_states, temperature=self.temperature,
+ )
+
+ fw_lprobs[:, self.pad] = -math.inf # never select pad
+ fw_lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
+ fw_lprobs, ch_lm_lprobs, lm_lprobs = noisy_channel_rescoring(fw_lprobs, beam_size, bsz, src_tokens, tokens, self.k2)
+
+ # handle min and max length constraints
+ if step >= max_len:
+ fw_lprobs[:, :self.eos] = -math.inf
+ fw_lprobs[:, self.eos + 1:] = -math.inf
+ elif step < self.min_len:
+ fw_lprobs[:, self.eos] = -math.inf
+
+ # handle prefix tokens (possibly with different lengths)
+ if prefix_tokens is not None and step < prefix_tokens.size(1):
+ prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
+ prefix_mask = prefix_toks.ne(self.pad)
+
+ prefix_fw_lprobs = fw_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
+ fw_lprobs[prefix_mask] = -math.inf
+ fw_lprobs[prefix_mask] = fw_lprobs[prefix_mask].scatter_(
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_fw_lprobs
+ )
+
+ prefix_ch_lm_lprobs = ch_lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
+ ch_lm_lprobs[prefix_mask] = -math.inf
+ ch_lm_lprobs[prefix_mask] = ch_lm_lprobs[prefix_mask].scatter_(
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_ch_lm_lprobs
+ )
+
+ prefix_lm_lprobs = lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
+ lm_lprobs[prefix_mask] = -math.inf
+ lm_lprobs[prefix_mask] = lm_lprobs[prefix_mask].scatter_(
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lm_lprobs
+ )
+
+ # if prefix includes eos, then we should make sure tokens and
+ # scores are the same across all beams
+ eos_mask = prefix_toks.eq(self.eos)
+ if eos_mask.any():
+ # validate that the first beam matches the prefix
+ first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
+ eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
+ target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
+ assert (first_beam == target_prefix).all()
+
+ def replicate_first_beam(tensor, mask):
+ tensor = tensor.view(-1, beam_size, tensor.size(-1))
+ tensor[mask] = tensor[mask][:, :1, :]
+ return tensor.view(-1, tensor.size(-1))
+
+ # copy tokens, scores and lprobs from the first beam to all beams
+ tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
+ scores = replicate_first_beam(scores, eos_mask_batch_dim)
+
+ fw_lprobs = replicate_first_beam(fw_lprobs, eos_mask_batch_dim)
+ ch_lm_lprobs = replicate_first_beam(ch_lm_lprobs, eos_mask_batch_dim)
+ lm_lprobs = replicate_first_beam(lm_lprobs, eos_mask_batch_dim)
+
+ if self.no_repeat_ngram_size > 0:
+ # for each beam and batch sentence, generate a list of previous ngrams
+ gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
+ for bbsz_idx in range(bsz * beam_size):
+ gen_tokens = tokens[bbsz_idx].tolist()
+ for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
+ gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
+ gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
+
+ # Record attention scores
+ if avg_attn_scores is not None:
+ if attn is None:
+ attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
+ attn_buf = attn.clone()
+ nonpad_idxs = src_tokens.ne(self.pad)
+ attn[:, :, step + 1].copy_(avg_attn_scores)
+
+ scores = scores.type_as(fw_lprobs)
+ scores_buf = scores_buf.type_as(fw_lprobs)
+
+ self.search.set_src_lengths(src_lengths_no_eos)
+
+ if self.no_repeat_ngram_size > 0:
+ def calculate_banned_tokens(bbsz_idx):
+ # before decoding the next token, prevent decoding of ngrams that have already appeared
+ ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
+ return gen_ngrams[bbsz_idx].get(ngram_index, [])
+
+ if step + 2 - self.no_repeat_ngram_size >= 0:
+ # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
+ banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
+ else:
+ banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
+
+ for bbsz_idx in range(bsz * beam_size):
+ fw_lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
+
+ combined_noisy_channel_scores, fw_lprobs_top_k, lm_lprobs_top_k, cand_indices, cand_beams = self.search.step(
+ step,
+ fw_lprobs.view(bsz, -1, self.vocab_size),
+ scores.view(bsz, beam_size, -1)[:, :, :step], ch_lm_lprobs.view(bsz, -1, self.vocab_size),
+ lm_lprobs.view(bsz, -1, self.vocab_size), self.combine_method
+ )
+
+ # cand_bbsz_idx contains beam indices for the top candidate
+ # hypotheses, with a range of values: [0, bsz*beam_size),
+ # and dimensions: [bsz, cand_size]
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
+
+ # finalize hypotheses that end in eos (except for candidates to be ignored)
+ eos_mask = cand_indices.eq(self.eos)
+ eos_mask[:, :beam_size] &= ~cands_to_ignore
+
+ # only consider eos when it's among the top beam_size indices
+ eos_bbsz_idx = torch.masked_select(
+ cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
+ )
+
+ finalized_sents = set()
+ if eos_bbsz_idx.numel() > 0:
+ eos_scores = torch.masked_select(
+ fw_lprobs_top_k[:, :beam_size], mask=eos_mask[:, :beam_size]
+ )
+ combined_noisy_channel_eos_scores = torch.masked_select(
+ combined_noisy_channel_scores[:, :beam_size],
+ mask=eos_mask[:, :beam_size],
+ )
+
+ # finalize hypo using channel model score
+ finalized_sents = finalize_hypos(
+ step, eos_bbsz_idx, eos_scores, combined_noisy_channel_eos_scores)
+
+ num_remaining_sent -= len(finalized_sents)
+
+ assert num_remaining_sent >= 0
+ if num_remaining_sent == 0:
+ break
+
+ if len(finalized_sents) > 0:
+ new_bsz = bsz - len(finalized_sents)
+
+ # construct batch_idxs which holds indices of batches to keep for the next pass
+ batch_mask = cand_indices.new_ones(bsz)
+ batch_mask[cand_indices.new(finalized_sents)] = 0
+ batch_idxs = torch.nonzero(batch_mask).squeeze(-1)
+
+ eos_mask = eos_mask[batch_idxs]
+ cand_beams = cand_beams[batch_idxs]
+ bbsz_offsets.resize_(new_bsz, 1)
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
+
+ lm_lprobs_top_k = lm_lprobs_top_k[batch_idxs]
+
+ fw_lprobs_top_k = fw_lprobs_top_k[batch_idxs]
+ cand_indices = cand_indices[batch_idxs]
+ if prefix_tokens is not None:
+ prefix_tokens = prefix_tokens[batch_idxs]
+ src_lengths_no_eos = src_lengths_no_eos[batch_idxs]
+ cands_to_ignore = cands_to_ignore[batch_idxs]
+
+ scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
+ scores_buf.resize_as_(scores)
+ tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
+ tokens_buf.resize_as_(tokens)
+ src_tokens = src_tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
+ src_lengths = src_lengths.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
+ lm_prefix_scores = lm_prefix_scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1).squeeze()
+
+ if attn is not None:
+ attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
+ attn_buf.resize_as_(attn)
+ bsz = new_bsz
+ else:
+ batch_idxs = None
+
+ # Set active_mask so that values > cand_size indicate eos or
+ # ignored hypos and values < cand_size indicate candidate
+ # active hypos. After this, the min values per row are the top
+ # candidate active hypos.
+ eos_mask[:, :beam_size] |= cands_to_ignore
+ active_mask = torch.add(
+ eos_mask.type_as(cand_offsets) * cand_size,
+ cand_offsets[: eos_mask.size(1)],
+ )
+
+ # get the top beam_size active hypotheses, which are just the hypos
+ # with the smallest values in active_mask
+ active_hypos, new_cands_to_ignore = buffer('active_hypos'), buffer('new_cands_to_ignore')
+ torch.topk(
+ active_mask, k=beam_size, dim=1, largest=False,
+ out=(new_cands_to_ignore, active_hypos)
+ )
+
+ # update cands_to_ignore to ignore any finalized hypos
+ cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
+ assert (~cands_to_ignore).any(dim=1).all()
+
+ active_bbsz_idx = buffer('active_bbsz_idx')
+ torch.gather(
+ cand_bbsz_idx, dim=1, index=active_hypos,
+ out=active_bbsz_idx,
+ )
+ active_scores = torch.gather(
+ fw_lprobs_top_k, dim=1, index=active_hypos,
+ out=scores[:, step].view(bsz, beam_size),
+ )
+
+ active_bbsz_idx = active_bbsz_idx.view(-1)
+ active_scores = active_scores.view(-1)
+
+ # copy tokens and scores for active hypotheses
+ torch.index_select(
+ tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
+ out=tokens_buf[:, :step + 1],
+ )
+ torch.gather(
+ cand_indices, dim=1, index=active_hypos,
+ out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
+ )
+ if step > 0:
+ torch.index_select(
+ scores[:, :step], dim=0, index=active_bbsz_idx,
+ out=scores_buf[:, :step],
+ )
+ torch.gather(
+ fw_lprobs_top_k, dim=1, index=active_hypos,
+ out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
+ )
+ torch.gather(
+ lm_lprobs_top_k, dim=1, index=active_hypos,
+ out=lm_prefix_scores.view(bsz, beam_size)
+ )
+
+ # copy attention for active hypotheses
+ if attn is not None:
+ torch.index_select(
+ attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
+ out=attn_buf[:, :, :step + 2],
+ )
+
+ # swap buffers
+ tokens, tokens_buf = tokens_buf, tokens
+ scores, scores_buf = scores_buf, scores
+ if attn is not None:
+ attn, attn_buf = attn_buf, attn
+
+ # reorder incremental state in decoder
+ reorder_state = active_bbsz_idx
+
+ # sort by score descending
+ for sent in range(len(finalized)):
+ finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
+
+ return finalized
+
+
+def get_lm_scores(model, input_tokens, incremental_states, cand_tokens, input_len, k):
+ with torch.no_grad():
+ lm_lprobs, avg_attn_scores = model.forward_decoder(
+ input_tokens, encoder_outs=None, incremental_states=incremental_states,
+ )
+
+ lm_lprobs_size = lm_lprobs.size(0)
+ probs_next_wrd = torch.gather(lm_lprobs.repeat(1, k).view(lm_lprobs_size*k, -1), 1, cand_tokens).squeeze().view(-1)
+
+ return probs_next_wrd
+
+
+def make_dict2dict(old_dict, new_dict):
+ dict2dict_map = {}
+ for sym in old_dict.symbols:
+ dict2dict_map[old_dict.index(sym)] = new_dict.index(sym)
+ return dict2dict_map
+
+
+def dict2dict(tokens, dict2dict_map):
+ if tokens.device == torch.device('cpu'):
+ tokens_tmp = tokens
+ else:
+ tokens_tmp = tokens.cpu()
+ return tokens_tmp.map_(
+ tokens_tmp,
+ lambda _, val, dict2dict_map=dict2dict_map : dict2dict_map[float(val)]
+ ).to(tokens.device)
+
+
+def reorder_tokens(tokens, lengths, eos):
+ # reorder source tokens so they may be used as reference for P(S|T)
+ return torch.cat((tokens.new([eos]), tokens[-lengths:-1], tokens[:-lengths]), 0)
+
+
+def reorder_all_tokens(tokens, lengths, eos):
+ # used to reorder src tokens from [ .. ] to [...]
+ # so source tokens can be used to predict P(S|T)
+ return torch.stack([reorder_tokens(token, length, eos) for token, length in zip(tokens, lengths)])
+
+
+def normalized_scores_with_batch_vocab(
+ model_decoder, features, target_ids, k, bsz, beam_size,
+ pad_idx, top_k=0, vocab_size_meter=None, start_idx=None,
+ end_idx=None, **kwargs):
+ """
+ Get normalized probabilities (or log probs) from a net's output
+ w.r.t. vocab consisting of target IDs in the batch
+ """
+ if model_decoder.adaptive_softmax is None:
+ weight = model_decoder.output_projection.weight
+ vocab_ids = torch.unique(
+ torch.cat(
+ (torch.unique(target_ids), torch.arange(top_k, device=target_ids.device))
+ )
+ )
+ id_map = dict(zip(vocab_ids.tolist(), range(len(vocab_ids))))
+ mapped_target_ids = target_ids.cpu().apply_(
+ lambda x, id_map=id_map: id_map[x]
+ ).to(target_ids.device)
+ expanded_target_ids = mapped_target_ids[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
+ if start_idx is not None and end_idx is not None:
+ expanded_target_ids = expanded_target_ids[start_idx:end_idx, :]
+ logits = F.linear(features, weight[vocab_ids, :])
+ log_softmax = F.log_softmax(logits, dim=-1, dtype=torch.float32)
+ intermed_scores = torch.gather(
+ log_softmax[:, :-1, :],
+ 2,
+ expanded_target_ids[:, 1:].unsqueeze(2),
+ ).squeeze()
+ not_padding = expanded_target_ids[:, 1:] != pad_idx
+ intermed_scores *= not_padding.float()
+ return intermed_scores
+ else:
+ raise ValueError("adaptive softmax doesn't work with " +
+ "`normalized_scores_with_batch_vocab()`")
diff --git a/examples/fast_noisy_channel/noisy_channel_translation.py b/examples/fast_noisy_channel/noisy_channel_translation.py
new file mode 100644
index 0000000000..b74bdfd456
--- /dev/null
+++ b/examples/fast_noisy_channel/noisy_channel_translation.py
@@ -0,0 +1,127 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.tasks.translation import TranslationTask
+from fairseq.tasks.language_modeling import LanguageModelingTask
+from fairseq import checkpoint_utils
+import argparse
+from fairseq.tasks import register_task
+import torch
+
+
+@register_task("noisy_channel_translation")
+class NoisyChannelTranslation(TranslationTask):
+ """
+ Rescore the top k candidates from each beam using noisy channel modeling
+ """
+
+ @staticmethod
+ def add_args(parser):
+ """Add task-specific arguments to the parser."""
+ TranslationTask.add_args(parser)
+ # fmt: off
+ parser.add_argument('--channel-model', metavar='FILE',
+ help='path to P(S|T) model. P(S|T) and P(T|S) must share source and target dictionaries.')
+ parser.add_argument('--combine-method', default='lm_only',
+ choices=['lm_only', 'noisy_channel'],
+ help="""method for combining direct and channel model scores.
+ lm_only: decode with P(T|S)P(T)
+ noisy_channel: decode with 1/t P(T|S) + 1/s(P(S|T)P(T))""")
+ parser.add_argument('--normalize-lm-scores-by-tgt-len', action='store_true', default=False,
+ help='normalize lm score by target length instead of source length')
+ parser.add_argument('--channel-scoring-type', default='log_norm', choices=['unnormalized', 'log_norm', 'k2_separate', 'src_vocab', 'src_vocab_batched'],
+ help="Normalize bw scores with log softmax or return bw scores without log softmax")
+ parser.add_argument('--top-k-vocab', default=0, type=int,
+ help='top k vocab IDs to use with `src_vocab` in channel model scoring')
+ parser.add_argument('--k2', default=50, type=int,
+ help='the top k2 candidates to rescore with the noisy channel model for each beam')
+ parser.add_argument('--ch-wt', default=1, type=float,
+ help='weight for the channel model')
+ parser.add_argument('--lm-model', metavar='FILE',
+ help='path to lm model file, to model P(T). P(T) must share the same vocab as the direct model on the target side')
+ parser.add_argument('--lm-data', metavar='FILE',
+ help='path to lm model training data for target language, used to properly load LM with correct dictionary')
+ parser.add_argument('--lm-wt', default=1, type=float,
+ help='the weight of the lm in joint decoding')
+ # fmt: on
+
+ def build_generator(
+ self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
+ ):
+ if getattr(args, "score_reference", False):
+ raise NotImplementedError()
+ else:
+ from .noisy_channel_sequence_generator import NoisyChannelSequenceGenerator
+ use_cuda = torch.cuda.is_available() and not self.args.cpu
+ assert self.args.lm_model is not None, '--lm-model required for noisy channel generation!'
+ assert self.args.lm_data is not None, '--lm-data required for noisy channel generation to map between LM and bitext vocabs'
+ if self.args.channel_model is not None:
+ import copy
+ ch_args_task = copy.deepcopy(self.args)
+ tmp = ch_args_task.source_lang
+ ch_args_task.source_lang = ch_args_task.target_lang
+ ch_args_task.target_lang = tmp
+ ch_args_task._name = 'translation'
+ channel_task = TranslationTask.setup_task(ch_args_task)
+
+ arg_dict = {}
+ arg_dict['task'] = 'language_modeling'
+ arg_dict['sample_break_mode'] = 'eos'
+ arg_dict['data'] = self.args.lm_data
+ arg_dict['output_dictionary_size'] = -1
+ lm_args = argparse.Namespace(**arg_dict)
+ lm_task = LanguageModelingTask.setup_task(lm_args)
+ lm_dict = lm_task.output_dictionary
+
+ if self.args.channel_model is not None:
+ channel_models, _ = checkpoint_utils.load_model_ensemble(self.args.channel_model.split(':'), task=channel_task)
+
+ for model in channel_models:
+ model.make_generation_fast_(
+ beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
+ need_attn=args.print_alignment,
+ )
+ if self.args.fp16:
+ model.half()
+ if use_cuda:
+ model.cuda()
+ else:
+ channel_models = None
+
+ lm_models, _ = checkpoint_utils.load_model_ensemble(self.args.lm_model.split(':'), task=lm_task)
+
+ for model in lm_models:
+ model.make_generation_fast_(
+ beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
+ need_attn=args.print_alignment,
+ )
+ if self.args.fp16:
+ model.half()
+ if use_cuda:
+ model.cuda()
+ return NoisyChannelSequenceGenerator(
+ combine_method=self.args.combine_method,
+ tgt_dict=self.target_dictionary,
+ src_dict=self.source_dictionary,
+ beam_size=getattr(args, 'beam', 5),
+ max_len_a=getattr(args, 'max_len_a', 0),
+ max_len_b=getattr(args, 'max_len_b', 200),
+ min_len=getattr(args, 'min_len', 1),
+ len_penalty=getattr(args, 'lenpen', 1),
+ unk_penalty=getattr(args, 'unkpen', 0),
+ temperature=getattr(args, 'temperature', 1.),
+ match_source_len=getattr(args, 'match_source_len', False),
+ no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
+ normalize_scores=(not getattr(args, 'unnormalized', False)),
+ channel_models=channel_models,
+ k2=getattr(self.args, 'k2', 50),
+ ch_weight=getattr(self.args, 'ch_wt', 1),
+ channel_scoring_type=self.args.channel_scoring_type,
+ top_k_vocab=self.args.top_k_vocab,
+ lm_models=lm_models,
+ lm_dict=lm_dict,
+ lm_weight=getattr(self.args, 'lm_wt', 1),
+ normalize_lm_scores_by_tgt_len=getattr(self.args, 'normalize_lm_scores_by_tgt_len', False),
+ )
diff --git a/examples/flores101/README.md b/examples/flores101/README.md
new file mode 100644
index 0000000000..635c13f40b
--- /dev/null
+++ b/examples/flores101/README.md
@@ -0,0 +1,223 @@
+
+
+
+
+# Flores101: Large-Scale Multilingual Machine Translation
+
+## Introduction
+
+Baseline pretrained models for small and large tracks of WMT 21 Large-Scale Multilingual Machine Translation competition.
+
+Flores Task at WMT 21: http://www.statmt.org/wmt21/large-scale-multilingual-translation-task.html
+
+Flores announement blog post: https://ai.facebook.com/blog/flores-researchers-kick-off-multilingual-translation-challenge-at-wmt-and-call-for-compute-grants/
+
+
+
+## Pretrained models
+
+Model | Num layers | Embed dimension | FFN dimension| Vocab Size | #params | Download
+---|---|---|---|---|---|---
+`flores101_mm100_615M` | 12 | 1024 | 4096 | 256,000 | 615M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz
+`flores101_mm100_175M` | 6 | 512 | 2048 | 256,000 | 175M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz
+
+
+These models are trained similar to [M2M-100](https://arxiv.org/abs/2010.11125) with additional support for the languages that are part of the WMT Large-Scale Multilingual Machine Translation track. Full list of languages can be found at the bottom.
+
+
+## Example Generation code
+
+### Download model, sentencepiece vocab
+
+```bash
+fairseq=/path/to/fairseq
+cd $fairseq
+
+# Download 615M param model.
+wget https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz
+
+# Extract
+tar -xvzf flores101_mm100_615M.tar.gz
+```
+
+### Encode using our SentencePiece Model
+Note: Install SentencePiece from [here](https://github.com/google/sentencepiece)
+
+
+```bash
+fairseq=/path/to/fairseq
+cd $fairseq
+
+# Download example dataset From German to French
+sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de
+sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr
+
+for lang in de fr ; do
+ python scripts/spm_encode.py \
+ --model flores101_mm100_615M/sentencepiece.bpe.model \
+ --output_format=piece \
+ --inputs=raw_input.de-fr.${lang} \
+ --outputs=spm.de-fr.${lang}
+done
+```
+
+### Binarization
+
+```bash
+fairseq-preprocess \
+ --source-lang de --target-lang fr \
+ --testpref spm.de-fr \
+ --thresholdsrc 0 --thresholdtgt 0 \
+ --destdir data_bin \
+ --srcdict flores101_mm100_615M/dict.txt --tgtdict flores101_mm100_615M/dict.txt
+```
+
+### Generation
+
+
+```bash
+fairseq-generate \
+ data_bin \
+ --batch-size 1 \
+ --path flores101_mm100_615M/model.pt \
+ --fixed-dictionary flores101_mm100_615M/dict.txt \
+ -s de -t fr \
+ --remove-bpe 'sentencepiece' \
+ --beam 5 \
+ --task translation_multi_simple_epoch \
+ --lang-pairs flores101_mm100_615M/language_pairs.txt \
+ --decoder-langtok --encoder-langtok src \
+ --gen-subset test \
+ --fp16 \
+ --dataset-impl mmap \
+ --distributed-world-size 1 --distributed-no-spawn
+```
+
+### Supported Languages and lang code
+
+Language | lang code
+---|---
+Akrikaans | af
+Amharic | am
+Arabic | ar
+Assamese | as
+Asturian | ast
+Aymara | ay
+Azerbaijani | az
+Bashkir | ba
+Belarusian | be
+Bulgarian | bg
+Bengali | bn
+Breton | br
+Bosnian | bs
+Catalan | ca
+Cebuano | ceb
+Chokwe | cjk
+Czech | cs
+Welsh | cy
+Danish | da
+German | de
+Dyula| dyu
+Greek | el
+English | en
+Spanish | es
+Estonian | et
+Persian | fa
+Fulah | ff
+Finnish | fi
+French | fr
+Western Frisian | fy
+Irish | ga
+Scottish Gaelic | gd
+Galician | gl
+Gujarati | gu
+Hausa | ha
+Hebrew | he
+Hindi | hi
+Croatian | hr
+Haitian Creole | ht
+Hungarian | hu
+Armenian | hy
+Indonesian | id
+Igbo | ig
+Iloko | ilo
+Icelandic | is
+Italian | it
+Japanese | ja
+Javanese | jv
+Georgian | ka
+Kachin | kac
+Kamba | kam
+Kabuverdianu | kea
+Kongo | kg
+Kazakh | kk
+Central Khmer | km
+Kimbundu | kmb
+Northern Kurdish | kmr
+Kannada | kn
+Korean | ko
+Kurdish | ku
+Kyrgyz | ky
+Luxembourgish | lb
+Ganda | lg
+Lingala | ln
+Lao | lo
+Lithuanian | lt
+Luo | luo
+Latvian | lv
+Malagasy | mg
+Maori | mi
+Macedonian | mk
+Malayalam | ml
+Mongolian | mn
+Marathi | mr
+Malay | ms
+Maltese | mt
+Burmese | my
+Nepali | ne
+Dutch | nl
+Norwegian | no
+Northern Sotho | ns
+Nyanja | ny
+Occitan | oc
+Oromo | om
+Oriya | or
+Punjabi | pa
+Polish | pl
+Pashto | ps
+Portuguese | pt
+Quechua | qu
+Romanian | ro
+Russian | ru
+Sindhi | sd
+Shan | shn
+Sinhala | si
+Slovak | sk
+Slovenian | sl
+Shona | sn
+Somali | so
+Albanian | sq
+Serbian | sr
+Swati | ss
+Sundanese | su
+Swedish | sv
+Swahili | sw
+Tamil | ta
+Telugu | te
+Tajik | tg
+Thai | th
+Tigrinya | ti
+Tagalog | tl
+Tswana | tn
+Turkish | tr
+Ukrainian | uk
+Umbundu | umb
+Urdu | ur
+Uzbek | uz
+Vietnamese | vi
+Wolof | wo
+Xhosa | xh
+Yiddish | yi
+Yoruba | yo
+Chinese| zh
+Zulu | zu
diff --git a/examples/flores101/flores_logo.png b/examples/flores101/flores_logo.png
new file mode 100644
index 0000000000..d4d1455c6e
Binary files /dev/null and b/examples/flores101/flores_logo.png differ
diff --git a/examples/fully_sharded_data_parallel/README.md b/examples/fully_sharded_data_parallel/README.md
new file mode 100644
index 0000000000..b9e44fef48
--- /dev/null
+++ b/examples/fully_sharded_data_parallel/README.md
@@ -0,0 +1,177 @@
+# Fully Sharded Data Parallel (FSDP)
+
+## Overview
+Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and
+[Google](https://arxiv.org/abs/2004.13336) has shown that data parallel
+training can be made significantly more efficient by sharding the model
+parameters and optimizer state across data parallel workers. These ideas are
+encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper provided
+by [fairscale](https://github.com/facebookresearch/fairscale/).
+
+Compared to PyTorch DDP:
+* FSDP produces identical results as PyTorch DDP (it's still synchronous data parallel training)
+* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
+* FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the communication can be overlapped with the forward pass
+* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs
+
+FSDP is fully supported in fairseq via the following new arguments:
+* `--ddp-backend=fully_sharded`: enables full sharding via FSDP
+* `--cpu-offload`: offloads the optimizer state and FP32 model copy to CPU (combine with `--optimizer=cpu_adam`)
+* `--no-reshard-after-forward`: increases training speed for large models (1B+ params) and is similar to ZeRO stage 2
+* other popular options (`--fp16`, `--update-freq`, `--checkpoint-activations`, `--offload-activations`, etc.) continue to work as normal
+
+Limitations
+
+FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP):
+* while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.)
+* FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported
+
+See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
+explanation of these and other limitations.
+
+
+
+How it works
+
+
+
+See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
+explanation of how FSDP works.
+
+
+
+## Example usage
+
+The following examples illustrate how to train a very large language model with
+13 billion parameters on 1 GPU by offloading parameters and optimizer states to
+CPU, or on 8 GPUs by fully sharding the params and optimizer states across GPUs.
+
+These examples use the WikiText-103 dataset for demonstration purposes, but
+in practice a much larger dataset will be needed to achieve good results.
+Follow the [instructions here](https://github.com/pytorch/fairseq/blob/main/examples/roberta/README.pretraining.md#1-preprocess-the-data)
+to preprocess the WikiText-103 dataset using the GPT-2/RoBERTa vocabulary.
+
+### 13B params on 1 V100 GPU (with CPU offloading)
+
+The following command trains a 13B parameter GPT-3 model on a single V100 GPU
+using the `--cpu-offload` feature to offload parameters and optimizer states to
+CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the
+`--checkpoint-activations` feature (sometimes called [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html)),
+which further saves memory in exchange for a small increase in computation.
+
+**Requirements:**
+- Install the latest master version of fairscale: `pip install git+https://github.com/facebookresearch/fairscale.git@master`
+- You'll need 32GB of GPU memory and ~256GB of system memory to train the 13B param model.
+- If you have less system memory, the 6.7B param model can be trained with ~128GB of system memory, just set `--arch transformer_lm_gpt3_6_7`
+- We use the CPU Adam optimizer from [DeepSpeed](https://github.com/microsoft/DeepSpeed), so you'll need to `pip install deepspeed` before running the command.
+
+**Notes:**
+- The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow.
+- The `--cpu-offload` feature requires training in mixed precision (`--fp16`).
+- Tune the `OMP_NUM_THREADS` env variable for best performance with CPU offloading.
+- The example command below stops training after 10 steps (`--max-update 10`) and does not save checkpoints (`--no-save`).
+
+```bash
+OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \
+ fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
+ --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
+ --cpu-offload --checkpoint-activations \
+ --task language_modeling --tokens-per-sample 2048 --batch-size 8 \
+ --arch transformer_lm_gpt3_13 \
+ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
+ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
+ --max-update 10 --no-save --log-format json --log-interval 1
+```
+
+Example output
+
+### 13B params on 8 V100 GPUs (with full parameter + optimizer state sharding)
+
+FSDP can also shard the parameters and optimizer states across multiple GPUs,
+reducing memory requirements significantly. On 8 x 32GB GPUs, sharding enables
+training the same 13B parameter model *without offloading the parameters to
+CPU*. However, without CPU offloading we'd only be able to fit a batch size of
+1 per GPU, which would cause training speed to suffer.
+
+We obtain the best performance on 8 GPUs by combining full sharding and CPU
+offloading. The following command trains the same 13B parameter GPT-3 model as
+before on 8 x 32GB V100 GPUs; training speed increases superlinearly from ~310
+words per second to ~3200 words per second.
+
+```bash
+OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
+ --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
+ --cpu-offload --checkpoint-activations \
+ --task language_modeling --tokens-per-sample 2048 --batch-size 8 \
+ --arch transformer_lm_gpt3_13 \
+ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
+ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
+ --max-update 10 --no-save --log-format json --log-interval 1
+```
+
+Example output