From 62ea40a1b74b379e066f434a804600df35619195 Mon Sep 17 00:00:00 2001 From: "1210212670@qq.com" Date: Tue, 11 May 2021 23:22:46 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=B8=8D=E5=90=8C=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E8=B7=AF=E5=BE=84=E9=80=82=E9=85=8D=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- actuator.py | 10 ++-- dialogue/tensorflow/seq2seq/actuator.py | 18 +++---- dialogue/tensorflow/smn/actuator.py | 12 ++--- dialogue/tensorflow/transformer/actuator.py | 18 +++---- requirements.txt | 60 ++++++++++++++++++--- 5 files changed, 82 insertions(+), 36 deletions(-) diff --git a/actuator.py b/actuator.py index d657a20..f120657 100644 --- a/actuator.py +++ b/actuator.py @@ -20,7 +20,7 @@ import sys from argparse import ArgumentParser -from dialogue.pytorch.seq2seq.actuator import torch_seq2seq +# from dialogue.pytorch.seq2seq.actuator import torch_seq2seq from dialogue.tensorflow.seq2seq.actuator import tf_seq2seq from dialogue.tensorflow.smn.actuator import tf_smn from dialogue.tensorflow.transformer.actuator import tf_transformer @@ -37,10 +37,10 @@ def main() -> None: "seq2seq": lambda: tf_seq2seq(), "smn": lambda: tf_smn(), }, - "torch": { - "transformer": lambda: None, - "seq2seq": lambda: torch_seq2seq(), - } + # "torch": { + # "transformer": lambda: None, + # "seq2seq": lambda: torch_seq2seq(), + # } } options = parser.parse_args(sys.argv[1:5]) diff --git a/dialogue/tensorflow/seq2seq/actuator.py b/dialogue/tensorflow/seq2seq/actuator.py index dad8055..8608478 100644 --- a/dialogue/tensorflow/seq2seq/actuator.py +++ b/dialogue/tensorflow/seq2seq/actuator.py @@ -48,18 +48,18 @@ def tf_seq2seq() -> NoReturn: parser.add_argument("--max_train_data_size", default=0, type=int, required=False, help="用于训练的最大数据大小") parser.add_argument("--max_valid_data_size", default=0, type=int, required=False, help="用于验证的最大数据大小") parser.add_argument("--max_sentence", default=40, type=int, required=False, help="单个序列的最大长度") - parser.add_argument("--dict_path", default="data\\preprocess\\seq2seq_dict.json", + parser.add_argument("--dict_path", default="data/preprocess/seq2seq_dict.json", type=str, required=False, help="字典路径") - parser.add_argument("--checkpoint_dir", default="checkpoints\\tensorflow\\seq2seq", + parser.add_argument("--checkpoint_dir", default="checkpoints/tensorflow/seq2seq", type=str, required=False, help="检查点路径") - parser.add_argument("--resource_data_path", default="data\\LCCC.json", type=str, required=False, help="原始数据集路径") - parser.add_argument("--tokenized_data_path", default="data\\preprocess\\lccc_tokenized.txt", + parser.add_argument("--resource_data_path", default="data/LCCC.json", type=str, required=False, help="原始数据集路径") + parser.add_argument("--tokenized_data_path", default="data/preprocess/lccc_tokenized.txt", type=str, required=False, help="处理好的多轮分词数据集路径") - parser.add_argument("--preprocess_data_path", default="data\\preprocess\\single_tokenized.txt", + parser.add_argument("--preprocess_data_path", default="data/preprocess/single_tokenized.txt", type=str, required=False, help="处理好的单轮分词数据集路径") - parser.add_argument("--valid_data_path", default="data\\preprocess\\single_tokenized.txt", type=str, + parser.add_argument("--valid_data_path", default="data/preprocess/single_tokenized.txt", type=str, required=False, help="处理好的单轮分词验证评估用数据集路径") - parser.add_argument("--history_image_dir", default="data\\history\\seq2seq\\", type=str, required=False, + parser.add_argument("--history_image_dir", default="data/history/seq2seq/", type=str, required=False, help="数据指标图表保存路径") parser.add_argument("--valid_freq", default=5, type=int, required=False, help="验证频率") parser.add_argument("--checkpoint_save_freq", default=2, type=int, required=False, help="检查点保存频率") @@ -72,9 +72,9 @@ def tf_seq2seq() -> NoReturn: parser.add_argument("--start_sign", default="", type=str, required=False, help="序列开始标记") parser.add_argument("--end_sign", default="", type=str, required=False, help="序列结束标记") parser.add_argument("--unk_sign", default="", type=str, required=False, help="未登录词") - parser.add_argument("--encoder_save_path", default="models\\tensorflow\\seq2seq\\encoder", type=str, + parser.add_argument("--encoder_save_path", default="models/tensorflow/seq2seq/encoder", type=str, required=False, help="Encoder的SaveModel格式保存路径") - parser.add_argument("--decoder_save_path", default="models\\tensorflow\\seq2seq\\decoder", type=str, + parser.add_argument("--decoder_save_path", default="models/tensorflow/seq2seq/decoder", type=str, required=False, help="Decoder的SaveModel格式保存路径") options = parser.parse_args().__dict__ diff --git a/dialogue/tensorflow/smn/actuator.py b/dialogue/tensorflow/smn/actuator.py index 0642742..063cdfa 100644 --- a/dialogue/tensorflow/smn/actuator.py +++ b/dialogue/tensorflow/smn/actuator.py @@ -48,16 +48,16 @@ def tf_smn() -> NoReturn: parser.add_argument("--valid_data_split", default=0.0, type=float, required=False, help="从训练数据集中划分验证数据的比例") parser.add_argument("--learning_rate", default=0.001, type=float, required=False, help="学习率") parser.add_argument("--max_database_size", default=0, type=int, required=False, help="最大数据候选数量") - parser.add_argument("--dict_path", default="data\\preprocess\\smn_dict.json", type=str, required=False, help="字典路径") - parser.add_argument("--checkpoint_dir", default="checkpoints\\tensorflow\\smn", type=str, required=False, + parser.add_argument("--dict_path", default="data/preprocess/smn_dict.json", type=str, required=False, help="字典路径") + parser.add_argument("--checkpoint_dir", default="checkpoints/tensorflow/smn", type=str, required=False, help="检查点路径") - parser.add_argument("--train_data_path", default="data\\ubuntu_train.txt", type=str, required=False, + parser.add_argument("--train_data_path", default="data/ubuntu_train.txt", type=str, required=False, help="处理好的多轮分词训练数据集路径") - parser.add_argument("--valid_data_path", default="data\\ubuntu_valid.txt", type=str, required=False, + parser.add_argument("--valid_data_path", default="data/ubuntu_valid.txt", type=str, required=False, help="处理好的多轮分词验证数据集路径") parser.add_argument("--solr_server", default="http://49.235.33.100:8983/solr/smn/", type=str, required=False, help="solr服务地址") - parser.add_argument("--candidate_database", default="data\\preprocess\\candidate.json", type=str, required=False, + parser.add_argument("--candidate_database", default="data/preprocess/candidate.json", type=str, required=False, help="候选回复数据库") parser.add_argument("--epochs", default=5, type=int, required=False, help="训练步数") parser.add_argument("--batch_size", default=64, type=int, required=False, help="batch大小") @@ -65,7 +65,7 @@ def tf_smn() -> NoReturn: parser.add_argument("--start_sign", default="", type=str, required=False, help="序列开始标记") parser.add_argument("--end_sign", default="", type=str, required=False, help="序列结束标记") parser.add_argument("--unk_sign", default="", type=str, required=False, help="未登录词") - parser.add_argument("--model_save_path", default="models\\tensorflow\\smn", type=str, + parser.add_argument("--model_save_path", default="models/tensorflow/smn", type=str, required=False, help="SaveModel格式保存路径") options = parser.parse_args().__dict__ diff --git a/dialogue/tensorflow/transformer/actuator.py b/dialogue/tensorflow/transformer/actuator.py index fd4e9f4..1194d88 100644 --- a/dialogue/tensorflow/transformer/actuator.py +++ b/dialogue/tensorflow/transformer/actuator.py @@ -63,22 +63,22 @@ def tf_transformer() -> NoReturn: parser.add_argument("--start_sign", default="", type=str, required=False, help="序列开始标记") parser.add_argument("--end_sign", default="", type=str, required=False, help="序列结束标记") parser.add_argument("--unk_sign", default="", type=str, required=False, help="未登录词") - parser.add_argument("--dict_path", default="data\\preprocess\\transformer_dict.json", type=str, required=False, + parser.add_argument("--dict_path", default="data/preprocess/transformer_dict.json", type=str, required=False, help="字典路径") - parser.add_argument("--checkpoint_dir", default="checkpoints\\tensorflow\\transformer", type=str, required=False, + parser.add_argument("--checkpoint_dir", default="checkpoints/tensorflow/transformer", type=str, required=False, help="检查点路径") - parser.add_argument("--raw_data_path", default="data\\LCCC.json", type=str, required=False, help="原始数据集路径") - parser.add_argument("--tokenized_data_path", default="data\\preprocess\\lccc_tokenized.txt", type=str, + parser.add_argument("--raw_data_path", default="data/LCCC.json", type=str, required=False, help="原始数据集路径") + parser.add_argument("--tokenized_data_path", default="data/preprocess/lccc_tokenized.txt", type=str, required=False, help="处理好的多轮分词数据集路径") - parser.add_argument("--preprocess_data_path", default="data\\preprocess\\single_tokenized.txt", type=str, + parser.add_argument("--preprocess_data_path", default="data/preprocess/single_tokenized.txt", type=str, required=False, help="处理好的单轮分词训练用数据集路径") - parser.add_argument("--valid_data_path", default="data\\preprocess\\single_tokenized.txt", type=str, + parser.add_argument("--valid_data_path", default="data/preprocess/single_tokenized.txt", type=str, required=False, help="处理好的单轮分词验证评估用数据集路径") - parser.add_argument("--history_image_dir", default="data\\history\\transformer\\", type=str, required=False, + parser.add_argument("--history_image_dir", default="data/history/transformer/", type=str, required=False, help="数据指标图表保存路径") - parser.add_argument("--encoder_save_path", default="models\\tensorflow\\transformer\\encoder", type=str, + parser.add_argument("--encoder_save_path", default="models/tensorflow/transformer/encoder", type=str, required=False, help="Encoder的SaveModel格式保存路径") - parser.add_argument("--decoder_save_path", default="models\\tensorflow\\transformer\\decoder", type=str, + parser.add_argument("--decoder_save_path", default="models/tensorflow/transformer/decoder", type=str, required=False, help="Decoder的SaveModel格式保存路径") options = parser.parse_args().__dict__ diff --git a/requirements.txt b/requirements.txt index eba5262..536cef2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,55 @@ -tensorflow==2.3.1 -pytorch==1.7.1 +absl-py==0.12.0 +astunparse==1.6.3 +cachetools==4.2.2 +certifi==2020.12.5 +chardet==4.0.0 +click==7.1.2 +cycler==0.10.0 +Flask==1.1.2 +flatbuffers==1.12 +gast==0.3.3 +google-auth==1.30.0 +google-auth-oauthlib==0.4.4 +google-pasta==0.2.0 +grpcio==1.32.0 +h5py==2.10.0 +idna==2.10 +importlib-metadata==4.0.1 +inflect==5.3.0 +itsdangerous==1.1.0 jieba==0.42.1 -inflect==5.0.2 -numpy==1.19.0 +Jinja2==2.11.3 +joblib==1.0.1 +Keras-Preprocessing==1.1.2 +kiwisolver==1.3.1 +Markdown==3.3.4 +MarkupSafe==1.1.1 +matplotlib==3.4.2 +numpy==1.19.5 +oauthlib==3.1.0 +opt-einsum==3.3.0 +Pillow==8.2.0 +protobuf==3.16.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pyparsing==2.4.7 pysolr==3.9.0 -flask==1.1.2 -scikit-learn==0.23.2 -flask-cors==3.0.9 \ No newline at end of file +python-dateutil==2.8.1 +requests==2.25.1 +requests-oauthlib==1.3.0 +rsa==4.7.2 +scikit-learn==0.24.2 +scipy==1.6.3 +six==1.15.0 +tensorboard==2.5.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorflow==2.4.1 +tensorflow-estimator==2.4.0 +termcolor==1.1.0 +threadpoolctl==2.1.0 +typing-extensions==3.7.4.3 +urllib3==1.26.4 +Werkzeug==1.0.1 +wrapt==1.12.1 +zipp==3.4.1