-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Done] Feature/mnist train api #971
Changes from 34 commits
c0e687b
06944ee
8b4cbcf
ad6cb60
025e3e9
677c79b
27d87db
9f5e742
ad93b8f
5f6c4af
36d1e61
efb5c10
20249e8
05ab22c
1f4f044
cf5bf5b
1e6c87b
eaba2e2
409a577
06dc66b
680dd92
5bca268
59009ba
a31ef0c
f06b64f
65e957c
5a68584
16ea66e
3a80272
843b63b
763a30f
9b41b08
f8e4b0b
eefe5a7
eca4592
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ mnist_vgg_model | |
plot.png | ||
train.log | ||
*pyc | ||
.ipynb_checkpoints |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
""" | ||
A very basic example for how to use current Raw SWIG API to train mnist network. | ||
|
||
Current implementation uses Raw SWIG, which means the API call is directly \ | ||
passed to C++ side of Paddle. | ||
|
||
The user api could be simpler and carefully designed. | ||
""" | ||
import py_paddle.swig_paddle as api | ||
from py_paddle import DataProviderConverter | ||
import paddle.trainer.PyDataProvider2 as dp | ||
import numpy as np | ||
import random | ||
from mnist_util import read_from_mnist | ||
from paddle.trainer_config_helpers import * | ||
|
||
|
||
def optimizer_config(): | ||
settings( | ||
learning_rate=1e-4, | ||
learning_method=AdamOptimizer(), | ||
batch_size=1000, | ||
model_average=ModelAverage(average_window=0.5), | ||
regularization=L2Regularization(rate=0.5)) | ||
|
||
|
||
def network_config(): | ||
imgs = data_layer(name='pixel', size=784) | ||
hidden1 = fc_layer(input=imgs, size=200) | ||
hidden2 = fc_layer(input=hidden1, size=200) | ||
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation()) | ||
cost = classification_cost( | ||
input=inference, label=data_layer( | ||
name='label', size=10)) | ||
outputs(cost) | ||
|
||
|
||
def init_parameter(network): | ||
assert isinstance(network, api.GradientMachine) | ||
for each_param in network.getParameters(): | ||
assert isinstance(each_param, api.Parameter) | ||
array_size = len(each_param) | ||
array = np.random.uniform(-1.0, 1.0, array_size).astype('float32') | ||
each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(array) | ||
|
||
|
||
def generator_to_batch(generator, batch_size): | ||
ret_val = list() | ||
for each_item in generator: | ||
ret_val.append(each_item) | ||
if len(ret_val) == batch_size: | ||
yield ret_val | ||
ret_val = list() | ||
if len(ret_val) != 0: | ||
yield ret_val | ||
|
||
|
||
class BatchPool(object): | ||
def __init__(self, generator, batch_size): | ||
self.data = list(generator) | ||
self.batch_size = batch_size | ||
|
||
def __call__(self): | ||
random.shuffle(self.data) | ||
for offset in xrange(0, len(self.data), self.batch_size): | ||
limit = min(offset + self.batch_size, len(self.data)) | ||
yield self.data[offset:limit] | ||
|
||
|
||
def input_order_converter(generator): | ||
for each_item in generator: | ||
yield each_item['pixel'], each_item['label'] | ||
|
||
|
||
def main(): | ||
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores | ||
|
||
# get enable_types for each optimizer. | ||
# enable_types = [value, gradient, momentum, etc] | ||
# For each optimizer(SGD, Adam), GradientMachine should enable different | ||
# buffers. | ||
opt_config_proto = parse_optimizer_config(optimizer_config) | ||
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto) | ||
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config) | ||
enable_types = _temp_optimizer_.getParameterTypes() | ||
|
||
# Create Simple Gradient Machine. | ||
model_config = parse_network_config(network_config) | ||
m = api.GradientMachine.createFromConfigProto( | ||
model_config, api.CREATE_MODE_NORMAL, enable_types) | ||
|
||
# This type check is not useful. Only enable type hint in IDE. | ||
# Such as PyCharm | ||
assert isinstance(m, api.GradientMachine) | ||
|
||
# Initialize Parameter by numpy. | ||
init_parameter(network=m) | ||
|
||
# Create Local Updater. Local means not run in cluster. | ||
# For a cluster training, here we can change to createRemoteUpdater | ||
# in future. | ||
updater = api.ParameterUpdater.createLocalUpdater(opt_config) | ||
assert isinstance(updater, api.ParameterUpdater) | ||
|
||
# Initialize ParameterUpdater. | ||
updater.init(m) | ||
|
||
# DataProvider Converter is a utility convert Python Object to Paddle C++ | ||
# Input. The input format is as same as Paddle's DataProvider. | ||
converter = DataProviderConverter( | ||
input_types=[dp.dense_vector(784), dp.integer_value(10)]) | ||
|
||
train_file = './data/raw_data/train' | ||
test_file = './data/raw_data/t10k' | ||
|
||
# start gradient machine. | ||
# the gradient machine must be started before invoke forward/backward. | ||
# not just for training, but also for inference. | ||
m.start() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. start得有一个宾语。 如果是method,有时候class就是宾语,也有时候得明确指定宾语。比如
这里比较诡异的是m的类型是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好,不过用户态的代码应该不这样。。可能类似于 with gradient_machine.enter_training(): |
||
|
||
# evaluator can print error rate, etc. It is a C++ class. | ||
batch_evaluator = m.makeEvaluator() | ||
test_evaluator = m.makeEvaluator() | ||
|
||
# Get Train Data. | ||
# TrainData will stored in a data pool. Currently implementation is not care | ||
# about memory, speed. Just a very naive implementation. | ||
train_data_generator = input_order_converter(read_from_mnist(train_file)) | ||
train_data = BatchPool(train_data_generator, 512) | ||
|
||
# outArgs is Neural Network forward result. Here is not useful, just passed | ||
# to gradient_machine.forward | ||
outArgs = api.Arguments.createArguments(0) | ||
|
||
for pass_id in xrange(2): # we train 2 passes. | ||
updater.startPass() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 通常startXXX要有对应的endXXX。比如 https://golang.org/pkg/os/exec/#Cmd.Start 有对应的 https://golang.org/pkg/os/exec/#Cmd.Wait。如果意思是“做且做完”,可以叫 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 恩,之前的code开发了一半,发现其他的一些问题,然后交另一个PR了。。 这个后面的code补上了。 |
||
|
||
for batch_id, data_batch in enumerate(train_data()): | ||
# data_batch is input images. | ||
# here, for online learning, we could get data_batch from network. | ||
|
||
# Start update one batch. | ||
pass_type = updater.startBatch(len(data_batch)) | ||
|
||
# Start BatchEvaluator. | ||
# batch_evaluator can be used between start/finish. | ||
batch_evaluator.start() | ||
|
||
# forwardBackward is a shortcut for forward and backward. | ||
# It is sometimes faster than invoke forward/backward separately, | ||
# because in GradientMachine, it may be async. | ||
m.forwardBackward(converter(data_batch), outArgs, pass_type) | ||
|
||
for each_param in m.getParameters(): | ||
updater.update(each_param) | ||
|
||
# Get cost. We use numpy to calculate total cost for this batch. | ||
cost_vec = outArgs.getSlotValue(0) | ||
cost_vec = cost_vec.copyToNumpyMat() | ||
cost = cost_vec.sum() / len(data_batch) | ||
|
||
# Make evaluator works. | ||
m.eval(batch_evaluator) | ||
|
||
# Print logs. | ||
print 'Pass id', pass_id, 'Batch id', batch_id, 'with cost=', \ | ||
cost, batch_evaluator | ||
|
||
batch_evaluator.finish() | ||
# Finish batch. | ||
# * will clear gradient. | ||
# * ensure all values should be updated. | ||
updater.finishBatch(cost) | ||
|
||
# testing stage. use test data set to test current network. | ||
updater.apply() | ||
test_evaluator.start() | ||
test_data_generator = input_order_converter(read_from_mnist(test_file)) | ||
for data_batch in generator_to_batch(test_data_generator, 512): | ||
# in testing stage, only forward is needed. | ||
m.forward(converter(data_batch), outArgs, api.PASS_TEST) | ||
m.eval(test_evaluator) | ||
|
||
# print error rate for test data set | ||
print 'Pass', pass_id, ' test evaluator: ', test_evaluator | ||
test_evaluator.finish() | ||
updater.restore() | ||
|
||
updater.catchUpWith() | ||
params = m.getParameters() | ||
for each_param in params: | ||
assert isinstance(each_param, api.Parameter) | ||
value = each_param.getBuf(api.PARAMETER_VALUE) | ||
value = value.copyToNumpyArray() | ||
|
||
# Here, we could save parameter to every where you want | ||
print each_param.getName(), value | ||
|
||
updater.finishPass() | ||
|
||
m.finish() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
另外,finish和complete的区别是“完蛋”和“完美”的区别,比如:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haha |
||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import numpy | ||
|
||
__all__ = ['read_from_mnist'] | ||
|
||
|
||
def read_from_mnist(filename): | ||
imgf = filename + "-images-idx3-ubyte" | ||
labelf = filename + "-labels-idx1-ubyte" | ||
f = open(imgf, "rb") | ||
l = open(labelf, "rb") | ||
|
||
f.read(16) | ||
l.read(8) | ||
|
||
# Define number of samples for train/test | ||
if "train" in filename: | ||
n = 60000 | ||
else: | ||
n = 10000 | ||
|
||
images = numpy.fromfile( | ||
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32') | ||
images = images / 255.0 * 2.0 - 1.0 | ||
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int") | ||
|
||
for i in xrange(n): | ||
yield {"pixel": images[i, :], 'label': labels[i]} | ||
|
||
f.close() | ||
l.close() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
#include <sstream> | ||
#include "PaddleAPI.h" | ||
#include "PaddleAPIPrivate.h" | ||
|
||
Evaluator::Evaluator() : m(new EvaluatorPrivate()) {} | ||
Evaluator::~Evaluator() { delete m; } | ||
|
||
void Evaluator::start() { m->rawPtr->start(); } | ||
|
||
void Evaluator::finish() { m->rawPtr->finish(); } | ||
|
||
std::string Evaluator::toString() { | ||
std::ostringstream sout; | ||
m->rawPtr->printStats(sout); | ||
return sout.str(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我不常用Python,所以我去看了一下各种style guide里关于naming的描述。以下style guide都要求function names和method names都是
function_name
的形式,而不是createLocalUpdater
的形式。如果我们的API要被Python社区接受,我估计得是Python style的吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ParameterUpdater和createLocalUpdater是直接从cpp中用swig expose出来的结构和方法,所以是cpp中的命名规范。我理解这些都不是直接暴露给用户的,而是经过我们封装一下,统一成python style。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,不过这就是swig api的悲剧点之一了。
这个名字直接从C++的头文件PaddleAPI.h自动化翻译过来的,没有办法自定义。
如果是C-API会好很多。