diff --git a/Inference.py b/Inference.py index 8f0965e..657ad50 100644 --- a/Inference.py +++ b/Inference.py @@ -47,7 +47,7 @@ def predict(audio_file): dim_feedforward=2048, ) - model.load_state_dict(torch.load(model_path, weights_only=False)) + model.load_state_dict(torch.load(model_path, weights_only=False,map_location=torch.device('cpu'))) model.to(device) model.eval() diff --git a/utils/Translation.py b/utils/Translation.py index aacf74a..64645a4 100644 --- a/utils/Translation.py +++ b/utils/Translation.py @@ -52,7 +52,7 @@ class Variables: "epoch": 120, "lr": 1e-4, } - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cpu") var = Variables() @@ -516,8 +516,7 @@ def translate(sentence): model = Seq2Seq(encoder=encoder, decoder=decoder).to(var.device) model_path = "translate_v1.pth" - model.load_state_dict(torch.load(model_path, weights_only=True)) - # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model.load_state_dict(torch.load(model_path, weights_only=False,map_location=torch.device('cpu'))) model.to(var.device) model.eval()