Skip to content

Commit

Permalink
minor fixes (#1345)
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr authored Oct 27, 2023
1 parent 800bf4b commit ea78b32
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions egs/tedlium3/ASR/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from train import add_model_arguments, get_model, get_params

from icefall.checkpoint import (
average_checkpoints,
Expand Down Expand Up @@ -695,7 +695,7 @@ def main():
logging.info(params)

logging.info("About to create model")
model = get_transducer_model(params)
model = get_model(params)

if not params.use_averaged_model:
if params.iter > 0:
Expand Down
4 changes: 2 additions & 2 deletions egs/tedlium3/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner


def get_transducer_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
Expand Down Expand Up @@ -1083,7 +1083,7 @@ def run(rank, world_size, args):
logging.info(params)

logging.info("About to create model")
model = get_transducer_model(params)
model = get_model(params)

num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
Expand Down

0 comments on commit ea78b32

Please sign in to comment.