From 3d32e351c3940026212ed76e2341ee6914be960a Mon Sep 17 00:00:00 2001 From: marwan2232004 <118024824+marwan2232004@users.noreply.github.com> Date: Sat, 12 Oct 2024 11:17:32 +0300 Subject: [PATCH] Fix: Deployment device problem --- Inference.py | 2 +- utils/Translation.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/Inference.py b/Inference.py index 8f0965e..657ad50 100644 --- a/Inference.py +++ b/Inference.py @@ -47,7 +47,7 @@ def predict(audio_file): dim_feedforward=2048, ) - model.load_state_dict(torch.load(model_path, weights_only=False)) + model.load_state_dict(torch.load(model_path, weights_only=False,map_location=torch.device('cpu'))) model.to(device) model.eval() diff --git a/utils/Translation.py b/utils/Translation.py index aacf74a..64645a4 100644 --- a/utils/Translation.py +++ b/utils/Translation.py @@ -52,7 +52,7 @@ class Variables: "epoch": 120, "lr": 1e-4, } - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cpu") var = Variables() @@ -516,8 +516,7 @@ def translate(sentence): model = Seq2Seq(encoder=encoder, decoder=decoder).to(var.device) model_path = "translate_v1.pth" - model.load_state_dict(torch.load(model_path, weights_only=True)) - # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model.load_state_dict(torch.load(model_path, weights_only=False,map_location=torch.device('cpu'))) model.to(var.device) model.eval()