Skip to content

Commit

Permalink
更新不同系统路径适配问题
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] authored and [email protected] committed May 11, 2021
1 parent a921f8b commit 62ea40a
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 36 deletions.
10 changes: 5 additions & 5 deletions actuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import sys
from argparse import ArgumentParser
from dialogue.pytorch.seq2seq.actuator import torch_seq2seq
# from dialogue.pytorch.seq2seq.actuator import torch_seq2seq
from dialogue.tensorflow.seq2seq.actuator import tf_seq2seq
from dialogue.tensorflow.smn.actuator import tf_smn
from dialogue.tensorflow.transformer.actuator import tf_transformer
Expand All @@ -37,10 +37,10 @@ def main() -> None:
"seq2seq": lambda: tf_seq2seq(),
"smn": lambda: tf_smn(),
},
"torch": {
"transformer": lambda: None,
"seq2seq": lambda: torch_seq2seq(),
}
# "torch": {
# "transformer": lambda: None,
# "seq2seq": lambda: torch_seq2seq(),
# }
}

options = parser.parse_args(sys.argv[1:5])
Expand Down
18 changes: 9 additions & 9 deletions dialogue/tensorflow/seq2seq/actuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@ def tf_seq2seq() -> NoReturn:
parser.add_argument("--max_train_data_size", default=0, type=int, required=False, help="用于训练的最大数据大小")
parser.add_argument("--max_valid_data_size", default=0, type=int, required=False, help="用于验证的最大数据大小")
parser.add_argument("--max_sentence", default=40, type=int, required=False, help="单个序列的最大长度")
parser.add_argument("--dict_path", default="data\\preprocess\\seq2seq_dict.json",
parser.add_argument("--dict_path", default="data/preprocess/seq2seq_dict.json",
type=str, required=False, help="字典路径")
parser.add_argument("--checkpoint_dir", default="checkpoints\\tensorflow\\seq2seq",
parser.add_argument("--checkpoint_dir", default="checkpoints/tensorflow/seq2seq",
type=str, required=False, help="检查点路径")
parser.add_argument("--resource_data_path", default="data\\LCCC.json", type=str, required=False, help="原始数据集路径")
parser.add_argument("--tokenized_data_path", default="data\\preprocess\\lccc_tokenized.txt",
parser.add_argument("--resource_data_path", default="data/LCCC.json", type=str, required=False, help="原始数据集路径")
parser.add_argument("--tokenized_data_path", default="data/preprocess/lccc_tokenized.txt",
type=str, required=False, help="处理好的多轮分词数据集路径")
parser.add_argument("--preprocess_data_path", default="data\\preprocess\\single_tokenized.txt",
parser.add_argument("--preprocess_data_path", default="data/preprocess/single_tokenized.txt",
type=str, required=False, help="处理好的单轮分词数据集路径")
parser.add_argument("--valid_data_path", default="data\\preprocess\\single_tokenized.txt", type=str,
parser.add_argument("--valid_data_path", default="data/preprocess/single_tokenized.txt", type=str,
required=False, help="处理好的单轮分词验证评估用数据集路径")
parser.add_argument("--history_image_dir", default="data\\history\\seq2seq\\", type=str, required=False,
parser.add_argument("--history_image_dir", default="data/history/seq2seq/", type=str, required=False,
help="数据指标图表保存路径")
parser.add_argument("--valid_freq", default=5, type=int, required=False, help="验证频率")
parser.add_argument("--checkpoint_save_freq", default=2, type=int, required=False, help="检查点保存频率")
Expand All @@ -72,9 +72,9 @@ def tf_seq2seq() -> NoReturn:
parser.add_argument("--start_sign", default="<start>", type=str, required=False, help="序列开始标记")
parser.add_argument("--end_sign", default="<end>", type=str, required=False, help="序列结束标记")
parser.add_argument("--unk_sign", default="<unk>", type=str, required=False, help="未登录词")
parser.add_argument("--encoder_save_path", default="models\\tensorflow\\seq2seq\\encoder", type=str,
parser.add_argument("--encoder_save_path", default="models/tensorflow/seq2seq/encoder", type=str,
required=False, help="Encoder的SaveModel格式保存路径")
parser.add_argument("--decoder_save_path", default="models\\tensorflow\\seq2seq\\decoder", type=str,
parser.add_argument("--decoder_save_path", default="models/tensorflow/seq2seq/decoder", type=str,
required=False, help="Decoder的SaveModel格式保存路径")

options = parser.parse_args().__dict__
Expand Down
12 changes: 6 additions & 6 deletions dialogue/tensorflow/smn/actuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,24 @@ def tf_smn() -> NoReturn:
parser.add_argument("--valid_data_split", default=0.0, type=float, required=False, help="从训练数据集中划分验证数据的比例")
parser.add_argument("--learning_rate", default=0.001, type=float, required=False, help="学习率")
parser.add_argument("--max_database_size", default=0, type=int, required=False, help="最大数据候选数量")
parser.add_argument("--dict_path", default="data\\preprocess\\smn_dict.json", type=str, required=False, help="字典路径")
parser.add_argument("--checkpoint_dir", default="checkpoints\\tensorflow\\smn", type=str, required=False,
parser.add_argument("--dict_path", default="data/preprocess/smn_dict.json", type=str, required=False, help="字典路径")
parser.add_argument("--checkpoint_dir", default="checkpoints/tensorflow/smn", type=str, required=False,
help="检查点路径")
parser.add_argument("--train_data_path", default="data\\ubuntu_train.txt", type=str, required=False,
parser.add_argument("--train_data_path", default="data/ubuntu_train.txt", type=str, required=False,
help="处理好的多轮分词训练数据集路径")
parser.add_argument("--valid_data_path", default="data\\ubuntu_valid.txt", type=str, required=False,
parser.add_argument("--valid_data_path", default="data/ubuntu_valid.txt", type=str, required=False,
help="处理好的多轮分词验证数据集路径")
parser.add_argument("--solr_server", default="http://49.235.33.100:8983/solr/smn/", type=str, required=False,
help="solr服务地址")
parser.add_argument("--candidate_database", default="data\\preprocess\\candidate.json", type=str, required=False,
parser.add_argument("--candidate_database", default="data/preprocess/candidate.json", type=str, required=False,
help="候选回复数据库")
parser.add_argument("--epochs", default=5, type=int, required=False, help="训练步数")
parser.add_argument("--batch_size", default=64, type=int, required=False, help="batch大小")
parser.add_argument("--buffer_size", default=20000, type=int, required=False, help="Dataset加载缓冲大小")
parser.add_argument("--start_sign", default="<start>", type=str, required=False, help="序列开始标记")
parser.add_argument("--end_sign", default="<end>", type=str, required=False, help="序列结束标记")
parser.add_argument("--unk_sign", default="<unk>", type=str, required=False, help="未登录词")
parser.add_argument("--model_save_path", default="models\\tensorflow\\smn", type=str,
parser.add_argument("--model_save_path", default="models/tensorflow/smn", type=str,
required=False, help="SaveModel格式保存路径")

options = parser.parse_args().__dict__
Expand Down
18 changes: 9 additions & 9 deletions dialogue/tensorflow/transformer/actuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,22 @@ def tf_transformer() -> NoReturn:
parser.add_argument("--start_sign", default="<start>", type=str, required=False, help="序列开始标记")
parser.add_argument("--end_sign", default="<end>", type=str, required=False, help="序列结束标记")
parser.add_argument("--unk_sign", default="<unk>", type=str, required=False, help="未登录词")
parser.add_argument("--dict_path", default="data\\preprocess\\transformer_dict.json", type=str, required=False,
parser.add_argument("--dict_path", default="data/preprocess/transformer_dict.json", type=str, required=False,
help="字典路径")
parser.add_argument("--checkpoint_dir", default="checkpoints\\tensorflow\\transformer", type=str, required=False,
parser.add_argument("--checkpoint_dir", default="checkpoints/tensorflow/transformer", type=str, required=False,
help="检查点路径")
parser.add_argument("--raw_data_path", default="data\\LCCC.json", type=str, required=False, help="原始数据集路径")
parser.add_argument("--tokenized_data_path", default="data\\preprocess\\lccc_tokenized.txt", type=str,
parser.add_argument("--raw_data_path", default="data/LCCC.json", type=str, required=False, help="原始数据集路径")
parser.add_argument("--tokenized_data_path", default="data/preprocess/lccc_tokenized.txt", type=str,
required=False, help="处理好的多轮分词数据集路径")
parser.add_argument("--preprocess_data_path", default="data\\preprocess\\single_tokenized.txt", type=str,
parser.add_argument("--preprocess_data_path", default="data/preprocess/single_tokenized.txt", type=str,
required=False, help="处理好的单轮分词训练用数据集路径")
parser.add_argument("--valid_data_path", default="data\\preprocess\\single_tokenized.txt", type=str,
parser.add_argument("--valid_data_path", default="data/preprocess/single_tokenized.txt", type=str,
required=False, help="处理好的单轮分词验证评估用数据集路径")
parser.add_argument("--history_image_dir", default="data\\history\\transformer\\", type=str, required=False,
parser.add_argument("--history_image_dir", default="data/history/transformer/", type=str, required=False,
help="数据指标图表保存路径")
parser.add_argument("--encoder_save_path", default="models\\tensorflow\\transformer\\encoder", type=str,
parser.add_argument("--encoder_save_path", default="models/tensorflow/transformer/encoder", type=str,
required=False, help="Encoder的SaveModel格式保存路径")
parser.add_argument("--decoder_save_path", default="models\\tensorflow\\transformer\\decoder", type=str,
parser.add_argument("--decoder_save_path", default="models/tensorflow/transformer/decoder", type=str,
required=False, help="Decoder的SaveModel格式保存路径")

options = parser.parse_args().__dict__
Expand Down
60 changes: 53 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,55 @@
tensorflow==2.3.1
pytorch==1.7.1
absl-py==0.12.0
astunparse==1.6.3
cachetools==4.2.2
certifi==2020.12.5
chardet==4.0.0
click==7.1.2
cycler==0.10.0
Flask==1.1.2
flatbuffers==1.12
gast==0.3.3
google-auth==1.30.0
google-auth-oauthlib==0.4.4
google-pasta==0.2.0
grpcio==1.32.0
h5py==2.10.0
idna==2.10
importlib-metadata==4.0.1
inflect==5.3.0
itsdangerous==1.1.0
jieba==0.42.1
inflect==5.0.2
numpy==1.19.0
Jinja2==2.11.3
joblib==1.0.1
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.4
MarkupSafe==1.1.1
matplotlib==3.4.2
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
Pillow==8.2.0
protobuf==3.16.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==2.4.7
pysolr==3.9.0
flask==1.1.2
scikit-learn==0.23.2
flask-cors==3.0.9
python-dateutil==2.8.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.6.3
six==1.15.0
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-estimator==2.4.0
termcolor==1.1.0
threadpoolctl==2.1.0
typing-extensions==3.7.4.3
urllib3==1.26.4
Werkzeug==1.0.1
wrapt==1.12.1
zipp==3.4.1

0 comments on commit 62ea40a

Please sign in to comment.