Skip to content

Commit

Permalink
Fix/SK-1076 | FutureWarning torch.load (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankJonasmoelle authored Oct 4, 2024
1 parent 64004d5 commit 932129d
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/FedSimSiam/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def load_data(data_path, is_train=True):
data_path = os.environ.get(
"FEDN_DATA_PATH", abs_path+"/data/clients/1/cifar10.pt")

data = torch.load(data_path)
data = torch.load(data_path, weights_only=True)

if is_train:
X = data["x_train"]
Expand Down
2 changes: 1 addition & 1 deletion examples/huggingface/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def load_data(data_path=None, is_train=True):
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/enron_spam.pt")
data = torch.load(data_path)
data = torch.load(data_path, weights_only=True)
if is_train:
X = data["X_train"]
y = data["y_train"]
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist-pytorch-DPSGD/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def load_data(data_path, is_train=True):
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt")

print("data path: ", data_path)
data = torch.load(data_path)
data = torch.load(data_path, weights_only=True)

if is_train:
X = data["x_train"]
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist-pytorch/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def load_data(data_path, is_train=True):
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt")

data = torch.load(data_path)
data = torch.load(data_path, weights_only=True)

if is_train:
X = data["x_train"]
Expand Down

0 comments on commit 932129d

Please sign in to comment.