Skip to content

Commit

Permalink
Merge pull request apache#814 from chrishkchris/fp16dist
Browse files Browse the repository at this point in the history
SINGA-511 Compatibility of fp16 to distributed optimizer
  • Loading branch information
joddiy authored Nov 17, 2020
2 parents 207bec2 + f9b180a commit 0a4f433
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 112 deletions.
4 changes: 2 additions & 2 deletions examples/cnn/model/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def train_one_batch(self, x, y, dist_option, spars):
out = self.forward(x)
loss = self.softmax_cross_entropy(out, y)

if dist_option == 'fp32':
if dist_option == 'plain':
self.optimizer(loss)
elif dist_option == 'fp16':
elif dist_option == 'half':
self.optimizer.backward_and_update_half(loss)
elif dist_option == 'partialUpdate':
self.optimizer.backward_and_partial_update(loss)
Expand Down
4 changes: 2 additions & 2 deletions examples/cnn/model/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def train_one_batch(self, x, y, dist_option, spars):
out = self.forward(x)
loss = self.softmax_cross_entropy(out, y)

if dist_option == 'fp32':
if dist_option == 'plain':
self.optimizer(loss)
elif dist_option == 'fp16':
elif dist_option == 'half':
self.optimizer.backward_and_update_half(loss)
elif dist_option == 'partialUpdate':
self.optimizer.backward_and_partial_update(loss)
Expand Down
4 changes: 2 additions & 2 deletions examples/cnn/model/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def train_one_batch(self, x, y, dist_option, spars):
out = self.forward(x)
loss = self.softmax_cross_entropy(out, y)

if dist_option == 'fp32':
if dist_option == 'plain':
self.optimizer(loss)
elif dist_option == 'fp16':
elif dist_option == 'half':
self.optimizer.backward_and_update_half(loss)
elif dist_option == 'partialUpdate':
self.optimizer.backward_and_partial_update(loss)
Expand Down
4 changes: 2 additions & 2 deletions examples/cnn/model/xceptionnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ def forward(self, x):
def train_one_batch(self, x, y, dist_option, spars):
out = self.forward(x)
loss = self.softmax_cross_entropy(out, y)
if dist_option == 'fp32':
if dist_option == 'plain':
self.optimizer(loss)
elif dist_option == 'fp16':
elif dist_option == 'half':
self.optimizer.backward_and_update_half(loss)
elif dist_option == 'partialUpdate':
self.optimizer.backward_and_partial_update(loss)
Expand Down
2 changes: 1 addition & 1 deletion examples/cnn/train_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def run(global_rank,
sgd,
graph,
verbosity,
dist_option='fp32',
dist_option='plain',
spars=None,
precision='float32'):
dev = device.create_cuda_gpu_on(local_rank)
Expand Down
15 changes: 11 additions & 4 deletions examples/cnn/train_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

from singa import singa_wrap as singa
from singa import opt
from singa import tensor
import argparse
import train_cnn

singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}

if __name__ == '__main__':
# use argparse to get command config: max_epoch, model, data, etc. for single gpu training
parser = argparse.ArgumentParser(
Expand All @@ -31,6 +34,10 @@
choices=['resnet', 'xceptionnet', 'cnn', 'mlp'],
default='cnn')
parser.add_argument('data', choices=['cifar10', 'cifar100', 'mnist'], default='mnist')
parser.add_argument('-p',
choices=['float32', 'float16'],
default='float32',
dest='precision')
parser.add_argument('-m',
'--max-epoch',
default=10,
Expand All @@ -51,8 +58,8 @@
dest='lr')
parser.add_argument('-d',
'--dist-option',
default='fp32',
choices=['fp32','fp16','partialUpdate','sparseTopK','sparseThreshold'],
default='plain',
choices=['plain','half','partialUpdate','sparseTopK','sparseThreshold'],
help='distibuted training options',
dest='dist_option') # currently partialUpdate support graph=False only
parser.add_argument('-s',
Expand All @@ -76,9 +83,9 @@

args = parser.parse_args()

sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision])
sgd = opt.DistOpt(sgd)

train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
args.batch_size, args.model, args.data, sgd, args.graph,
args.verbosity, args.dist_option, args.spars)
args.verbosity, args.dist_option, args.spars, args.precision)
15 changes: 11 additions & 4 deletions examples/cnn/train_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@

from singa import singa_wrap as singa
from singa import opt
from singa import tensor
import argparse
import train_cnn
import multiprocessing

singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}

def run(args, local_rank, world_size, nccl_id):
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision])
sgd = opt.DistOpt(sgd, nccl_id=nccl_id, local_rank=local_rank, world_size=world_size)
train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
args.batch_size, args.model, args.data, sgd, args.graph,
args.verbosity, args.dist_option, args.spars)
args.verbosity, args.dist_option, args.spars, args.precision)


if __name__ == '__main__':
Expand All @@ -40,6 +43,10 @@ def run(args, local_rank, world_size, nccl_id):
choices=['resnet', 'xceptionnet', 'cnn', 'mlp'],
default='cnn')
parser.add_argument('data', choices=['cifar10', 'cifar100', 'mnist'], default='mnist')
parser.add_argument('-p',
choices=['float32', 'float16'],
default='float32',
dest='precision')
parser.add_argument('-m',
'--max-epoch',
default=10,
Expand All @@ -66,8 +73,8 @@ def run(args, local_rank, world_size, nccl_id):
dest='world_size')
parser.add_argument('-d',
'--dist-option',
default='fp32',
choices=['fp32','fp16','partialUpdate','sparseTopK','sparseThreshold'],
default='plain',
choices=['plain','half','partialUpdate','sparseTopK','sparseThreshold'],
help='distibuted training options',
dest='dist_option') # currently partialUpdate support graph=False only
parser.add_argument('-s',
Expand Down
4 changes: 2 additions & 2 deletions examples/mlp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def train_one_batch(self, x, y, dist_option, spars):
out = self.forward(x)
loss = self.softmax_cross_entropy(out, y)

if dist_option == 'fp32':
if dist_option == 'plain':
self.optimizer(loss)
elif dist_option == 'fp16':
elif dist_option == 'half':
self.optimizer.backward_and_update_half(loss)
elif dist_option == 'partialUpdate':
self.optimizer.backward_and_partial_update(loss)
Expand Down
21 changes: 12 additions & 9 deletions include/singa/io/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class Communicator {
float sparsThreshold, bool topK, Context *ctx);
void _sparsification(Tensor &t, Tensor *accumulation, float sparsThreshold,
bool topK, Context *ctx);
void valSparsAllReduce(size_t num, float *accumulation, Context *ctx);
void topKSparsAllReduce(size_t num, float *accumulation, Context *ctx);
void valSparsAllReduce(size_t num, void *accumulation, Context *ctx);
void topKSparsAllReduce(size_t num, void *accumulation, Context *ctx);

// last group of synchronized memory blocks
std::shared_ptr<Device> device_ = nullptr;
Expand All @@ -123,13 +123,16 @@ class Communicator {

// normal synch
size_t sendBuffOffset = 0;
float *fusedSendBuff;
float *fusedRecvBuff;
void *fusedSendBuff;
void *fusedRecvBuff;
void *offsetPointer;
size_t dataSize;
ncclDataType_t ncclType;

// half synch
bool halfInitialized;
__half *fusedSendBuffHalf;
__half *fusedRecvBuffHalf;
void *fusedSendBuffHalf;
void *fusedRecvBuffHalf;

// sparsification
cusparseHandle_t cusparse_handle;
Expand All @@ -142,9 +145,9 @@ class Communicator {
int *nnzGPU;
int *nnzAllGPU;
float threshold;
float *sparsSendBuff;
float *sparsRecvBuff;
float *backupBuff;
void *sparsSendBuff;
void *sparsRecvBuff;
void *backupBuff;
int *fusedIndex;
};
} // namespace singa
Expand Down
7 changes: 5 additions & 2 deletions python/singa/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def __init__(self,
self.weight_decay = weight_decay
else:
raise TypeError("Wrong weight_decay type")
self.decay_value = self.weight_decay(self.step_counter).as_type(self.dtype)
self.decay_value = self.weight_decay(self.step_counter).as_type(
self.dtype)

# init other params
self.nesterov = nesterov
Expand Down Expand Up @@ -308,7 +309,6 @@ def apply(self, param_name, param_value, param_grad):
minus_lr = 0.0 - self.lr_value
singa.Axpy(minus_lr.data, param_grad.data, param_value.data)


def step(self):
# increment step counter, lr and moment
super().step()
Expand Down Expand Up @@ -894,6 +894,9 @@ def backward_and_update_half(self,
acc = 0
glist = []
for p, g in autograd.backward(loss):
assert p.dtype == tensor.float32, (
'This function is only available for input tensor precision 32 bit, '
'which are converted into 16 bits before transmit')
if clipping:
g = autograd.clip(g, -clip_Value, clip_Value)
if g.size() > threshold:
Expand Down
2 changes: 0 additions & 2 deletions src/core/tensor/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ Tensor::Tensor(const Tensor &in)
block_(in.block()),
shape_(in.shape_),
stride_(in.stride_) {
// printf("i am here in &in\n");
if (block_ != nullptr) block_->IncRefCount();
}

Expand All @@ -81,7 +80,6 @@ Tensor::Tensor(Tensor &&in)
device_(in.device_),
shape_(std::move(in.shape_)),
stride_(std::move(in.stride_)) {
// printf("i am here in &&in\n");
block_ = in.block_;
in.block_ = nullptr;
}
Expand Down
Loading

0 comments on commit 0a4f433

Please sign in to comment.