Skip to content

Latest commit

 

History

History
85 lines (74 loc) · 2.74 KB

README.md

File metadata and controls

85 lines (74 loc) · 2.74 KB

nmt

pytorch implementation of neural machine translation with RNNs

pytorch实现基于RNN和注意力机制的机器翻译(仅限汉译英),翻译效果与训练集质量有关

如果有自己的训练集或需要专业领域的翻译,将准备的训练集验证集测试集放在zh_en_data目录,全部是txt文件,每行一条语句 model

运行步骤

克隆本仓库

git clone https://github.com/showsunny/nmt.git

安装工具包

pip install -r requirements.txt

创建自己的词汇表(如果需要用自己的训练集)

python vocab.py --train-src=zh_en_data/train.zh --train-tgt=zh_en_data/train.en vocab.json

训练

source run.sh train

loss图像(本地训练时可以删掉第一行,如果无法显示图像,删掉第一行,下载runs路径下生成的0文件,将代码中的路径改为0文件所在文件夹(不包括这个文件的文件名)并在本地的conda环境运行)

load_ext tensorboard
tensorboard --logdir runs/nmt

loss 在测试集上测试

source run.sh test

翻译一条语句

from nmt_model import NMT

import torch
import jieba

def process_jieba(text):
    words = list(jieba.cut(text))  # 转换为列表
    return words

def detokenize(tokens):
    """ Detokenize a list of tokens into a string.
    @param tokens (list[str]): List of tokens
    @returns sentence (str): Detokenized sentence
    """
    return ''.join(tokens).replace('▁', ' ').strip()

def translate_sentence(model, src_sentence):
    """ Translate a single source sentence to target language.
    @param model (NMT): Trained NMT model
    @param src_sentence (str): Source sentence
    @returns translation (str): Translated sentence
    """
    # Tokenize the source sentence
    src_tokens = process_jieba(src_sentence)
    # Perform translation
    with torch.no_grad():
        translation_hypotheses = model.beam_search(src_tokens, beam_size=5, max_decoding_time_step=70)  # Adjust beam size and max decoding time step accordingly
        # Assuming the best hypothesis is the first one
        best_translation = translation_hypotheses[0][0]
    return best_translation

def main():
    # Load the trained model
    model = NMT.load("model.bin")
    # Set the model to evaluation mode
    model.eval()

    # Example source sentence to translate
    src_sentence = "几乎已经没有地方容纳这些人, 资源已经用尽。"

    # Translate the sentence
    translation = translate_sentence(model, src_sentence)

    print("Source Sentence:", src_sentence)
    print("Translation:", detokenize(translation))

if __name__ == '__main__':
    main()