Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training mobilenet using batch_reduce #5

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ We provide several video tutorials on YouTube.
- [Deploying LambdaML with DynamoDB](https://youtu.be/mWa3NpCcEDU)
- [Deploying LambdaML with Hybrid Parameter Server](https://youtu.be/gjmEV0RCaak)

### Development notes
- Two AWS Lambda functions are needed - Trigger and Execution function. Trigger function is used to invoke the workers, which then execute the code of the Execution Function. The Trigger function requires the name of the Execution function in order to invoke it.
- The only functionality that Trigger Function needs to be able to perform is to invoke AWS Lambda Functions. If it is not in the VPC, the execution role needs to have permissions to invoke Lambdas. If it is in the VPC, it also needs a VPC Endpoint Interface to access AWS Lambda.
- Regardless of the choice for storage (S3, Elasticache or DynamoDB), Execution function needs to be able to access S3 buckets. If it is not in the VPC, the execution role needs to have permissions to access S3. If it is in the VPC, it also needs a VPC Endpoint Gateway to access S3.
- Since Elasticache is designed to be used internally inside a VPC, if Memcached or Redis are your choice for storage, you need to have your Execution function inside a VPC.
- Due to the size requirement for the code inside a Lambda Function (50MB), you can't zip the whole LambdaML and upload it to the function. Instead, only zip the packages that you need for your project.
- In the case of Memcached, you need to change the maximum file size in the parameter group. Consider watching the ElastiCache video for more details.


## Dependencies
- awscli (version 1)
Expand Down
4 changes: 2 additions & 2 deletions examples/lambda/elasticache/dl_ec_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

def handler(event, context):

function_name = "lambda_core"
function_name = "Insert Lambda Function Name"

# dataset setting
dataset_name = 'cifar10'
data_bucket = "cifar10dataset"
n_features = 32 * 32
n_classes = 10
host = "127.0.0.1"
host = "Insert Node Endpoint Here"
port = 11211
tmp_bucket = "tmp-params"
merged_bucket = "merged-params"
Expand Down
278 changes: 278 additions & 0 deletions examples/lambda/elasticache/mobilenet_reduce_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
import os
import time
import math

import numpy as np
import json

import torch
import torch.nn.functional as F

import boto3

from utils.constants import MLModel, Optimization, Synchronization
from storage import S3Storage, MemcachedStorage
from communicator import MemcachedCommunicator

from model import deep_models
from utils.metric import Accuracy, Average


def handler(event, context):
start_time = time.time()

# dataset setting
train_file = event['train_file']
test_file = event['test_file']
data_bucket = event['data_bucket']
n_features = event['n_features']
n_classes = event['n_classes']
n_workers = event['n_workers']
worker_index = event['worker_index']
host = event['host']
port = event['port']
tmp_bucket = event['tmp_bucket']
merged_bucket = event['merged_bucket']
cp_bucket = event['cp_bucket']

# training setting
model_name = event['model']
optim = event['optim']
sync_mode = event['sync_mode']
assert model_name.lower() in MLModel.Deep_Models
assert optim.lower() in [Optimization.Grad_Avg, Optimization.Model_Avg]
assert sync_mode.lower() in Synchronization.All

# hyper-parameter
learning_rate = event['lr']
batch_size = event['batch_size']
n_epochs = event['n_epochs']
start_epoch = event['start_epoch']
run_epochs = event['run_epochs']

function_name = event['function_name']

print('data bucket = {}'.format(data_bucket))
print("train file = {}".format(train_file))
print("test file = {}".format(test_file))
print('number of workers = {}'.format(n_workers))
print('worker index = {}'.format(worker_index))
print('model = {}'.format(model_name))
print('optimization = {}'.format(optim))
print('sync mode = {}'.format(sync_mode))
print('start epoch = {}'.format(start_epoch))
print('run epochs = {}'.format(run_epochs))
print('host = {}'.format(host))
print('port = {}'.format(port))

print("Run function {}, round: {}/{}, epoch: {}/{} to {}/{}"
.format(function_name, int(start_epoch/run_epochs) + 1, math.ceil(n_epochs / run_epochs),
start_epoch + 1, n_epochs, start_epoch + run_epochs, n_epochs))

s3_storage = S3Storage()
mem_storage = MemcachedStorage(host, port)
communicator = MemcachedCommunicator(mem_storage, tmp_bucket, merged_bucket, n_workers, worker_index)
if worker_index == 0:
mem_storage.clear()
mem_storage.client.set("key", 3)
print(mem_storage.client.get("key"))

# download file from s3
local_dir = "/tmp"
read_start = time.time()
s3_storage.download(data_bucket, train_file, os.path.join(local_dir, train_file))
s3_storage.download(data_bucket, test_file, os.path.join(local_dir, test_file))
print("download file from s3 cost {} s".format(time.time() - read_start))

train_set = torch.load(os.path.join(local_dir, train_file))
test_set = torch.load(os.path.join(local_dir, test_file))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False)

print("read data cost {} s".format(time.time() - read_start))

random_seed = 100
torch.manual_seed(random_seed)

device = 'cpu'
net = deep_models.get_models(model_name).to(device)

# Loss and Optimizer
# Softmax is internally computed.
# Set parameters to be updated.
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

# load checkpoint model if it is not the first round
if start_epoch != 0:
checked_file = 'checkpoint_{}.pt'.format(start_epoch - 1)
s3_storage.download(cp_bucket, checked_file, os.path.join(local_dir, checked_file))
checkpoint_model = torch.load(os.path.join(local_dir, checked_file))

net.load_state_dict(checkpoint_model['model_state_dict'])
optimizer.load_state_dict(checkpoint_model['optimizer_state_dict'])
print("load checkpoint model at epoch {}".format(start_epoch - 1))

for epoch in range(start_epoch, min(start_epoch + run_epochs, n_epochs)):

train_loss, train_acc = train_one_epoch(epoch, net, train_loader, optimizer, worker_index, n_workers,
communicator, sync_mode)
test_loss, test_acc = test(epoch, net, test_loader)

print('Epoch: {}/{},'.format(epoch + 1, n_epochs),
'train loss: {}'.format(train_loss),
'train acc: {},'.format(train_acc),
'test loss: {}'.format(test_loss),
'test acc: {}.'.format(test_acc), )

if worker_index == 0:
mem_storage.clear()

# training is not finished yet, invoke next round
if epoch < n_epochs - 1:
checkpoint_model = {
'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss.average
}

checked_file = 'checkpoint_{}.pt'.format(epoch)

if worker_index == 0:
torch.save(checkpoint_model, os.path.join(local_dir, checked_file))
s3_storage.upload(cp_bucket, checked_file, os.path.join(local_dir, checked_file))
print("checkpoint model at epoch {} saved!".format(epoch))

print("Invoking the next round of functions. round: {}/{}, start epoch: {}, run epoch: {}"
.format(int((epoch + 1) / run_epochs) + 1, math.ceil(n_epochs / run_epochs),
epoch + 1, run_epochs))
lambda_client = boto3.client('lambda')
payload = {
'train_file': event['train_file'],
'test_file': event['test_file'],
'data_bucket': event['data_bucket'],
'n_features': event['n_features'],
'n_classes': event['n_classes'],
'n_workers': event['n_workers'],
'worker_index': event['worker_index'],
'host': event['host'],
'port': event['port'],
'tmp_bucket': event['tmp_bucket'],
'merged_bucket': event['merged_bucket'],
'cp_bucket': event['cp_bucket'],
'model': event['model'],
'optim': event['optim'],
'sync_mode': event['sync_mode'],
'lr': event['lr'],
'batch_size': event['batch_size'],
'n_epochs': event['n_epochs'],
'start_epoch': epoch + 1,
'run_epochs': event['run_epochs'],
'function_name': event['function_name']
}
lambda_client.invoke(FunctionName=function_name,
InvocationType='Event',
Payload=json.dumps(payload))

end_time = time.time()
print("Elapsed time = {} s".format(end_time - start_time))


# Train
def train_one_epoch(epoch, net, train_loader, optimizer, worker_index, n_workers,
communicator, sync_mode):
assert isinstance(communicator, MemcachedCommunicator)
net.train()

epoch_start = time.time()

epoch_cal_time = 0
epoch_comm_time = 0

train_acc = Accuracy()
train_loss = Average()

# record the architecture of the network
grads_shape = []
grads_length = []
for param in net.parameters():
grads_shape.append(param.data.numpy().shape)
l = 1
for i in param.data.numpy().shape:
l *= i
grads_length.append(l)
grad_dtype = param.data.numpy().dtype

for batch_idx, (inputs, targets) in enumerate(train_loader):
batch_start = time.time()
outputs = net(inputs)
loss = F.cross_entropy(outputs, targets)

optimizer.zero_grad()
loss.backward()

# flatten the gradient into a 1-d numpy array
grads_vector = np.empty(sum(grads_length), dtype=grad_dtype)
curr = 0
for i, param in enumerate(net.parameters()):
grads_vector[curr:curr + grads_length[i]] = np.ravel(param.grad.data.numpy())
curr += grads_length[i]

batch_cal_time = time.time() - batch_start
epoch_cal_time += batch_cal_time

batch_comm_start = time.time()
if sync_mode == "reduce":
merged_grads_vec = communicator.reduce_batch(grads_vector, epoch, batch_idx)
elif sync_mode == "reduce_scatter":
merged_grads_vec = communicator.reduce_scatter_batch(grads_vector, epoch, batch_idx)
else:
raise ValueError("Synchronization mode has to be either reduce or reduce_scatter")

# reconstruct the gradient from the 1-d vector
curr = 0
for i, param in enumerate(net.parameters()):
curr_grad = merged_grads_vec[curr:curr + grads_length[i]].reshape(grads_shape[i]) / n_workers
param.grad.data = torch.from_numpy(curr_grad)
curr += grads_length[i]

batch_comm_time = time.time() - batch_comm_start
print("one {} round cost {} s".format(sync_mode, batch_comm_time))
epoch_comm_time += batch_comm_time

train_acc.update(outputs, targets)
train_loss.update(loss.item(), inputs.size(0))

if batch_idx % 10 == 0:
print("Epoch: [{}], Batch: [{}], train loss: {}, train acc: {}, batch cost {} s, "
"cal cost {} s, comm cost {} s"
.format(epoch + 1, batch_idx + 1, train_loss, train_acc, time.time() - batch_start,
batch_cal_time, batch_comm_time))

if worker_index == 0:
delete_start = time.time()
communicator.delete_expired_batch(epoch, batch_idx)
epoch_comm_time += time.time() - delete_start

print("Epoch {} has {} batches, cost {} s, cal time = {} s, comm time = {} s"
.format(epoch + 1, batch_idx, time.time() - epoch_start, epoch_cal_time, epoch_comm_time))

return train_loss, train_acc


def test(epoch, net, test_loader):
# global best_acc
net.eval()
test_loss = Average()
test_acc = Accuracy()

with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
outputs = net(inputs)

loss = F.cross_entropy(outputs, targets)

test_loss.update(loss.item(), inputs.size(0))
test_acc.update(outputs, targets)

return test_loss, test_acc
1 change: 1 addition & 0 deletions examples/lambda/hybrid/dl_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def handler(event, context):
param_grad = np.zeros((1))
for param in model.parameters():
# print("shape of layer = {}".format(param.data.numpy().flatten().shape))
# NOTE(milos) this is very slow since np.concatenate creates a new array every time
param_grad = np.concatenate((param_grad, param.data.numpy().flatten()))
param_grad = np.delete(param_grad, 0)
#print("model_length = {}".format(param_grad.shape))
Expand Down
3 changes: 2 additions & 1 deletion examples/lambda/s3/dl_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def handler(event, context):

if worker_index == 0:
torch.save(checkpoint_model, os.path.join(local_dir, checked_file))
storage.upload_file(cp_bucket, checked_file, os.path.join(local_dir, checked_file))
storage.upload(cp_bucket, checked_file, os.path.join(local_dir, checked_file))
print("checkpoint model at epoch {} saved!".format(epoch))

print("Invoking the next round of functions. round: {}/{}, start epoch: {}, run epoch: {}"
Expand Down Expand Up @@ -206,6 +206,7 @@ def train_one_epoch(epoch, net, train_loader, optimizer, worker_index,
if sync_mode == "reduce":
merged_grads = communicator.reduce_batch_nn(pickle.dumps(grads), postfix)
elif sync_mode == "reduce_scatter":
# NOTE(milos) I think this should be reduce_scatter_batch_nn, but this is not in s3_comm
merged_grads = communicator.reduce_batch_nn(pickle.dumps(grads), postfix)

for layer_index, param in enumerate(net.parameters()):
Expand Down
Loading