Skip to content

Commit

Permalink
Support to put a torch module into all vineyard instances dispersedly (
Browse files Browse the repository at this point in the history
…#1891)

Fixes #1884

Signed-off-by: Ye Cao <[email protected]>
  • Loading branch information
dashanji authored Jun 5, 2024
1 parent 7439561 commit 3e9ff47
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 9 deletions.
4 changes: 4 additions & 0 deletions python/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,10 @@ void bind_client(py::module& mod) {
[](RPCClient* self,
const std::vector<std::shared_ptr<RemoteBlobWriter>>&
remote_blob_builders) -> std::vector<ObjectMeta> {
// Release GIL to avoid blocking the other threads
// See also
// https://pybind11.readthedocs.io/en/stable/advanced/misc.html#global-interpreter-lock-gil
py::gil_scoped_release release;
std::vector<ObjectMeta> blob_metas;
throw_on_error(
self->CreateRemoteBlobs(remote_blob_builders, blob_metas));
Expand Down
24 changes: 24 additions & 0 deletions python/vineyard/contrib/ml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,30 @@ with torch_context():
model.load_state_dict(state_dict, assign=True)
```

By default, the compression is enabled for the vineyard client. Sometimes, the compression may not be efficient for the torch modules, you can disable it as follows:

```python
from vineyard.contrib.ml.torch import torch_context
# add the client parameter to the torch_context to disable the compression
with torch_context(client):
object_id = client.put(model)

# add the client parameter to the torch_context to disable the compression
with torch_context(client):
state_dict = client.get(object_id)
```

Besides, if you want to put the torch modules into all vineyard workers spreadly to gather the network bandwidth of all workers, you can enable the spread option as follows:

```python
from vineyard.contrib.ml.torch import torch_context
with torch_context(client, spread=True):
object_id = client.put(model)

with torch_context(client):
state_dict = client.get(object_id)
```

Reference and Implementation
----------------------------

Expand Down
103 changes: 95 additions & 8 deletions python/vineyard/contrib/ml/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

import contextlib
import ctypes
import time
import warnings
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed
from math import ceil
from typing import Iterable
from typing import Iterator
from typing import List
Expand All @@ -33,6 +35,7 @@
import lazy_import

import vineyard
from vineyard import envvars
from vineyard._C import NotEnoughMemoryException
from vineyard._C import ObjectID
from vineyard._C import ObjectMeta
Expand Down Expand Up @@ -268,6 +271,83 @@ def datapipe(
return torchdata.datapipes.iter.IterableWrapper(dataset)


def distribute_tensors(client, tensor_values):
cluster_info = client.meta
instance_ids = cluster_info.keys()
chunk_size = len(cluster_info)

def split_tensors_into_chunks(tensor_values, chunk_size):
average_size = ceil(
sum(t.numel() * t.element_size() for t in tensor_values) / chunk_size
)
current_size = 0
tensor_chunks = []
current_chunk = []
for t in tensor_values:
if current_size >= average_size and current_chunk:
tensor_chunks.append(current_chunk)
current_size = 0
current_chunk = []
current_chunk.append(t)
current_size += t.numel() * t.element_size()
if current_chunk:
tensor_chunks.append(current_chunk)
return tensor_chunks

tensor_chunks = split_tensors_into_chunks(tensor_values, chunk_size)

def thread_put_torch_tensors(
cluster_info, instance_id, tensor_chunk, client, output_objects
):
compression = client.compression
connected_instance_id = (
client.instance_id if client.is_ipc else client.remote_instance_id
)
rpc_client = None
if connected_instance_id != instance_id:
instance_status = cluster_info.get(instance_id)
if instance_status is None or instance_status['rpc_endpoint'] is None:
raise RuntimeError(
"The rpc endpoint of the vineyard instance "
f"{instance_id} is not available."
)

host, port = cluster_info[instance_id]['rpc_endpoint'].split(':')
try:
with envvars('VINEYARD_RPC_SKIP_RETRY', '1'):
rpc_client = vineyard.connect(host=host, port=int(port))
rpc_client.compression = compression
except Exception as exec:
raise RuntimeError(
f"Failed to connect to the vineyard instance {instance_id} "
f"at {host}:{port}."
) from exec
used_client = rpc_client if rpc_client else client
result = put_torch_tensors(used_client, tensor_chunk)
output_objects[instance_id] = result

tensor_objects_dict = {}
with ThreadPoolExecutor() as executor:
futures = []
for instance_id, tensor_chunk in zip(instance_ids, tensor_chunks):
future = executor.submit(
thread_put_torch_tensors,
cluster_info,
instance_id,
tensor_chunk,
client,
tensor_objects_dict,
)
futures.append(future)
for future in as_completed(futures):
future.result()

tensor_objects = []
for instance_id in instance_ids:
tensor_objects.extend(tensor_objects_dict[instance_id])
return tensor_objects


def put_torch_tensors(client, tensors) -> List[Union[ObjectID, ObjectMeta]]:
pointers, sizes = [], []
tensors = [tensor.contiguous() for tensor in tensors]
Expand Down Expand Up @@ -359,8 +439,11 @@ def assign(state_dict, key_prefix, tensors):
go(value, 'tensor', tensors)

tensor_keys, tensor_values = list(tensors.keys()), list(tensors.values())
tensor_objects = put_torch_tensors(client, tensor_values)

if client.spread:
tensor_objects = distribute_tensors(client, tensor_values)
else:
tensor_objects = put_torch_tensors(client, tensor_values)
tensors = dict(zip(tensor_keys, tensor_objects))
new_value = assign(value, 'tensor', tensors)

Expand All @@ -369,7 +452,10 @@ def assign(state_dict, key_prefix, tensors):
meta['state_dict'] = to_json(new_value)
for key, tensor in tensors.items():
meta.add_member(key, tensor)
return client.create_metadata(meta)
if client.spread:
meta.set_global(True)
o = client.create_metadata(meta)
return o


def torch_module_resolver(obj, resolver, **kw):
Expand Down Expand Up @@ -420,13 +506,14 @@ def register_torch_types(builder_ctx, resolver_ctx):


@contextlib.contextmanager
def torch_context(client: Client = None):
def torch_context(client: Client = None, spread=False):
if client is not None:
with client.with_compression(False):
with context() as (builder_ctx, resolver_ctx):
with contextlib.suppress(ImportError):
register_torch_types(builder_ctx, resolver_ctx)
yield builder_ctx, resolver_ctx
with client.with_spread(spread):
with context() as (builder_ctx, resolver_ctx):
with contextlib.suppress(ImportError):
register_torch_types(builder_ctx, resolver_ctx)
yield builder_ctx, resolver_ctx
else:
with context() as (builder_ctx, resolver_ctx):
with contextlib.suppress(ImportError):
Expand Down
22 changes: 21 additions & 1 deletion python/vineyard/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def __init__(
except VineyardException:
continue

self._spread = False
self._compression = True
if self._ipc_client is None and self._rpc_client is None:
raise ConnectionError(
"Failed to connect to vineyard via both IPC and RPC connection. "
Expand All @@ -287,12 +289,22 @@ def compression(self) -> bool:
'''Whether the compression is enabled for underlying RPC client.'''
if self._rpc_client:
return self._rpc_client.compression
return None
return self._compression

@compression.setter
def compression(self, value: bool = True):
if self._rpc_client:
self._rpc_client.compression = value
self._compression = value

@property
def spread(self) -> bool:
'''Whether the spread is enabled for underlying RPC client.'''
return self._spread

@spread.setter
def spread(self, value: bool = False):
self._spread = value

@property
def ipc_client(self) -> IPCClient:
Expand Down Expand Up @@ -789,5 +801,13 @@ def with_compression(self, enabled: bool = True):
yield
self.compression = compression

@contextlib.contextmanager
def with_spread(self, enabled: bool = True):
"""Enable spread for the following put operations."""
tmp_spread = self._spread
self.spread = enabled
yield
self.spread = tmp_spread


__all__ = ['Client']

0 comments on commit 3e9ff47

Please sign in to comment.