diff --git a/nix/pkgs/buddy-mlir.nix b/nix/pkgs/buddy-mlir.nix index d382d3cec..f2e4e89a7 100644 --- a/nix/pkgs/buddy-mlir.nix +++ b/nix/pkgs/buddy-mlir.nix @@ -24,6 +24,42 @@ let hash = "sha256-Z78I9S8g9WexoX6HhxwbOD0K0awCTzsqW1ZiWObQNw0="; }; })); + + accelerate = (prev.accelerate.overridePythonAttrs (old: rec { + pname = "accelerate"; + version = "0.32.0"; + + src = fetchFromGitHub { + owner = "huggingface"; + repo = "accelerate"; + rev = "refs/tags/v${version}"; + hash = "sha256-/Is5aKTYHxvgUJSkF7HxMbEA6dgn/y5F1B3D6qSCSaE="; + }; + })); + + torch = (prev.torch.overridePythonAttrs (old: rec { + version = "2.3.1"; + src = fetchFromGitHub { + owner = "pytorch"; + repo = "pytorch"; + rev = "refs/tags/v${version}"; + fetchSubmodules = true; + hash = "sha256-vpgtOqzIDKgRuqdT8lB/g6j+oMIH1RPxdbjtlzZFjV8="; + }; + PYTORCH_BUILD_VERSION = version; + PYTORCH_BUILD_NUMBER = 0; + })); + + torchvision = prev.torchvision.overridePythonAttrs rec { + version = "0.18.1"; + + src = fetchFromGitHub { + owner = "pytorch"; + repo = "vision"; + rev = "refs/tags/v${version}"; + hash = "sha256-aFm6CyoMA8HtpOAVF5Q35n3JRaOXYswWEqfooORUKsw="; + }; + }; }; }; diff --git a/tests/pytorch/lenet/lenet.py b/tests/pytorch/lenet/lenet.py index 9fd364f68..f65a263e2 100644 --- a/tests/pytorch/lenet/lenet.py +++ b/tests/pytorch/lenet/lenet.py @@ -5,6 +5,7 @@ import numpy as np import torch from torch._inductor.decomposition import decompositions as inductor_decomp +import torch._inductor.lowering from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.graph import GraphDriver