Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Done] Feature/mnist train api #971

Merged
merged 35 commits into from
Dec 27, 2016
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c0e687b
Refine Code
reyoung Dec 20, 2016
06944ee
Merge branch 'feature/add_const_in_parameter_updater' into feature/mn…
reyoung Dec 20, 2016
8b4cbcf
Start doing mnist_train_api
reyoung Dec 20, 2016
ad6cb60
Merge branch 'feature/clean_gradient_machine_start' into feature/mnis…
reyoung Dec 20, 2016
025e3e9
Add GradientMachine::start/finish to API
reyoung Dec 20, 2016
677c79b
Merge branch 'feature/clean_parameter_updater_finish_pass' into featu…
reyoung Dec 20, 2016
27d87db
Wait for reading data.
reyoung Dec 21, 2016
9f5e742
A tiny fix in PyDataProvider2
reyoung Dec 21, 2016
ad93b8f
Merge branch 'feature/fix_param_hidden_in_pydp2' into feature/mnist_t…
reyoung Dec 21, 2016
5f6c4af
Try to read data in mnist
reyoung Dec 21, 2016
36d1e61
Use numpy in DenseScanner.
reyoung Dec 21, 2016
efb5c10
Merge branch 'feature/fix_swig_dense_scanner' into feature/mnist_trai…
reyoung Dec 21, 2016
20249e8
Try expose ParamUpdater::update
reyoung Dec 21, 2016
05ab22c
A simplest train file for mnist added.
reyoung Dec 21, 2016
1f4f044
A tiny fix in PyDataProvider2
reyoung Dec 21, 2016
cf5bf5b
Merge branch 'feature/fix_param_hidden_in_pydp2' into feature/mnist_t…
reyoung Dec 21, 2016
1e6c87b
Merge branch 'feature/add_const_in_gradient_machine_eval' into featur…
reyoung Dec 21, 2016
eaba2e2
Expose Evaluator API
reyoung Dec 21, 2016
409a577
Complete a very simple mnist demo.
reyoung Dec 21, 2016
06dc66b
Merge branch 'feature/fix_param_hidden_in_pydp2' into feature/mnist_t…
reyoung Dec 21, 2016
680dd92
Add AverageOptimizer, Add save parameter
reyoung Dec 22, 2016
5bca268
Add gitignore
reyoung Dec 22, 2016
59009ba
Always use copy method for numpy.
reyoung Dec 22, 2016
a31ef0c
Merge branch 'feature/mnist_train_api' of github.com:reyoung/Paddle i…
reyoung Dec 22, 2016
f06b64f
Test GPU
reyoung Dec 22, 2016
65e957c
Merge branch 'feature/mnist_train_api' of github.com:reyoung/Paddle i…
reyoung Dec 22, 2016
5a68584
Test on GPU
reyoung Dec 22, 2016
16ea66e
Merge branch 'develop' of github.com:baidu/Paddle into feature/mnist_…
reyoung Dec 22, 2016
3a80272
Add comments.
reyoung Dec 22, 2016
843b63b
add config_parser in trainer_config_helpers to seperate trainer config
jacquesqiao Dec 21, 2016
763a30f
add config_parser_utils
jacquesqiao Dec 22, 2016
9b41b08
Remove unnecessary import in api_train.py
reyoung Dec 22, 2016
f8e4b0b
Merge branch 'develop' of github.com:baidu/Paddle into feature/mnist_…
reyoung Dec 26, 2016
eefe5a7
Merge branch 'develop' of github.com:baidu/Paddle into feature/mnist_…
reyoung Dec 27, 2016
eca4592
Fix merge errors.
reyoung Dec 27, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions demo/mnist/api_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import py_paddle.swig_paddle as api
import paddle.trainer.config_parser
import numpy as np


def init_parameter(network):
assert isinstance(network, api.GradientMachine)
for each_param in network.getParameters():
assert isinstance(each_param, api.Parameter)
array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace()
assert isinstance(array, np.ndarray)
for i in xrange(len(array)):
array[i] = np.random.uniform(-1.0, 1.0)


def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
config = paddle.trainer.config_parser.parse_config(
'simple_mnist_network.py', '')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没看明白,我们昨天讨论的目的是用Python API写的Paddle程序都在一个py文件里。为什么这是还是分了两个.py文件呢?

这个PR的目的是用目前的Python API来实现MNIST training,还是用新的API来写demo呢?

Copy link
Member

@jacquesqiao jacquesqiao Dec 20, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个mnist一个基础,按照昨天我们商量的结果,@reyoung 先提供一个小的demo,我在这个基础上去掉配置文件并做一些封装,@reyoung会修改data provider直接feed数据,在这个基础上把这个mnist demo改造成一个符合讨论结果的api调用方式。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个PR的目的是用目前的Python API来实现MNIST training,还是用新的API来写demo呢?

用第一步目前的Python API实现,第二步把一堆python文件合成一个,第三步确定Python API究竟是啥样的。


opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes()

m = api.GradientMachine.createFromConfigProto(
config.model_config, api.CREATE_MODE_NORMAL, enable_types)
assert isinstance(m, api.GradientMachine)
init_parameter(network=m)

updater = api.ParameterUpdater.createLocalUpdater(opt_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我不常用Python,所以我去看了一下各种style guide里关于naming的描述。以下style guide都要求function names和method names都是 function_name的形式,而不是 createLocalUpdater 的形式。

  1. PEP style guide
  2. Google style guide
  3. Python.net

如果我们的API要被Python社区接受,我估计得是Python style的吧。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParameterUpdater和createLocalUpdater是直接从cpp中用swig expose出来的结构和方法,所以是cpp中的命名规范。我理解这些都不是直接暴露给用户的,而是经过我们封装一下,统一成python style。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,不过这就是swig api的悲剧点之一了。

这个名字直接从C++的头文件PaddleAPI.h自动化翻译过来的,没有办法自定义。

如果是C-API会好很多。

assert isinstance(updater, api.ParameterUpdater)
updater.init(m)
m.start()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start得有一个宾语。 如果是method,有时候class就是宾语,也有时候得明确指定宾语。比如

class BashCommand {
 public:
  void Run(); // Run the bash command.
  void SetStderrPipe(Pipe* p); // Set is the predicate, Pipe is the subject.
}

这里比较诡异的是m的类型是GradientMachine,如果method叫 start,我不明白是 "start computing gradient” 还是 "start machine”。根据下文看,貌似是 start_training,那么最好的安排貌似是m的类型起名叫做 Trainer 且method name是 start,这样就成了 start trainer 的意思了。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好,不过用户态的代码应该不这样。。可能类似于

with gradient_machine.enter_training():
gradient_machine.forwardBackward这样。


for _ in xrange(100):
updater.startPass()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

恩,之前的code开发了一半,发现其他的一些问题,然后交另一个PR了。。

这个后面的code补上了。


m.finish()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m.complete_training

另外,finish和complete的区别是“完蛋”和“完美”的区别,比如:

If you married a wrong woman, you are finished.

If you married the right woman, your are completed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha



if __name__ == '__main__':
main()
16 changes: 16 additions & 0 deletions demo/mnist/simple_mnist_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from paddle.trainer_config_helpers import *

settings(learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数名字应该是动词或者动宾短语,settings是一个名词。这里的函数名看上去应该是 config_training_settings吧?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里估计最后会改。。


imgs = data_layer(name='pixel', size=784)

hidden1 = fc_layer(input=imgs, size=200)
hidden2 = fc_layer(input=hidden1, size=200)

inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())

cost = classification_cost(
input=inference, label=data_layer(
name='label', size=10))

outputs(cost)
1 change: 1 addition & 0 deletions paddle/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(API_SOURCES
Matrix.cpp
Parameter.cpp
ParameterOptimizer.cpp
ParameterUpdater.cpp
SequenceGenerator.cpp
Trainer.cpp
Util.cpp
Expand Down
4 changes: 4 additions & 0 deletions paddle/api/GradientMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ GradientMachine* GradientMachine::createByModelConfig(
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
}

void GradientMachine::start() { m->machine->start(); }

void GradientMachine::finish() { m->machine->finish(); }

void GradientMachine::forward(const Arguments& inArgs,
Arguments* outArgs,
PassType passType) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/api/Paddle.swig
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ namespace std {
%newobject Parameter::getConfig;
%newobject ParameterOptimizer::create;
%newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater;

%feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide
Expand All @@ -193,4 +194,4 @@ namespace std {
%ignore OptimizationConfigPrivate;
%ignore ParameterTraverseCallbackPrivate;
%include "utils/GlobalConstants.h"
%include "api/PaddleAPI.h"
%include "api/PaddleAPI.h"
29 changes: 29 additions & 0 deletions paddle/api/PaddleAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ class OptimizationConfig {

friend class TrainerConfig;
friend class ParameterOptimizer;
friend class ParameterUpdater;
friend class Trainer;
};

Expand Down Expand Up @@ -557,6 +558,7 @@ class Parameter {
ParameterPrivate* m;
friend class UpdateCallbackWrapper;
friend class GradientMachine;
friend class ParameterUpdater;
};

struct ModelConfigPrivate;
Expand Down Expand Up @@ -714,6 +716,13 @@ class GradientMachine {
GradientMatchineCreateMode mode = CREATE_MODE_NORMAL,
const std::vector<int>& parameterTypes = defaultParamTypes);

/**
* @brief finish
*/
void finish();

void start();

/**
* The forward stage of GradientMachine.
*
Expand Down Expand Up @@ -772,6 +781,26 @@ class GradientMachine {
// Not to use c++ 11 init-list, so we use static var as function default arg.
static std::vector<int> defaultParamTypes;
friend class Trainer;
friend class ParameterUpdater;
};

struct ParameterUpdaterPrivate;
class ParameterUpdater {
private:
ParameterUpdater();

public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Google C++ Style Guide 里,function/method names 应该是 CreateLocalUpdater

虽然现在Paddle的C++ naming和Google style不同,但是我建议(至少新代码里)采用Google style,因为一旦有人违反,我们在code review comments里可以贴一个链接(如上),告诉开发者为什么naming需要修正,而不需要一遍又一遍地手工输入解释。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个条款看起来不一定是非要遵守,因为没有感受到这个条款的优势,目前Paddle中的变量命名规则是.

对于类型名,采用全大写的CamelCase. 例如 GradientMachine, NeuralNetwork.

对于变量名:

  • 普通变量名使用首字母小写的CamelCase, gradientMachine.
  • 类内成员变量,采用首字母小写的CamelCase,同时尾缀为_。 比如 gradientMachine_。
  • 对于静态全局变量,采用首字母为g的CamelCase,比如 gSyncThreadPool;

对于namespace的名称,为下划线分割的名字,例如 cpu_avx

Paddle目前有一套完整的命名风格了,所以似乎不需要非要改。

另外,关于变量命名方式的guide,有很多风格(Qt,stl,boost,linux,google)。只要在一个库里面统一,看到某一个名字,能够反应出变量大概是什么,其实就还好了。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们必须要遵守一个条款。遵守的最大用处,就是不要留给每个人纷争哪些条款值得遵守,否则review就总是会陷入在这类无休止的纷争里。

准守Google style的原因是:它对pros和cons都有分析,并且每一条条款都有一个URL。这样我们code review的时候只需要贴URL,不需要反复解释。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的。单纯就命名这一点。我们可以遵守Paddle目前使用的条款,而没有必要全改成google风格的。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reyoung paddle现在有关于naming的条款,可以在code review comments里贴URL的吗?如果没有,我建议就用Google style就好了。

~ParameterUpdater();

void init(const GradientMachine& gm);

void startPass();

void finishPass();

private:
ParameterUpdaterPrivate* m;
};

struct TrainerPrivate;
Expand Down
27 changes: 25 additions & 2 deletions paddle/api/PaddleAPIPrivate.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <memory>
#include "PaddleAPI.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/trainer/TrainerConfigHelper.h"

#pragma once
#include "paddle/parameter/ParameterUpdaterBase.h"

struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine;
Expand Down Expand Up @@ -65,3 +67,24 @@ struct ArgumentsPrivate {
return *(std::shared_ptr<T>*)(rawPtr);
}
};

struct ParameterUpdaterPrivate {
std::unique_ptr<paddle::ParameterUpdater> updater;
};

struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater,
// in other situation sharedPtr should
// contains value.

ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {}

paddle::Parameter* getPtr() {
if (sharedPtr) {
return sharedPtr.get();
} else {
return rawPtr;
}
}
};
16 changes: 1 addition & 15 deletions paddle/api/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,7 @@ limitations under the License. */

#include "paddle/parameter/Parameter.h"
#include "PaddleAPI.h"

struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr;

ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {}

paddle::Parameter* getPtr() {
if (sharedPtr) {
return sharedPtr.get();
} else {
return rawPtr;
}
}
};
#include "PaddleAPIPrivate.h"

Parameter::Parameter() : m(new ParameterPrivate()) {}

Expand Down
37 changes: 37 additions & 0 deletions paddle/api/ParameterUpdater.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "PaddleAPI.h"

#include "PaddleAPIPrivate.h"
#include "paddle/trainer/ThreadParameterUpdater.h"

ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}

ParameterUpdater *ParameterUpdater::createLocalUpdater(
OptimizationConfig *config) {
auto param = new ParameterUpdater();
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig()));
return param;
}

ParameterUpdater::~ParameterUpdater() { delete m; }

void ParameterUpdater::init(const GradientMachine &gm) {
m->updater->init(gm.m->machine->getParameters());
}

void ParameterUpdater::startPass() { m->updater->startPass(); }

void ParameterUpdater::finishPass() {}
2 changes: 1 addition & 1 deletion paddle/parameter/ParameterUpdaterBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */

namespace paddle {

void ParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void ParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters;
for (ParameterType type : getParameterTypes()) {
for (auto& para : parameters) {
Expand Down
10 changes: 5 additions & 5 deletions paddle/parameter/ParameterUpdaterBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ class ParameterUpdater {
parameterTypes_.push_back(type);
}

virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);

// called by Trainer when starting a new pass
virtual void startPass() {}

// called by Trainer then finishing a pass, ruturn true if pass accepted
virtual bool finishPass(real cost = 0) { return true; }
virtual bool finishPass() { return true; }

// called by Trainer before backward() of a batch
// Return the type of pass it needs. This pass type will be passed
Expand Down Expand Up @@ -105,16 +105,16 @@ class ParameterUpdaterComposite : public ParameterUpdater {
ParameterUpdaterComposite() {}
virtual ~ParameterUpdaterComposite() {}

virtual void init(std::vector<ParameterPtr>& parameters) = 0;
virtual void init(const std::vector<ParameterPtr>& parameters) = 0;

virtual void startPass() {
syncThreadPool_->execPlusOwner(
[&](int tid, size_t numThreads) { updaters_[tid]->startPass(); });
}

virtual bool finishPass(real cost = 0) {
virtual bool finishPass() {
syncThreadPool_->execPlusOwner(
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(cost); });
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(); });
return true;
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/trainer/ParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager(
updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); });
}

void SgdUpdaterWithCpuAverager::init(std::vector<ParameterPtr>& parameters) {
void SgdUpdaterWithCpuAverager::init(
const std::vector<ParameterPtr>& parameters) {
SgdLocalUpdater::init(parameters);
averager_->init(parameters_.size(), nullptr);
copyEvents_.resize(parameters_.size());
Expand Down
12 changes: 6 additions & 6 deletions paddle/trainer/ParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SgdLocalUpdater : public ParameterUpdater {
* be initialized.
* @param parameters The parameter need to be initialized.
*/
virtual void init(std::vector<ParameterPtr>& parameters) {
virtual void init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);
optimizer_->init(parameters_.size(), nullptr);
// check no L1 decay in parameter configs
Expand Down Expand Up @@ -102,9 +102,9 @@ class SgdLocalUpdater : public ParameterUpdater {
* @param cost sum cost during one pass.
* @return true if accept (used for owlqn).
*/
virtual bool finishPass(real cost) {
virtual bool finishPass() {
optimizer_->finishPass();
return ParameterUpdater::finishPass(cost);
return ParameterUpdater::finishPass();
}

/**
Expand Down Expand Up @@ -208,7 +208,7 @@ class SgdUpdaterWithCpuAverager : public SgdLocalUpdater {
* @brief init. Initialize cpu parameters, model average optimizer.
* @param parameters
*/
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);

virtual PassType startBatch(int64_t batchSize) {
averager_->startBatch(-1UL);
Expand All @@ -220,9 +220,9 @@ class SgdUpdaterWithCpuAverager : public SgdLocalUpdater {
averager_->startPass();
SgdLocalUpdater::startPass();
}
virtual bool finishPass(real cost) {
virtual bool finishPass() {
averager_->finishPass();
return SgdLocalUpdater::finishPass(cost);
return SgdLocalUpdater::finishPass();
}

/// apply the averaged parameter to PARAMETER_VALUE
Expand Down
11 changes: 6 additions & 5 deletions paddle/trainer/RemoteParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater(
addParameterType(PARAMETER_MOMENTUM);
}

void RemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void RemoteParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);

if (localUpdater_) {
Expand Down Expand Up @@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() {
}
}

bool RemoteParameterUpdater::finishPass(real cost) {
bool RemoteParameterUpdater::finishPass() {
if (localUpdater_) {
localUpdater_->finishPass();
}
Expand Down Expand Up @@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater(
testing_(testing),
useApplyInPserver_(false) {}

void SparseRemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void SparseRemoteParameterUpdater::init(
const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);

parameterClient_.reset(new ParameterClient2(
Expand Down Expand Up @@ -711,7 +712,7 @@ void SparseRemoteParameterUpdater::startPass() {
}
}

bool SparseRemoteParameterUpdater::finishPass(real cost) {
bool SparseRemoteParameterUpdater::finishPass() {
if (config_.algorithm() == TrainAlgorithm::SGD) {
parameterClient_->waitPassFinish();
} else {
Expand Down Expand Up @@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote(
}

void SparseRemoteParameterUpdaterComposite::init(
std::vector<ParameterPtr>& parameters) {
const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters;

std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS];
Expand Down
Loading