Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Network simulator #2

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion slower/client/proxy/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from slower.server.server_model.proxy.ray_private_server_model_proxy import (
RayPrivateServerModelProxy
)
from slower.simulation.utlis.network_simulator import NetworkSimulator


class RayClientProxy(RayActorClientProxy):
Expand All @@ -21,10 +22,12 @@ class RayClientProxy(RayActorClientProxy):
def __init__(
self,
server_model_manager: ServerModelManager,
network_simulator: Optional[NetworkSimulator] = None,
**kwargs,
):
super().__init__(**kwargs)
self.server_model_manager = server_model_manager
self.network_simulator = network_simulator

def fit(
self,
Expand All @@ -36,7 +39,8 @@ def fit(
def fit(client: Client) -> common.FitRes:
server_model_proxy = RayPrivateServerModelProxy(
server_model,
request_queue_in_separate_thread=True
request_queue_in_separate_thread=True,
network_simulator=self.network_simulator
)
# also return the server_model_proxy, so that we can store it outside the
# ray actor to the shared ray memory
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional
from queue import SimpleQueue
from typing import Iterator
import threading

from slower.server.server_model.server_model import ServerModel
from slower.common import ControlCode, BatchData
from slower.server.server_model.proxy.server_model_proxy import ServerModelProxy
from slower.simulation.utlis.network_simulator import NetworkSimulator


class RayPrivateServerModelProxy(ServerModelProxy):
Expand All @@ -13,21 +15,24 @@ class RayPrivateServerModelProxy(ServerModelProxy):
def __init__(
self,
server_model: ServerModel,
request_queue_in_separate_thread: bool = True
request_queue_in_separate_thread: bool = True,
network_simulator: Optional[NetworkSimulator]=None
):
super().__init__()
self.server_model = server_model
self.request_queue = None
self.server_request_thread = None
self.request_queue_in_separate_thread = request_queue_in_separate_thread
self.network_simulator = network_simulator

def _blocking_request(self, method, batch_data, timeout):
_ = (timeout,)
if self.network_simulator is not None:
self.network_simulator.simulate_network(batch_data=batch_data)
res = getattr(self.server_model, method)(batch_data=batch_data)
return res

def _streaming_request(self, method, batch_data):

if self.request_queue is not None:
self.request_queue.put((method, batch_data))
else:
Expand All @@ -43,6 +48,8 @@ def _iterate(server_proxy: ServerModelProxy, iterator: Iterator):
for method, batch in iterator:
if batch.control_code == ControlCode.DO_CLOSE_STREAM:
break
if server_proxy.network_simulator is not None:
server_proxy.network_simulator.simulate_network(batch_data=batch)
server_proxy._blocking_request(method, batch, None)

self.request_queue = SimpleQueue()
Expand Down
19 changes: 17 additions & 2 deletions slower/simulation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


from slower.simulation.ray_transport.split_learning_actor_pool import SplitLearningVirtualClientPool
from slower.simulation.utlis.network_simulator import NetworkSimulator
from slower.client.typing import ClientFn
from slower.client.proxy.ray_client_proxy import RayClientProxy
from slower.server.server import Server
Expand All @@ -45,6 +46,7 @@ def start_simulation(
actor_type: Type[VirtualClientEngineActor] = DefaultActor,
actor_kwargs: Optional[Dict[str, Any]] = None,
actor_scheduling: Union[str, NodeAffinitySchedulingStrategy] = "DEFAULT",
network_simulator_kwargs: Optional[Dict[str, int]] = None,
) -> History:
"""Start a Ray-based Flower simulation server.

Expand Down Expand Up @@ -115,7 +117,9 @@ def start_simulation(
compute nodes (e.g. via NodeAffinitySchedulingStrategy). Please note this
is an advanced feature. For all details, please refer to the Ray documentation:
https://docs.ray.io/en/latest/ray-core/scheduling/index.html

network_simulator_kwargs: Optional[Dict[str, int]] (default: None)
Optional dictionary containing arguments to configure the network simulator. If provided,
the network simulator will be enabled.
Returns
-------
hist : flwr.server.history.History
Expand Down Expand Up @@ -257,13 +261,24 @@ def update_resources(f_stop: threading.Event) -> None:
pool.num_actors,
)

if network_simulator_kwargs:
network_simulator = NetworkSimulator(**network_simulator_kwargs)
log(
INFO,
"Flower VCE: Network simulator enabled with %s",
network_simulator_kwargs,
)
else:
network_simulator = None

# Register one RayClientProxy object for each client with the ClientManager
for cid in cids:
client_proxy = RayClientProxy(
client_fn=client_fn,
cid=cid,
actor_pool=pool,
server_model_manager=initialized_server.server_model_manager
server_model_manager=initialized_server.server_model_manager,
network_simulator=network_simulator
)
initialized_server.client_manager().register(client=client_proxy)

Expand Down
Empty file.
51 changes: 51 additions & 0 deletions slower/simulation/utlis/network_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import time
import random
import sys
from slower.common import BatchData


class NetworkSimulator:
def __init__(self, avg_latency_ms, latency_variance_ms, avg_bandwidth_mbps, bandwidth_variance_mbps):
self.avg_latency_ms = avg_latency_ms
self.latency_variance_ms = latency_variance_ms
self.avg_bandwidth_mbps = avg_bandwidth_mbps
self.bandwidth_variance_mbps = bandwidth_variance_mbps

def simulate_latency(self):
latency = random.gauss(self.avg_latency_ms, self.latency_variance_ms)
time.sleep(max(0, latency) / 1000)

def simulate_bandwidth(self, data_size):
bandwidth_mbps = max(0, random.gauss(self.avg_bandwidth_mbps, self.bandwidth_variance_mbps))
if bandwidth_mbps > 0:
transfer_time = data_size / (bandwidth_mbps * 1024 * 1024 / 8)
time.sleep(transfer_time)

def simulate_network(self, batch_data: BatchData):
data_size = get_data_size(batch_data)
self.simulate_latency()
self.simulate_bandwidth(data_size)


def get_data_size(batch_data) -> int:
def get_size(obj, seen=None):
"""Recursively finds size of objects"""
size = sys.getsizeof(obj)
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0
# Important mark as seen *before* entering recursion to gracefully handle
# self-referential objects
seen.add(obj_id)
if isinstance(obj, dict):
size += sum([get_size(v, seen) for v in obj.values()])
size += sum([get_size(k, seen) for k in obj.keys()])
elif hasattr(obj, '__dict__'):
size += get_size(vars(obj), seen)
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
size += sum([get_size(i, seen) for i in obj])
return size

return get_size(batch_data)