Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support validation set and FedEM for MF datasets #310

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

yxdyc
Copy link
Collaborator

@yxdyc yxdyc commented Aug 10, 2022

as the title says. Please double check the modifications related to MF. Thanks @rayrayraykk @DavdGao

@yxdyc yxdyc added the enhancement New feature or request label Aug 10, 2022
Copy link
Collaborator

@DavdGao DavdGao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the inline comments

"""
Ensemble evaluation for matrix factorization model
"""
cur_data = ctx.cur_mode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ensure that the usage of cur_mode is correct here.

  • cur_mode: the type of our routine, chosen from "train"/"test"/"val"/"finetune"
  • cur_split: the chosen data split
    Besides, do we still need to name the variables with cur_data, since they are all removed at the end of the routine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, here we should use cur_split

# set the eval_metrics
if ctx.num_samples == 0:
results = {
f"{cur_data}_avg_loss": ctx.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metric calculator uses cur_split instead, please check if it's correct to use cur_data(actually cur_mode)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed as above replied

}
else:
results = {
f"{ctx.cur_mode}_avg_loss": ctx.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a little confused to use ctx.cur_mode here, since we use cur_data in line 236.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed accordingly

else:
self._split_n_clients_rating_vmf(ratings, num_client, split)

def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the class HMFDataset and VMFDataset also have the function _split_n_clients_rating for HMF and VMF resepectively, maybe we don't need the functions _split_n_clients_rating_hmf and _split_n_clients_rating_vmf here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted it in the new pr

}
self.data = data

def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted it in the new pr

@@ -45,7 +45,8 @@ def forward(self, indices, ratings):
device=pred.device,
dtype=torch.float32).to_dense()

return mask * pred, label, float(np.prod(pred.size())) / len(ratings)
return mask * pred, label, torch.Tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we convert it to a Tensor, and do we need to consider the device of the Tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the conversion is for flop counting. The device is not important since after counting the flop, the tensor will be discarded.

if ctx.get("num_samples") == 0:
results = {
f"{ctx.cur_mode}_avg_loss": ctx.get(
"loss_batch_total_{}".format(ctx.cur_mode)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little confused that in line 53, we use loss_batch_total_{ctx.cur_mode}, while in line 58 it is ctx.loss_batch_total

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed into loss_batch_total_{ctx.cur_mode} in line 58

@@ -66,6 +82,13 @@ def _hook_on_batch_end(self, ctx):
ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.))

if self.cfg.federate.method.lower() in ["fedem"]:
# cache label for evaluation ensemble
ctx.get("{}_y_true".format(ctx.cur_mode)).append(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attribute y_true is a matrix here and can be very large for MF dataset, I'm not sure it's appropriate to storage all the labels and probs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The appended one is sparse csr_matrix

@@ -18,16 +18,20 @@ class VMFDataset:

"""
def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int,
test_portion: float):
split: list):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enabling this change to FedNetflix?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FedNetflix is inherited from MovieLensData, thus this change should be valid to FedNetflix

Copy link
Collaborator Author

@yxdyc yxdyc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified according to the comments

"""
Ensemble evaluation for matrix factorization model
"""
cur_data = ctx.cur_mode
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, here we should use cur_split

# set the eval_metrics
if ctx.num_samples == 0:
results = {
f"{cur_data}_avg_loss": ctx.get(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed as above replied

@@ -45,7 +45,8 @@ def forward(self, indices, ratings):
device=pred.device,
dtype=torch.float32).to_dense()

return mask * pred, label, float(np.prod(pred.size())) / len(ratings)
return mask * pred, label, torch.Tensor(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the conversion is for flop counting. The device is not important since after counting the flop, the tensor will be discarded.

@@ -66,6 +82,13 @@ def _hook_on_batch_end(self, ctx):
ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.))

if self.cfg.federate.method.lower() in ["fedem"]:
# cache label for evaluation ensemble
ctx.get("{}_y_true".format(ctx.cur_mode)).append(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The appended one is sparse csr_matrix

}
else:
results = {
f"{ctx.cur_mode}_avg_loss": ctx.get(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed accordingly

if ctx.get("num_samples") == 0:
results = {
f"{ctx.cur_mode}_avg_loss": ctx.get(
"loss_batch_total_{}".format(ctx.cur_mode)),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed into loss_batch_total_{ctx.cur_mode} in line 58

@@ -18,16 +18,20 @@ class VMFDataset:

"""
def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int,
test_portion: float):
split: list):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FedNetflix is inherited from MovieLensData, thus this change should be valid to FedNetflix

else:
self._split_n_clients_rating_vmf(ratings, num_client, split)

def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted it in the new pr

}
self.data = data

def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted it in the new pr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants