Skip to content

Commit

Permalink
[pytorch] fix buddy-mlir python module and add new demo
Browse files Browse the repository at this point in the history
Signed-off-by: Avimitin <[email protected]>
  • Loading branch information
Avimitin committed Aug 7, 2024
1 parent 1b67436 commit 86cf251
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
9 changes: 8 additions & 1 deletion nix/overlay.nix
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ rec {
dramsim3 = final.callPackage ./pkgs/dramsim3.nix { };
libspike = final.callPackage ./pkgs/libspike.nix { };
libspike_interfaces = final.callPackage ../difftest/spike_interfaces { };
buddy-mlir = final.callPackage ./pkgs/buddy-mlir.nix { };

# DynamoCompiler doesn't support python 3.12+ yet
buddy-mlir = final.callPackage ./pkgs/buddy-mlir.nix { python3 = final.python311; };
buddy-mlir-pyenv = final.buddy-mlir.pythonModule.withPackages (ps: [
final.buddy-mlir
ps.torch
]);

fetchMillDeps = final.callPackage ./pkgs/mill-builder.nix { };
circt-full = final.callPackage ./pkgs/circt-full.nix { };
rvv-codegen = final.callPackage ./pkgs/rvv-codegen.nix { };
Expand Down
23 changes: 19 additions & 4 deletions nix/pkgs/buddy-mlir.nix
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
, llvmPackages_17
, fetchFromGitHub
, fetchpatch
, python3
, callPackage
}:
let
stdenv = llvmPackages_17.stdenv;
bintools = llvmPackages_17.bintools;

buddy-llvm = callPackage ./buddy-llvm.nix { inherit stdenv; };
buddy-llvm = callPackage ./buddy-llvm.nix { inherit stdenv python3; };
in
stdenv.mkDerivation {
pname = "buddy-mlir";
Expand Down Expand Up @@ -38,12 +39,26 @@ stdenv.mkDerivation {
patches = [
(fetchpatch {
url = "https://github.com/buddy-compiler/buddy-mlir/pull/359.patch";
hash = "sha256-hE2nHkuGCBbCCO5VERy5zNctjxydntOu/10J4f+D8to=";
hash = "sha256-UKDO/MWkHNBEs2nyBPGy3AT4Dm100EjHPJ0DkJ8uMvo=";
})
];

passthru.llvm = buddy-llvm;

# No need to do check, and it also takes too much time to finish.
doCheck = false;

# Here we concatenate the LLVM and Buddy python module into one directory for easier import
postFixup = ''
mkdir -p $out/lib/python${python3.pythonVersion}/site-packages
cp -vr $out/python_packages/buddy $out/lib/python${python3.pythonVersion}/site-packages/
cp -vr ${buddy-llvm}/python_packages/mlir_core/mlir $out/lib/python${python3.pythonVersion}/site-packages/
'';

passthru = {
llvm = buddy-llvm;

# Below three fields are black magic that allow site-packages automatically imported with nixpkgs hooks
pythonModule = python3;
pythonPath = [ ];
requiredPythonModules = [ ];
};
}
31 changes: 31 additions & 0 deletions tests/pytorch/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa

# Define the target function or model.
def foo(x, y):
return x * y + x

# Define the input data.
float32_in1 = torch.randn(10).to(torch.float32)
float32_in2 = torch.randn(10).to(torch.float32)
int32_in1 = torch.randint(0, 10, (10,)).to(torch.int32)
int32_in2 = torch.randint(0, 10, (10,)).to(torch.int32)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

# Pass the function and input data to the dynamo compiler's importer, the
# importer will first build a graph. Then, lower the graph to top-level IR.
# (tosa, linalg, etc.). Finally, accepts the generated module and weight parameters.
graphs = dynamo_compiler.importer(foo, *(float32_in1, float32_in2))
graph = graphs[0]
graph.lower_to_top_level_ir()

print(graph._imported_module)

0 comments on commit 86cf251

Please sign in to comment.