-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Done] Feature/mnist train api #971
Changes from 8 commits
c0e687b
06944ee
8b4cbcf
ad6cb60
025e3e9
677c79b
27d87db
9f5e742
ad93b8f
5f6c4af
36d1e61
efb5c10
20249e8
05ab22c
1f4f044
cf5bf5b
1e6c87b
eaba2e2
409a577
06dc66b
680dd92
5bca268
59009ba
a31ef0c
f06b64f
65e957c
5a68584
16ea66e
3a80272
843b63b
763a30f
9b41b08
f8e4b0b
eefe5a7
eca4592
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import py_paddle.swig_paddle as api | ||
import paddle.trainer.config_parser | ||
import numpy as np | ||
|
||
|
||
def init_parameter(network): | ||
assert isinstance(network, api.GradientMachine) | ||
for each_param in network.getParameters(): | ||
assert isinstance(each_param, api.Parameter) | ||
array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace() | ||
assert isinstance(array, np.ndarray) | ||
for i in xrange(len(array)): | ||
array[i] = np.random.uniform(-1.0, 1.0) | ||
|
||
|
||
def main(): | ||
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores | ||
config = paddle.trainer.config_parser.parse_config( | ||
'simple_mnist_network.py', '') | ||
|
||
opt_config = api.OptimizationConfig.createFromProto(config.opt_config) | ||
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config) | ||
enable_types = _temp_optimizer_.getParameterTypes() | ||
|
||
m = api.GradientMachine.createFromConfigProto( | ||
config.model_config, api.CREATE_MODE_NORMAL, enable_types) | ||
assert isinstance(m, api.GradientMachine) | ||
init_parameter(network=m) | ||
|
||
updater = api.ParameterUpdater.createLocalUpdater(opt_config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我不常用Python,所以我去看了一下各种style guide里关于naming的描述。以下style guide都要求function names和method names都是 如果我们的API要被Python社区接受,我估计得是Python style的吧。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ParameterUpdater和createLocalUpdater是直接从cpp中用swig expose出来的结构和方法,所以是cpp中的命名规范。我理解这些都不是直接暴露给用户的,而是经过我们封装一下,统一成python style。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,不过这就是swig api的悲剧点之一了。 这个名字直接从C++的头文件PaddleAPI.h自动化翻译过来的,没有办法自定义。 如果是C-API会好很多。 |
||
assert isinstance(updater, api.ParameterUpdater) | ||
updater.init(m) | ||
m.start() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. start得有一个宾语。 如果是method,有时候class就是宾语,也有时候得明确指定宾语。比如
这里比较诡异的是m的类型是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好,不过用户态的代码应该不这样。。可能类似于 with gradient_machine.enter_training(): |
||
|
||
for _ in xrange(100): | ||
updater.startPass() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 通常startXXX要有对应的endXXX。比如 https://golang.org/pkg/os/exec/#Cmd.Start 有对应的 https://golang.org/pkg/os/exec/#Cmd.Wait。如果意思是“做且做完”,可以叫 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 恩,之前的code开发了一半,发现其他的一些问题,然后交另一个PR了。。 这个后面的code补上了。 |
||
|
||
m.finish() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
另外,finish和complete的区别是“完蛋”和“完美”的区别,比如:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haha |
||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from paddle.trainer_config_helpers import * | ||
|
||
settings(learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 函数名字应该是动词或者动宾短语,settings是一个名词。这里的函数名看上去应该是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里估计最后会改。。 |
||
|
||
imgs = data_layer(name='pixel', size=784) | ||
|
||
hidden1 = fc_layer(input=imgs, size=200) | ||
hidden2 = fc_layer(input=hidden1, size=200) | ||
|
||
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation()) | ||
|
||
cost = classification_cost( | ||
input=inference, label=data_layer( | ||
name='label', size=10)) | ||
|
||
outputs(cost) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -519,6 +519,7 @@ class OptimizationConfig { | |
|
||
friend class TrainerConfig; | ||
friend class ParameterOptimizer; | ||
friend class ParameterUpdater; | ||
friend class Trainer; | ||
}; | ||
|
||
|
@@ -557,6 +558,7 @@ class Parameter { | |
ParameterPrivate* m; | ||
friend class UpdateCallbackWrapper; | ||
friend class GradientMachine; | ||
friend class ParameterUpdater; | ||
}; | ||
|
||
struct ModelConfigPrivate; | ||
|
@@ -714,6 +716,13 @@ class GradientMachine { | |
GradientMatchineCreateMode mode = CREATE_MODE_NORMAL, | ||
const std::vector<int>& parameterTypes = defaultParamTypes); | ||
|
||
/** | ||
* @brief finish | ||
*/ | ||
void finish(); | ||
|
||
void start(); | ||
|
||
/** | ||
* The forward stage of GradientMachine. | ||
* | ||
|
@@ -772,6 +781,26 @@ class GradientMachine { | |
// Not to use c++ 11 init-list, so we use static var as function default arg. | ||
static std::vector<int> defaultParamTypes; | ||
friend class Trainer; | ||
friend class ParameterUpdater; | ||
}; | ||
|
||
struct ParameterUpdaterPrivate; | ||
class ParameterUpdater { | ||
private: | ||
ParameterUpdater(); | ||
|
||
public: | ||
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在 Google C++ Style Guide 里,function/method names 应该是 虽然现在Paddle的C++ naming和Google style不同,但是我建议(至少新代码里)采用Google style,因为一旦有人违反,我们在code review comments里可以贴一个链接(如上),告诉开发者为什么naming需要修正,而不需要一遍又一遍地手工输入解释。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个条款看起来不一定是非要遵守,因为没有感受到这个条款的优势,目前Paddle中的变量命名规则是. 对于类型名,采用全大写的CamelCase. 例如 GradientMachine, NeuralNetwork. 对于变量名:
对于namespace的名称,为下划线分割的名字,例如 cpu_avx Paddle目前有一套完整的命名风格了,所以似乎不需要非要改。 另外,关于变量命名方式的guide,有很多风格(Qt,stl,boost,linux,google)。只要在一个库里面统一,看到某一个名字,能够反应出变量大概是什么,其实就还好了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我们必须要遵守一个条款。遵守的最大用处,就是不要留给每个人纷争哪些条款值得遵守,否则review就总是会陷入在这类无休止的纷争里。 准守Google style的原因是:它对pros和cons都有分析,并且每一条条款都有一个URL。这样我们code review的时候只需要贴URL,不需要反复解释。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的。单纯就命名这一点。我们可以遵守Paddle目前使用的条款,而没有必要全改成google风格的。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @reyoung paddle现在有关于naming的条款,可以在code review comments里贴URL的吗?如果没有,我建议就用Google style就好了。 |
||
~ParameterUpdater(); | ||
|
||
void init(const GradientMachine& gm); | ||
|
||
void startPass(); | ||
|
||
void finishPass(); | ||
|
||
private: | ||
ParameterUpdaterPrivate* m; | ||
}; | ||
|
||
struct TrainerPrivate; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "PaddleAPI.h" | ||
|
||
#include "PaddleAPIPrivate.h" | ||
#include "paddle/trainer/ThreadParameterUpdater.h" | ||
|
||
ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {} | ||
|
||
ParameterUpdater *ParameterUpdater::createLocalUpdater( | ||
OptimizationConfig *config) { | ||
auto param = new ParameterUpdater(); | ||
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig())); | ||
return param; | ||
} | ||
|
||
ParameterUpdater::~ParameterUpdater() { delete m; } | ||
|
||
void ParameterUpdater::init(const GradientMachine &gm) { | ||
m->updater->init(gm.m->machine->getParameters()); | ||
} | ||
|
||
void ParameterUpdater::startPass() { m->updater->startPass(); } | ||
|
||
void ParameterUpdater::finishPass() {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没看明白,我们昨天讨论的目的是用Python API写的Paddle程序都在一个py文件里。为什么这是还是分了两个.py文件呢?
这个PR的目的是用目前的Python API来实现MNIST training,还是用新的API来写demo呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个mnist一个基础,按照昨天我们商量的结果,@reyoung 先提供一个小的demo,我在这个基础上去掉配置文件并做一些封装,@reyoung会修改data provider直接feed数据,在这个基础上把这个mnist demo改造成一个符合讨论结果的api调用方式。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用第一步目前的Python API实现,第二步把一堆python文件合成一个,第三步确定Python API究竟是啥样的。