Skip to content

Commit

Permalink
Enable FP16 type in operators (#3059)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3059

## Context

Enable half precision shader computation using the `GL_EXT_shader_16bit_storage` extension that was enabled in the change just below this stack.
ghstack-source-id: 222727209

Reviewed By: jorgep31415

Differential Revision: D56189470

fbshipit-source-id: 0eb5990651ad34e5a2ada601a0d3944dfe2ae9ea
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Apr 16, 2024
1 parent d481c11 commit ab62707
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 68 deletions.
6 changes: 2 additions & 4 deletions backends/vulkan/runtime/api/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,10 @@ def define_variable(name: str) -> str:


def get_buffer_scalar_type(dtype: str) -> str:
# TODO(ssjia): use float16_t for half types
if dtype == "half":
return "float"
# TODO(ssjia): use int8_t for int8 types
return "float16_t"
elif dtype[-1] == "8":
return dtype[:-1]
return dtype + "_t"

return dtype

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in;
Expand Down
26 changes: 15 additions & 11 deletions backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
Expand Down Expand Up @@ -52,20 +55,21 @@ void main() {
const ivec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(cpu_sizes.data);

SCALAR_T val_x = SCALAR_T(buffer_in.data[buf_indices.x]);
SCALAR_T val_y = SCALAR_T(buffer_in.data[buf_indices.y]);
SCALAR_T val_z = SCALAR_T(buffer_in.data[buf_indices.z]);
SCALAR_T val_w = SCALAR_T(buffer_in.data[buf_indices.w]);

VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w);

const int packed_dim_size = get_packed_dim(cpu_sizes.data);
int packed_idx = get_packed_dim(idx);

if (packed_idx + 3 >= packed_dim_size) {
ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3);
VEC4_T valid_idx = VEC4_T(lessThan(packed_ind, ivec4(packed_dim_size)));
texel = texel * valid_idx;
VEC4_T texel = VEC4_T(0);
if (packed_idx < packed_dim_size) {
texel.x = SCALAR_T(buffer_in.data[buf_indices.x]);
}
if (packed_idx + 1 < packed_dim_size) {
texel.y = SCALAR_T(buffer_in.data[buf_indices.y]);
}
if (packed_idx + 2 < packed_dim_size) {
texel.z = SCALAR_T(buffer_in.data[buf_indices.z]);
}
if (packed_idx + 3 < packed_dim_size) {
texel.w = SCALAR_T(buffer_in.data[buf_indices.w]);
}

imageStore(image_out, ${get_pos[NDIM]("pos")}, texel);
Expand Down
28 changes: 13 additions & 15 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def get_binary_elementwise_inputs():
return VkTestSuite(
test_suite = VkTestSuite(
[
((M1, M2), (M1, M2)),
((M1, M2), (M1, 1), 2.0),
Expand All @@ -31,6 +31,11 @@ def get_binary_elementwise_inputs():
((S, S1, S2), (S, 1, S2), 2.0),
]
)
test_suite.layouts = [
"api::kWidthPacked",
"api::kChannelsPacked",
]
return test_suite


def get_mm_inputs():
Expand All @@ -41,6 +46,12 @@ def get_mm_inputs():
],
)
test_suite.prepacked_args = ["mat2"]
# ATen matmul doesn't support half
test_suite.dtypes = ["at::kFloat"]
test_suite.layouts = [
"api::kWidthPacked",
"api::kChannelsPacked",
]
return test_suite


Expand All @@ -50,7 +61,6 @@ def get_pool2d_inputs():
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand Down Expand Up @@ -114,7 +124,6 @@ def get_conv2d_inputs():
),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand All @@ -123,10 +132,9 @@ def get_native_layer_norm_inputs():
[
((S1, S2), [S2], (S2), (S2), 0.001),
((M, M1, M2), [M2], (M2), (M2), 0.001),
((L, XL, M1, M2), [M2], (M2), (M2), 0.001),
((S, XL, M1, M2), [M2], (M2), (M2), 0.001),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand All @@ -138,7 +146,6 @@ def get_full_inputs():
([L, M, M1, M2], 2.72),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand All @@ -161,7 +168,6 @@ def get_select_int_inputs():
((8, 6, 1, 1), 1, 4),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand All @@ -177,11 +183,3 @@ def get_select_int_inputs():
"aten.full.default": get_full_inputs(),
"aten.select.int": get_select_int_inputs(),
}

prepacked_args = {"aten.mm.default": {"mat2"}}

support_exceptions = {
"aten.max_pool2d_with_indices.default": {
"layouts": ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
},
}
22 changes: 21 additions & 1 deletion backends/vulkan/test/op_tests/targets.bzl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID")
load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets(is_fbcode = False):
Expand Down Expand Up @@ -43,6 +45,24 @@ def define_common_targets(is_fbcode = False):
default_outs = ["."],
)

pt_operator_library(
name = "all_aten_ops",
check_decl = False,
include_all_operators = True,
)

runtime.cxx_library(
name = "all_aten_ops_lib",
srcs = [],
define_static_target = False,
exported_deps = get_pt_ops_deps(
name = "pt_ops_full",
deps = [
":all_aten_ops",
],
),
)

runtime.cxx_binary(
name = "compute_graph_op_tests_bin",
srcs = [
Expand All @@ -52,7 +72,7 @@ def define_common_targets(is_fbcode = False):
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
runtime.external_dep_location("libtorch"),
":all_aten_ops_lib",
],
)

Expand Down
53 changes: 35 additions & 18 deletions backends/vulkan/test/op_tests/utils/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,10 @@

@dataclass
class VkTestSuite(TestSuite):
supports = {
"storage_types": ["api::StorageType::TEXTURE_3D"],
"layouts": [
"api::GPUMemoryLayout::TENSOR_WIDTH_PACKED",
"api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED",
],
}
def __init__(self, input_cases: List[Any]):
super().__init__(input_cases)
self.storage_types: List[str] = ["api::kTexture3D"]
self.layouts: List[str] = ["api::kChannelsPacked"]


##########################
Expand Down Expand Up @@ -88,7 +85,6 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
self.dot = "->"

self.args = []
self.out = None
self.refs = {}

self.should_prepack = False
Expand Down Expand Up @@ -288,6 +284,7 @@ def set_output(self, ref: ValueRefList) -> str:
return ret_str

def virtual_resize(self, ref: ValueRefList) -> str:
assert isinstance(ref, ValueRef)
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
if self.prepack_ref(ref):
return ""
Expand All @@ -296,6 +293,7 @@ def virtual_resize(self, ref: ValueRefList) -> str:
return ret_str

def copy_into_staging(self, ref: ValueRefList) -> str:
assert isinstance(ref, ValueRef)
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
if self.prepack_ref(ref):
return ""
Expand Down Expand Up @@ -336,7 +334,7 @@ def check_graph_out(self, ref: ValueRefList) -> str:
ret_str += self.check_graph_out(r)
return ret_str

return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}));\n"
return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}, rtol, atol));\n"

## Top level code generation

Expand Down Expand Up @@ -374,11 +372,19 @@ def gen_graph_exec_code(self) -> str:

return graph_exec

def gen_conditional_skips(self) -> str:
skips = "if (test_dtype == at::kHalf && "
skips += f"!{self.graph}{self.dot}context()->adapter_ptr()->has_16bit_storage()) {{\n"
skips += " GTEST_SKIP();"
skips += "}\n"
return skips

def gen_op_check_fn(self) -> str:
op_name = self.f.func.name.unambiguous_name()
op_check_fn = self.gen_decl(f"check_{op_name}") + " {"
if self.should_prepack:
op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {"
op_check_fn += self.gen_conditional_skips()
op_check_fn += self.gen_graph_build_code()
op_check_fn += self.gen_graph_exec_code()
op_check_fn += self.check_graph_out(self.refs["out"])
Expand All @@ -391,19 +397,26 @@ def gen_op_check_fn(self) -> str:
##################################

test_fixture_template = """
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<api::StorageType, api::GPUMemoryLayout>> {{
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, api::StorageType, api::GPUMemoryLayout>> {{
protected:
ComputeGraph* graph;
at::ScalarType test_dtype = at::kFloat;
float rtol = 1e-5;
float atol = 1e-5;
void SetUp() override {{
GraphConfig config;
api::StorageType default_storage_type;
api::GPUMemoryLayout default_memory_layout;
std::tie(default_storage_type, default_memory_layout) = GetParam();
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
config.setStorageTypeOverride(default_storage_type);
config.setMemoryLayoutOverride(default_memory_layout);
graph = new ComputeGraph(config);
if (test_dtype == at::kHalf) {{
rtol = 1e-2;
atol = 1e-2;
}}
}}
void TearDown() override {{
Expand All @@ -420,7 +433,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple


class VkTestSuiteGen(TestSuiteGen):
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: List[Any]):
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite):
super().__init__(f, inputs)
self.op_reg_name = op_reg_name
self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def)
Expand All @@ -442,14 +455,16 @@ def generate_fixture_cpp(self) -> str:
)

def gen_parameterization(self) -> str:
storage_types = self.suite_def.supports["storage_types"]
layouts = self.suite_def.supports["layouts"]
dtypes = self.suite_def.dtypes
storage_types = self.suite_def.storage_types
layouts = self.suite_def.layouts

return f"""
INSTANTIATE_TEST_SUITE_P(
StorageLayoutCombos_{self.op_name},
Combos_{self.op_name},
GeneratedOpsTest_{self.op_name},
::testing::Combine(
::testing::Values({', '.join(dtypes)}),
::testing::Values({', '.join(storage_types)}),
::testing::Values({', '.join(layouts)})));
"""
Expand Down Expand Up @@ -494,9 +509,11 @@ def gen_parameterization(self) -> str:
return true;
}
bool is_close = at::allclose(t1, t2, rtol, atol);
if (!is_close) {
std::cout << "t1:" << t1 << std::endl;
std::cout << "t2:" << t2 << std::endl;
if (!is_close && t1.numel() < 500) {
std::cout << "reference: " << std::endl;
print(t1, 150);
std::cout << "vulkan: " << std::endl;
print(t2, 150);
}
return is_close;
}
Expand Down
Loading

0 comments on commit ab62707

Please sign in to comment.