Skip to content

Commit

Permalink
Fix: Deployment device problem
Browse files Browse the repository at this point in the history
  • Loading branch information
marwan2232004 committed Oct 12, 2024
1 parent 6dc6fd8 commit 3d32e35
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def predict(audio_file):
dim_feedforward=2048,
)

model.load_state_dict(torch.load(model_path, weights_only=False))
model.load_state_dict(torch.load(model_path, weights_only=False,map_location=torch.device('cpu')))
model.to(device)
model.eval()

Expand Down
5 changes: 2 additions & 3 deletions utils/Translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Variables:
"epoch": 120,
"lr": 1e-4,
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")


var = Variables()
Expand Down Expand Up @@ -516,8 +516,7 @@ def translate(sentence):

model = Seq2Seq(encoder=encoder, decoder=decoder).to(var.device)
model_path = "translate_v1.pth"
model.load_state_dict(torch.load(model_path, weights_only=True))
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.load_state_dict(torch.load(model_path, weights_only=False,map_location=torch.device('cpu')))
model.to(var.device)
model.eval()

Expand Down

0 comments on commit 3d32e35

Please sign in to comment.