-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support validation set and FedEM for MF datasets #310
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the inline comments
""" | ||
Ensemble evaluation for matrix factorization model | ||
""" | ||
cur_data = ctx.cur_mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please ensure that the usage of cur_mode
is correct here.
cur_mode
: the type of our routine, chosen from"train"/"test"/"val"/"finetune"
cur_split
: the chosen data split
Besides, do we still need to name the variables withcur_data
, since they are all removed at the end of the routine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed, here we should use cur_split
# set the eval_metrics | ||
if ctx.num_samples == 0: | ||
results = { | ||
f"{cur_data}_avg_loss": ctx.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The metric calculator uses cur_split instead, please check if it's correct to use cur_data
(actually cur_mode
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed as above replied
} | ||
else: | ||
results = { | ||
f"{ctx.cur_mode}_avg_loss": ctx.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a little confused to use ctx.cur_mode
here, since we use cur_data
in line 236.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed accordingly
else: | ||
self._split_n_clients_rating_vmf(ratings, num_client, split) | ||
|
||
def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the class HMFDataset and VMFDataset also have the function _split_n_clients_rating
for HMF and VMF resepectively, maybe we don't need the functions _split_n_clients_rating_hmf
and _split_n_clients_rating_vmf
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted it in the new pr
} | ||
self.data = data | ||
|
||
def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted it in the new pr
@@ -45,7 +45,8 @@ def forward(self, indices, ratings): | |||
device=pred.device, | |||
dtype=torch.float32).to_dense() | |||
|
|||
return mask * pred, label, float(np.prod(pred.size())) / len(ratings) | |||
return mask * pred, label, torch.Tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we convert it to a Tensor, and do we need to consider the device of the Tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the conversion is for flop counting. The device is not important since after counting the flop, the tensor will be discarded.
if ctx.get("num_samples") == 0: | ||
results = { | ||
f"{ctx.cur_mode}_avg_loss": ctx.get( | ||
"loss_batch_total_{}".format(ctx.cur_mode)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little confused that in line 53, we use loss_batch_total_{ctx.cur_mode}
, while in line 58 it is ctx.loss_batch_total
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed into loss_batch_total_{ctx.cur_mode}
in line 58
@@ -66,6 +82,13 @@ def _hook_on_batch_end(self, ctx): | |||
ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size | |||
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) | |||
|
|||
if self.cfg.federate.method.lower() in ["fedem"]: | |||
# cache label for evaluation ensemble | |||
ctx.get("{}_y_true".format(ctx.cur_mode)).append( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attribute y_true
is a matrix here and can be very large for MF dataset, I'm not sure it's appropriate to storage all the labels and probs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The appended one is sparse csr_matrix
@@ -18,16 +18,20 @@ class VMFDataset: | |||
|
|||
""" | |||
def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int, | |||
test_portion: float): | |||
split: list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about enabling this change to FedNetflix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FedNetflix is inherited from MovieLensData, thus this change should be valid to FedNetflix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified according to the comments
""" | ||
Ensemble evaluation for matrix factorization model | ||
""" | ||
cur_data = ctx.cur_mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed, here we should use cur_split
# set the eval_metrics | ||
if ctx.num_samples == 0: | ||
results = { | ||
f"{cur_data}_avg_loss": ctx.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed as above replied
@@ -45,7 +45,8 @@ def forward(self, indices, ratings): | |||
device=pred.device, | |||
dtype=torch.float32).to_dense() | |||
|
|||
return mask * pred, label, float(np.prod(pred.size())) / len(ratings) | |||
return mask * pred, label, torch.Tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the conversion is for flop counting. The device is not important since after counting the flop, the tensor will be discarded.
@@ -66,6 +82,13 @@ def _hook_on_batch_end(self, ctx): | |||
ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size | |||
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) | |||
|
|||
if self.cfg.federate.method.lower() in ["fedem"]: | |||
# cache label for evaluation ensemble | |||
ctx.get("{}_y_true".format(ctx.cur_mode)).append( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The appended one is sparse csr_matrix
} | ||
else: | ||
results = { | ||
f"{ctx.cur_mode}_avg_loss": ctx.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed accordingly
if ctx.get("num_samples") == 0: | ||
results = { | ||
f"{ctx.cur_mode}_avg_loss": ctx.get( | ||
"loss_batch_total_{}".format(ctx.cur_mode)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed into loss_batch_total_{ctx.cur_mode}
in line 58
@@ -18,16 +18,20 @@ class VMFDataset: | |||
|
|||
""" | |||
def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int, | |||
test_portion: float): | |||
split: list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FedNetflix is inherited from MovieLensData, thus this change should be valid to FedNetflix
else: | ||
self._split_n_clients_rating_vmf(ratings, num_client, split) | ||
|
||
def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted it in the new pr
} | ||
self.data = data | ||
|
||
def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted it in the new pr
as the title says. Please double check the modifications related to MF. Thanks @rayrayraykk @DavdGao