Skip to content

Commit

Permalink
Debug and add assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishkchris committed Nov 5, 2020
1 parent c089958 commit 3f6e15a
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 22 deletions.
11 changes: 9 additions & 2 deletions examples/cnn/train_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

from singa import singa_wrap as singa
from singa import opt
from singa import tensor
import argparse
import train_cnn

singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}

if __name__ == '__main__':
# use argparse to get command config: max_epoch, model, data, etc. for single gpu training
parser = argparse.ArgumentParser(
Expand All @@ -31,6 +34,10 @@
choices=['resnet', 'xceptionnet', 'cnn', 'mlp'],
default='cnn')
parser.add_argument('data', choices=['cifar10', 'cifar100', 'mnist'], default='mnist')
parser.add_argument('-p',
choices=['float32', 'float16'],
default='float32',
dest='precision')
parser.add_argument('-m',
'--max-epoch',
default=10,
Expand Down Expand Up @@ -76,9 +83,9 @@

args = parser.parse_args()

sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision])
sgd = opt.DistOpt(sgd)

train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
args.batch_size, args.model, args.data, sgd, args.graph,
args.verbosity, args.dist_option, args.spars)
args.verbosity, args.dist_option, args.spars, args.precision)
11 changes: 9 additions & 2 deletions examples/cnn/train_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@

from singa import singa_wrap as singa
from singa import opt
from singa import tensor
import argparse
import train_cnn
import multiprocessing

singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}

def run(args, local_rank, world_size, nccl_id):
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision])
sgd = opt.DistOpt(sgd, nccl_id=nccl_id, local_rank=local_rank, world_size=world_size)
train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
args.batch_size, args.model, args.data, sgd, args.graph,
args.verbosity, args.dist_option, args.spars)
args.verbosity, args.dist_option, args.spars, args.precision)


if __name__ == '__main__':
Expand All @@ -40,6 +43,10 @@ def run(args, local_rank, world_size, nccl_id):
choices=['resnet', 'xceptionnet', 'cnn', 'mlp'],
default='cnn')
parser.add_argument('data', choices=['cifar10', 'cifar100', 'mnist'], default='mnist')
parser.add_argument('-p',
choices=['float32', 'float16'],
default='float32',
dest='precision')
parser.add_argument('-m',
'--max-epoch',
default=10,
Expand Down
4 changes: 2 additions & 2 deletions include/singa/io/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ class Communicator {

// half synch
bool halfInitialized;
__half *fusedSendBuffHalf;
__half *fusedRecvBuffHalf;
void *fusedSendBuffHalf;
void *fusedRecvBuffHalf;

// sparsification
cusparseHandle_t cusparse_handle;
Expand Down
7 changes: 5 additions & 2 deletions python/singa/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def __init__(self,
self.weight_decay = weight_decay
else:
raise TypeError("Wrong weight_decay type")
self.decay_value = self.weight_decay(self.step_counter).as_type(self.dtype)
self.decay_value = self.weight_decay(self.step_counter).as_type(
self.dtype)

# init other params
self.nesterov = nesterov
Expand Down Expand Up @@ -308,7 +309,6 @@ def apply(self, param_name, param_value, param_grad):
minus_lr = 0.0 - self.lr_value
singa.Axpy(minus_lr.data, param_grad.data, param_value.data)


def step(self):
# increment step counter, lr and moment
super().step()
Expand Down Expand Up @@ -894,6 +894,9 @@ def backward_and_update_half(self,
acc = 0
glist = []
for p, g in autograd.backward(loss):
assert p.dtype == tensor.float32, (
'This function is only available for input tensor precision 32 bit, '
'which are converted into 16 bits before transmit')
if clipping:
g = autograd.clip(g, -clip_Value, clip_Value)
if g.size() > threshold:
Expand Down
2 changes: 0 additions & 2 deletions src/core/tensor/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ Tensor::Tensor(const Tensor &in)
block_(in.block()),
shape_(in.shape_),
stride_(in.stride_) {
// printf("i am here in &in\n");
if (block_ != nullptr) block_->IncRefCount();
}

Expand All @@ -81,7 +80,6 @@ Tensor::Tensor(Tensor &&in)
device_(in.device_),
shape_(std::move(in.shape_)),
stride_(std::move(in.stride_)) {
// printf("i am here in &&in\n");
block_ = in.block_;
in.block_ = nullptr;
}
Expand Down
35 changes: 23 additions & 12 deletions src/io/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,11 @@ void Communicator::fusedSynch(vector<Tensor> &t, bool send) {
// memory copy to fusedBuff
for (size_t i = 0; i < t.size(); i++) {
if (t[0].data_type() == kFloat16) {
offsetPointer =
(void *)(static_cast<__half *>(fusedRecvBuff) + sendBuffOffset);
offsetPointer = (void *)(static_cast<__half *>(fusedSendBuff) +
sendBuffOffset);
} else {
offsetPointer =
(void *)(static_cast<float *>(fusedRecvBuff) + sendBuffOffset);
offsetPointer = (void *)(static_cast<float *>(fusedSendBuff) +
sendBuffOffset);
}
CUDA_CHECK(cudaMemcpyAsync(
(void *)offsetPointer,
Expand All @@ -264,8 +264,8 @@ void Communicator::fusedSynch(vector<Tensor> &t, bool send) {

device_->Exec(
[this](Context *ctx) mutable {
allReduce((int)sendBuffOffset, (void *)fusedSendBuff,
(void *)fusedRecvBuff, ncclType, ctx);
allReduce((int)sendBuffOffset, fusedSendBuff, fusedRecvBuff, ncclType,
ctx);
sendBuffOffset = 0;
},
prev_blocks_, blocks_, "Dist_s_fusedSynch_allreduce");
Expand Down Expand Up @@ -328,6 +328,10 @@ void Communicator::synch(Tensor &t) {
} // namespace singa

void Communicator::fusedSynchHalf(vector<Tensor> &t, bool send) {
CHECK_EQ(t[0].data_type(), kFloat32)
<< "This function is only available for input tensor precision 32 bit, "
"which are converted into 16 bits before transmit";

CHECK_GT(t.size(), 0);

generateBlocks(t);
Expand Down Expand Up @@ -375,8 +379,8 @@ void Communicator::fusedSynchHalf(vector<Tensor> &t, bool send) {
blocks_, blocks_, "Waiting");
device_->Exec(
[this](Context *ctx) mutable {
allReduce((int)sendBuffOffset, (void *)fusedSendBuffHalf,
(void *)fusedRecvBuffHalf, ncclHalf, ctx);
allReduce((int)sendBuffOffset, fusedSendBuffHalf, fusedRecvBuffHalf,
ncclHalf, ctx);
},
blocks_, blocks_, "Dist_s_fusedSynchHalf_allreduce");
device_->Exec(
Expand Down Expand Up @@ -410,6 +414,11 @@ void Communicator::fusedSynchHalf(vector<Tensor> &t, bool send) {
}

void Communicator::synchHalf(Tensor &t) {
// tensor precision is 32 bit, convert to 16 bit before transmit
CHECK_EQ(t.data_type(), kFloat32)
<< "This function is only available for input tensor precision 32 bit, "
"which are converted into 16 bits before transmit";

generateBlocks(t);

if (halfInitialized == false) halfInit();
Expand All @@ -424,7 +433,8 @@ void Communicator::synchHalf(Tensor &t) {
device_->Exec(
[this, t](Context *ctx) mutable {
float *addr = static_cast<float *>(t.block()->mutable_data());
cuda::float2half(t.Size(), addr, fusedSendBuffHalf, ctx->c1);
cuda::float2half(t.Size(), addr,
static_cast<__half *>(fusedSendBuffHalf), ctx->c1);
},
blocks_, blocks_, "Dist_c1_synchHalf_float2half");
device_->Exec(
Expand All @@ -436,8 +446,8 @@ void Communicator::synchHalf(Tensor &t) {
blocks_, blocks_, "Waiting");
device_->Exec(
[this, t](Context *ctx) mutable {
allReduce(t.Size(), (void *)fusedSendBuffHalf,
(void *)fusedRecvBuffHalf, ncclHalf, ctx);
allReduce(t.Size(), fusedSendBuffHalf, fusedRecvBuffHalf, ncclHalf,
ctx);
},
blocks_, blocks_, "Dist_s_synchHalf_allreduce");
device_->Exec(
Expand All @@ -450,7 +460,8 @@ void Communicator::synchHalf(Tensor &t) {
device_->Exec(
[this, t](Context *ctx) mutable {
float *addr = static_cast<float *>(t.block()->mutable_data());
cuda::half2float(t.Size(), fusedRecvBuffHalf, addr, ctx->c2);
cuda::half2float(t.Size(), static_cast<__half *>(fusedRecvBuffHalf),
addr, ctx->c2);
},
blocks_, blocks_, "Dist_c2_synchHalf_half2float");
}
Expand Down

0 comments on commit 3f6e15a

Please sign in to comment.