Skip to content

Commit

Permalink
Strengthened style constraints (#1527)
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr authored Mar 4, 2024
1 parent 29b195a commit 242002e
Show file tree
Hide file tree
Showing 70 changed files with 166 additions and 164 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:

- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
# Click issue fixed in https://github.com/psf/black/pull/2966
- name: Run flake8
Expand All @@ -67,3 +67,9 @@ jobs:
working-directory: ${{github.workspace}}
run: |
black --check --diff .
- name: Run isort
shell: bash
working-directory: ${{github.workspace}}
run: |
isort --check --diff .
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
# E121,E123,E126,E226,E24,E704,W503,W504

- repo: https://github.com/pycqa/isort
rev: 5.11.5
rev: 5.10.1
hooks:
- id: isort
args: ["--profile=black"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@
import argparse
import logging

from icefall import is_module_available
import torch
from onnx_pretrained import OnnxModel

import torch
from icefall import is_module_available


def get_parser():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@
from pathlib import Path

import torch
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
from tokenizer import Tokenizer
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import (
average_checkpoints,
Expand Down
1 change: 1 addition & 0 deletions egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import str2bool

# Similar text filtering and normalization procedure as in:
Expand Down
1 change: 1 addition & 0 deletions egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
)
from gigaspeech_scoring import asr_text_post_processing
from train import get_params, get_transducer_model

from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
Expand Down
2 changes: 1 addition & 1 deletion egs/gigaspeech/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
import torch
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from train import add_model_arguments, get_params, get_model
from train import add_model_arguments, get_model, get_params

from icefall.checkpoint import (
average_checkpoints,
Expand Down
2 changes: 1 addition & 1 deletion egs/gigaspeech/ASR/zipformer/streaming_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model
from train import add_model_arguments, get_model, get_params

from icefall.checkpoint import (
average_checkpoints,
Expand Down
6 changes: 2 additions & 4 deletions egs/gigaspeech/KWS/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@
import torch
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import (
keywords_search,
)
from beam_search import keywords_search
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params

from lhotse.cut import Cut
from icefall import ContextGraph
from icefall.checkpoint import (
average_checkpoints,
Expand Down
29 changes: 14 additions & 15 deletions egs/gigaspeech/KWS/zipformer/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from train import (
add_model_arguments,
add_training_arguments,
compute_loss,
compute_validation_loss,
display_and_save_batch,
get_adjusted_batch_count,
get_model,
get_params,
load_checkpoint_if_available,
save_checkpoint,
scan_pessimistic_batches_for_oom,
set_batch_count,
)

from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints
Expand All @@ -95,21 +109,6 @@
str2bool,
)

from train import (
add_model_arguments,
add_training_arguments,
compute_loss,
compute_validation_loss,
display_and_save_batch,
get_adjusted_batch_count,
get_model,
get_params,
load_checkpoint_if_available,
save_checkpoint,
scan_pessimistic_batches_for_oom,
set_batch_count,
)

LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]


Expand Down
3 changes: 1 addition & 2 deletions egs/librispeech/ASR/conformer_ctc3/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
"""

import torch

from train import get_params, get_ctc_model
from train import get_ctc_model, get_params


def test_model():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@
import torch
import torch.nn as nn
from decoder import Decoder
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from emformer import Emformer
from scaling_converter import convert_scaled_to_non_scaled
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import (
average_checkpoints,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@
import argparse
import logging
import math
from typing import List
from typing import List, Optional

import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from torch.nn.utils.rnn import pad_sequence
from typing import Optional, List


def get_parser():
Expand Down
14 changes: 7 additions & 7 deletions egs/librispeech/ASR/long_file_recog/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,28 @@
"""

import argparse
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Tuple

from pathlib import Path
from typing import List, Optional, Tuple

import k2
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AsrDataModule
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.supervision import AlignmentItem
from lhotse.serialization import SequentialJsonlWriter
from lhotse.supervision import AlignmentItem

from icefall.utils import AttributeDict, convert_timestamp, setup_logger


def get_parser():
Expand Down
3 changes: 1 addition & 2 deletions egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,11 @@
import argparse
import logging

import torch
from onnx_pretrained import OnnxModel

from icefall import is_module_available

import torch


def get_parser():
parser = argparse.ArgumentParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@

import argparse
import logging

import sentencepiece as spm
import torch
from train import add_model_arguments, get_encoder_model, get_params

from icefall.profiler import get_model_profile
from train import get_encoder_model, add_model_arguments, get_params


def get_parser():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule

from onnx_pretrained import greedy_search, OnnxModel
from onnx_pretrained import OnnxModel, greedy_search

from icefall.utils import setup_logger, store_transcripts, write_error_stats

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@
import argparse
import logging

from icefall import is_module_available
import torch
from onnx_pretrained import OnnxModel

import torch
from icefall import is_module_available


def get_parser():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@
import torch.nn as nn
from asr_datamodule import AsrDataModule
from librispeech import LibriSpeech

from onnx_pretrained import greedy_search, OnnxModel
from onnx_pretrained import OnnxModel, greedy_search

from icefall.utils import setup_logger, store_transcripts, write_error_stats

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@

import argparse
import logging
from typing import Tuple

import sentencepiece as spm
import torch

from typing import Tuple
from scaling import BasicNorm, DoubleSwish
from torch import Tensor, nn
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params

from icefall.profiler import get_model_profile
from scaling import BasicNorm, DoubleSwish
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params


def get_parser():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule

from onnx_pretrained import greedy_search, OnnxModel
from onnx_pretrained import OnnxModel, greedy_search

from icefall.utils import setup_logger, store_transcripts, write_error_stats

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import k2
import torch

from beam_search import Hypothesis, HypothesisList, get_hyps_shape

# The force alignment problem can be formulated as finding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@
import sentencepiece as spm
import torch
import torch.nn as nn

# from asr_datamodule import LibriSpeechAsrDataModule
from gigaspeech import GigaSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
Expand All @@ -120,6 +117,9 @@
greedy_search_batch,
modified_beam_search,
)

# from asr_datamodule import LibriSpeechAsrDataModule
from gigaspeech import GigaSpeechAsrDataModule
from gigaspeech_scoring import asr_text_post_processing
from train import add_model_arguments, get_params, get_transducer_model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,15 @@

import sentencepiece as spm
import torch

from train import add_model_arguments, get_params, get_transducer_model

from icefall.utils import str2bool
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool


def get_parser():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@

import argparse
import logging
from typing import Tuple

import sentencepiece as spm
import torch

from typing import Tuple
from scaling import BasicNorm, DoubleSwish
from torch import Tensor, nn
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params

from icefall.profiler import get_model_profile
from scaling import BasicNorm, DoubleSwish
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params


def get_parser():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule

from onnx_pretrained import greedy_search, OnnxModel
from onnx_pretrained import OnnxModel, greedy_search

from icefall.utils import setup_logger, store_transcripts, write_error_stats

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"""

import torch

from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model

Expand Down
Loading

0 comments on commit 242002e

Please sign in to comment.