Skip to content

Commit

Permalink
feat(cutlass/example): switch to setuptools
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 518cd31
  • Loading branch information
megvii-mge committed Feb 25, 2022
1 parent bf389d6 commit 704eca4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 30 deletions.
4 changes: 2 additions & 2 deletions examples/19_large_depthwise_conv2d_torch_extension/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ Torch Extension for DepthwiseConv2d with Implicit GEMM

### Usage

Compile and run on a worker with GPU:
Compile and install `depthwise_conv2d_implicit_gemm`

```
./depthwise_conv2d_implicit_gemm.py
./setup.py install --user
```

Original file line number Diff line number Diff line change
@@ -1,40 +1,15 @@
#!/usr/bin/env python3
import os

import torch
import torch.nn as nn
import torch.utils.cpp_extension as cpp_extension

CUTLASS_ROOT = os.path.join(os.path.dirname(__file__), "../..")
import _depthwise_conv2d_implicit_gemm_C as _extension

_extension = None

__all__ = ["DepthWiseConv2dImplicitGEMM"]

def _load_extension():
global _extension
if _extension is not None: return _extension
_extension = cpp_extension.load(
name="dwconv_implicitgemm",
sources=[
"frontend.cpp",
"forward_fp32.cu",
"backward_data_fp32.cu",
"backward_filter_fp32.cu",
"forward_fp16.cu",
"backward_data_fp16.cu",
"backward_filter_fp16.cu",
],
extra_include_paths=[
".",
os.path.join(CUTLASS_ROOT, "include"),
os.path.join(CUTLASS_ROOT, "tools", "library", "include"),
os.path.join(CUTLASS_ROOT, "tools", "util", "include"),
os.path.join(CUTLASS_ROOT, "examples", "common"),
],
verbose=True
)
return _extension


class _DepthWiseConv2dImplicitGEMMFP32(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -71,7 +46,7 @@ def backward(ctx, grad):
class DepthWiseConv2dImplicitGEMM(nn.Conv2d):
def __init__(self, channels, kernel, bias=False):
super().__init__(channels, channels, kernel, groups=channels, bias=bias)
_load_extension()
# _load_extension()

def forward(self, x):
if x.dtype == torch.float32:
Expand Down
35 changes: 35 additions & 0 deletions examples/19_large_depthwise_conv2d_torch_extension/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3
import os

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

CUTLASS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))

setup(
name='depthwise_conv2d_implicit_gemm',
py_modules=['depthwise_conv2d_implicit_gemm'],
ext_modules=[
CUDAExtension(
name='_depthwise_conv2d_implicit_gemm_C',
sources=[
"frontend.cpp",
"forward_fp32.cu",
"backward_data_fp32.cu",
"backward_filter_fp32.cu",
"forward_fp16.cu",
"backward_data_fp16.cu",
"backward_filter_fp16.cu",
],
include_dirs=[
".",
os.path.join(CUTLASS_ROOT, "include"),
os.path.join(CUTLASS_ROOT, "tools", "library", "include"),
os.path.join(CUTLASS_ROOT, "tools", "util", "include"),
os.path.join(CUTLASS_ROOT, "examples", "common"),
],
extra_compile_args=['-g']),
],
cmdclass={
'build_ext': BuildExtension
})

0 comments on commit 704eca4

Please sign in to comment.