Skip to content

Commit

Permalink
Add statistics collecting algo
Browse files Browse the repository at this point in the history
  • Loading branch information
lsago committed Sep 4, 2023
1 parent 2d349ee commit 32ebbd6
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 1 deletion.
42 changes: 42 additions & 0 deletions algos_server/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json
import logging
from typing import Any, Dict

from .algo import FederatedServerAlgo

logger = logging.getLogger(__name__)
logger.setLevel(level=logging.DEBUG)

class StatsFederatedServerAlgo(FederatedServerAlgo):

def __init__(self, params: Dict[str, Any]):
super().__init__(name="stats", params=params, model_suffix="json")

def initialize(self):
# parameters to be passed to the worker
self.params["cutoff"] = self.params.get("cutoff", 730)
self.params["delta"] = self.params.get("delta", 30)
logger.debug("Parameters to be passed to the worker have been initialized: %s", self.params)

# we definately don't need an initial model for this algorithm but
# server makes workers start training by sharing an initial model so, we
# create an empty file here...
# Hopefully it's clear by now that this whole thing is just a PoC.
# This is not an actual proper FL implementation with IDS/TSG!
with open(self.model_aggregated_path, "w+") as f:
f.write("")

def aggregate(self, current_round):
# no real aggregation, just concatenate all partial results
aggregated_results = [
json.load(open(file))
for file in self.round_partial_models[current_round]
]
logger.info("Concatenated %s partial results", len(aggregated_results))
logger.info("Saving final compiled stats")
with open(self.model_aggregated_path, "w+") as f:
json.dump(aggregated_results, f)

# we only require one round for this algorithm, so aggregation happens
# only once and we can signal we are done
return True
2 changes: 1 addition & 1 deletion algos_worker/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class NNFederatedWorkerAlgo(FederatedWorkerAlgo):
def __init__(self, params: dict = {}):
self.data: Optional[ndarray] = None #
self.data: Optional[ndarray] = None
self.labels: Optional[ndarray] = None
self.model: Optional[keras.Model] = None
self.unique_labels: Optional[int] = None
Expand Down
121 changes: 121 additions & 0 deletions algos_worker/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import json
import logging
from typing import Optional

import numpy as np
import pandas as pd

from .algo import FederatedWorkerAlgo

logger = logging.getLogger(__name__)
logger.setLevel(level=logging.DEBUG)

class StatsFederatedWorkerAlgo(FederatedWorkerAlgo):
def __init__(self, params: dict = {}):
self.data: Optional[np.ndarray] = None
super().__init__(name="stats", params=params, model_suffix="json")

def initialize(self):
# self.params['cutoff'] = self.params.get('cutoff', 730)
# self.params['delta'] = self.params.get('delta', 30)
self.cutoff = self.params.get('cutoff')
self.delta = self.params.get('delta')

# will be used to store results during "training"
self.results = {'logs': ''}

logger.info("Reading data from %s.csv", self.key)
# "training" will read from here (self.data)
self.data = pd.read_csv(f"{self.key}.csv")
logger.info("Data shape: %s", self.data.shape)
logger.info("Initialized, ready for training")

def survival_rate(self, df: pd.DataFrame, cutoff: int, delta: int) -> list:
""" Compute survival rate at certain time points after diagnosis
Parameters
----------
df
DataFrame with TNM data
cutoff
Maximum number of days for the survival rate profile
delta
Number of days between the time points in the profile
Returns
-------
survival_rates
Survival rate profile
"""

# Get survival days, here we assume the date of last follow-up as death date
df['date_of_diagnosis'] = pd.to_datetime(df['date_of_diagnosis'])
df['date_of_fu'] = pd.to_datetime(df['date_of_fu'])
df['survival_days'] = df.apply(
lambda x: (x['date_of_fu'] - x['date_of_diagnosis']).days, axis=1
)

# Get survival rate after a certain number of days
times = list(range(0, cutoff, delta))
all_alive = len(df[df['vital_status'] == 'alive'])
all_dead = len(df[df['vital_status'] == 'dead'])
survival_rates = []
for time in times:
dead = len(
df[(df['survival_days'] <= time) & (df['vital_status'] == 'dead')]
)
alive = (all_dead - dead) + all_alive
survival_rates.append(alive / len(df))

return survival_rates


def train(self, callback=None):
# statistics adapted from: https://github.com/MaastrichtU-CDS/v6-healthai-dashboard-py
logger.info('Getting centre name')
column = 'centre'
if column in self.data.columns:
centre = self.data[column].unique()[0]
self.results['organisation'] = centre
else:
self.results['organisation'] = None
self.results['logs'] += f'Column {column} not found in the data\n'

logger.info('Counting number of unique ids')
column = 'id'
if column in self.data.columns:
nids = self.data[column].nunique()
self.results['nids'] = nids
else:
self.results['logs'] += f'Column {column} not found in the data\n'

logger.info('Counting number of unique ids per stage')
column = 'stage'
if column in self.data.columns:
self.data[column] = self.data[column].str.upper()
stages = self.data.groupby([column])['id'].nunique().reset_index()
self.results[column] = stages.to_dict()
else:
self.results['logs'] += f'Column {column} not found in the data'

logger.info('Counting number of unique ids per vital status')
column = 'vital_status'
if column in self.data.columns:
vital_status = self.data.groupby([column])['id'].nunique().reset_index()
self.results[column] = vital_status.to_dict()
else:
self.results['logs'] += f'Column {column} not found in the data'

logger.info('Getting survival rates')
columns = ['date_of_diagnosis', 'date_of_fu']
if (columns[0] in self.data.columns) and (columns[1] in self.data.columns):
survival = self.survival_rate(self.data, self.cutoff, self.delta)
self.results['survival'] = survival
else:
self.results['logs'] += \
f'Columns {columns[0]} and/or {columns[1]} not found in the data'

# Save results
logger.info("Saving local statistics results")
with open(self.model_path, "w+") as f:
json.dump(self.results, f)
4 changes: 4 additions & 0 deletions federated_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from algos_worker.algo import FederatedWorkerAlgo
from algos_worker.kmeans import KmeansFederatedWorkerAlgo
from algos_worker.nn import NNFederatedWorkerAlgo
from algos_worker.stats import StatsFederatedWorkerAlgo
from dataset_handler import DataSetHandler

debug_sleep_time = int(os.environ.get("DEBUG_SLEEP_TIME", "10"))
Expand Down Expand Up @@ -61,6 +62,9 @@ async def initialize(self, request: Request) -> web.Response:
if self._params.get("algo") == "kmeans":
logger.info("Initializing Kmeans Federated Learning")
self.algo = KmeansFederatedWorkerAlgo(params=self._params)
elif self._params.get("algo") == "stats":
logger.info("Initializing Stats collection")
self.algo = StatsFederatedWorkerAlgo(params=self._params)
else:
logger.info("Initializing NN Federated Learning")
self.algo = NNFederatedWorkerAlgo(params=self._params)
Expand Down
5 changes: 5 additions & 0 deletions federated_learning_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from algos_server.algo import FederatedServerAlgo
from algos_server.kmeans import KmeansFederatedServerAlgo
from algos_server.stats import StatsFederatedServerAlgo
from algos_server.nn import NNFederatedServerAlgo


Expand Down Expand Up @@ -96,7 +97,11 @@ async def initialize(self, request: Request) -> web.Response:
self._state = FederatedLearningState.INITIALIZED

if self._params.get("algo", None) == "kmeans":
logger.info("Initializing Kmeans Federated Server Algo")
self.algo = KmeansFederatedServerAlgo(params=self._params)
elif self._params.get("algo", None) == "stats":
logger.info("Initializing Stats collection")
self.algo = StatsFederatedServerAlgo(params=self._params)
elif self._params.get("algo", None) == "nn" or self._params.get("algo", None) is None:
# researcher-gui does not send 'algo', but it's meant for NN
logger.info("Initializing NN Federated Server Algo")
Expand Down

0 comments on commit 32ebbd6

Please sign in to comment.