Skip to content

Commit

Permalink
reshape related fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Oct 29, 2024
1 parent d905dac commit ca59611
Show file tree
Hide file tree
Showing 14 changed files with 576 additions and 136 deletions.
41 changes: 33 additions & 8 deletions onnxruntime/core/optimizer/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
k = dim_k.dim_value();
n = dim_n.dim_value();
ORT_ENFORCE(shape_values.back() == k);
m = std::accumulate(shape_values.begin(), shape_values.end() - 1, static_cast<int64_t>(1),
std::multiplies<int64_t>());

Check warning on line 134 in onnxruntime/core/optimizer/matmul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <functional> for multiplies<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/matmul_add_fusion.cc:134: Add #include <functional> for multiplies<> [build/include_what_you_use] [4]
}
Expand Down Expand Up @@ -167,8 +166,10 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}

auto gemm_output_defs = add_node.MutableOutputDefs();
Node* input_node = nullptr;
Node* output_node = nullptr;
if (need_reshape) {
auto add_reshape = [&](const std::vector<int64_t>& shape, Graph& graph, bool is_input) {
auto add_reshape = [&](const std::vector<int64_t>& shape, Graph& graph, bool is_input) -> Node* {

Check warning on line 172 in onnxruntime/core/optimizer/matmul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/matmul_add_fusion.cc:172: Add #include <vector> for vector<> [build/include_what_you_use] [4]
const std::string name = is_input ? "gemm_input" : "gemm_output";

Check warning on line 173 in onnxruntime/core/optimizer/matmul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/matmul_add_fusion.cc:173: Add #include <string> for string [build/include_what_you_use] [4]
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_shape"));
Expand All @@ -187,23 +188,47 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
{is_input ? gemm_input_defs[0] : new_arg, shape_arg},
{is_input ? new_arg : gemm_output_defs[0]});
reshape_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());
return new_arg;
return &reshape_node;
};

gemm_input_defs[0] = add_reshape({m, k}, graph, true);
input_node = add_reshape({m, k}, graph, true);
gemm_input_defs[0] = input_node->MutableOutputDefs()[0];
shape_values.back() = n;
gemm_output_defs[0] = add_reshape(shape_values, graph, false);
output_node = add_reshape(shape_values, graph, false);
gemm_output_defs[0] = output_node->MutableInputDefs()[0];
}

Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"), "Gemm",
"fused Matmul and Add", gemm_input_defs, gemm_output_defs);

// Assign provider to this new node. Provider should be same as the provider for old node.
gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());

if (need_reshape) {
graph.AddEdge(input_node->Index(), gemm_node.Index(), 0, 0);
graph.AddEdge(gemm_node.Index(), output_node->Index(), 0, 0);
} else {
input_node = &gemm_node;
output_node = &gemm_node;
}

auto matmul_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(matmul_node);
for (auto cur = matmul_input_edges.cbegin(), end = matmul_input_edges.cend(); cur != end; ++cur) {
if (cur->dst_arg_index == 0) {
graph.AddEdge(cur->src_node, input_node->Index(), cur->src_arg_index, 0);
} else if (cur->dst_arg_index == 1) {
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 1);
}
}
graph_utils::GraphEdge::RemoveGraphEdges(graph, matmul_input_edges);
auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(add_node);
for (auto cur = add_input_edges.cbegin(), end = add_input_edges.cend(); cur != end; ++cur) {
if (cur->dst_arg_index == 1) {
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 2);
}
}
graph_utils::GraphEdge::RemoveGraphEdges(graph, add_input_edges);
graph_utils::RemoveNodeOutputEdges(graph, matmul_node);
graph_utils::ReplaceDownstreamNodeInput(graph, add_node, 0, *output_node, 0);
graph.RemoveNode(matmul_node.Index());
graph_utils::RemoveNodeOutputEdges(graph, add_node);
graph.RemoveNode(add_node.Index());

modified = true;
Expand Down
71 changes: 71 additions & 0 deletions onnxruntime/core/optimizer/reshape_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
fused_count++;
LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name();
modified = true;
} else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph, logger)) {
modified = true;
}
}

Expand Down Expand Up @@ -452,4 +454,73 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo
return true;
}

bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph, const logging::Logger& logger) {
ORT_UNUSED_PARAMETER(logger);
InlinedVector<Node*> contiguous_reshapes{&reshape};
InlinedVector<int64_t> shape_value;
while (true) {
Node* p_curr_node = contiguous_reshapes.back();
if (graph.NodeProducesGraphOutput(*p_curr_node) || p_curr_node->GetOutputEdgesCount() != 1) {
break;
}

Node* p_next_node = graph.GetNode(p_curr_node->OutputNodesBegin()->Index());
if (p_next_node->OpType() != "Reshape" && p_next_node->OpType() != "Squeeze" &&
p_next_node->OpType() != "Unsqueeze") {
break;
}

auto shape = p_next_node->OutputDefs()[0]->Shape();
if (!shape) {
break;
}

bool is_concrete_shape = true;
shape_value.clear();
for (const auto& dim : shape->dim()) {
if (dim.has_dim_value()) {
shape_value.emplace_back(dim.dim_value());
} else {
is_concrete_shape = false;
}
}
if (!is_concrete_shape) {
break;
}

contiguous_reshapes.emplace_back(p_next_node);
}

if (contiguous_reshapes.size() < 2) {
return false;
}

const std::string& name = contiguous_reshapes[0]->Name();

Check warning on line 498 in onnxruntime/core/optimizer/reshape_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/reshape_fusion.cc:498: Add #include <string> for string [build/include_what_you_use] [4]
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_new_shape"));
shape_initializer_proto.add_dims(static_cast<int64_t>(shape_value.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_initializer_proto.set_raw_data(shape_value.data(), shape_value.size() * sizeof(int64_t));
NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name,
{contiguous_reshapes[0]->MutableInputDefs()[0], shape_arg},
{contiguous_reshapes.back()->MutableOutputDefs()[0]});
reshape_node.SetExecutionProviderType(contiguous_reshapes[0]->GetExecutionProviderType());

auto input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*contiguous_reshapes[0]);
for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) {
if (cur->dst_arg_index == 0) {
graph.AddEdge(cur->src_node, reshape_node.Index(), cur->src_arg_index, 0);
}
}
graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges);
graph_utils::ReplaceDownstreamNodeInput(graph, *contiguous_reshapes.back(), 0, reshape_node, 0);
for (Node* p_node : contiguous_reshapes) {
graph_utils::RemoveNodeOutputEdges(graph, *p_node);
graph.RemoveNode(p_node->Index());
}

return true;
}

} // namespace onnxruntime
5 changes: 5 additions & 0 deletions onnxruntime/core/optimizer/reshape_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class ReshapeFusion : public GraphTransformer {
static bool Is_One_Element_Input(const Node& cur_node, int index);
static bool Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg& root_input, const Node& concat,
int index, gsl::span<const int64_t> shape_value, const logging::Logger& logger);

// Remove contiguous Reshape/Squeeze/Unsqueeze if the shape info is concrete.
// For some EP, such reshape Ops are not no-op, such as QNN EP, memory is allocated for each output,
// so this fusion can help to reduce memory usage on such devices.
static bool FuseContiguousReshapes(Node& reshape, Graph& graph, const logging::Logger& logger);
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#include <core/providers/common.h>

#include "core/providers/shared/utils/utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "core/common/safeint.h"

namespace onnxruntime {
Expand Down Expand Up @@ -271,37 +269,6 @@ Status BaseOpBuilder::SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper&
return Status::OK();
}

Status BaseOpBuilder::TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper,
const onnx::TensorProto& initializer,
const std::vector<size_t>& perm,
std::vector<uint8_t>& transposed_data) const {
const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(initializer.data_type())->GetElementType();
const auto tensor_shape_dims = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer);
TensorShape tensor_shape{tensor_shape_dims};
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>();
Tensor in_tensor = Tensor(tensor_dtype, tensor_shape, cpu_allocator);

auto rank = perm.size();
std::vector<int64_t> new_tensor_shape_dims;
std::vector<size_t> permutations;
new_tensor_shape_dims.reserve(rank);
permutations.reserve(rank);
for (int64_t p : perm) {
permutations.push_back(p);
new_tensor_shape_dims.push_back(tensor_shape_dims[p]);
}

TensorShape new_tensor_shape(new_tensor_shape_dims);
Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator);
ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor(
Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), initializer, in_tensor));
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor));
onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test");
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data));

return Status::OK();
}

Status BaseOpBuilder::ProcessAxisAttribute(const QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
Qnn_Scalar_t& axis_qnn_scalar,
Expand Down
82 changes: 0 additions & 82 deletions onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,88 +214,6 @@ class BaseOpBuilder : public IOpBuilder {
return it->second;
}

// NCHW shape to channel last
Status NchwShapeToNhwc(const std::vector<uint32_t>& nchw_shape, std::vector<uint32_t>& nhwc_shape) const {
ORT_RETURN_IF_NOT(nchw_shape.size() == 4, "shape should have 4 dimension NCHW.");
nhwc_shape[0] = nchw_shape[0];
nhwc_shape[1] = nchw_shape[2];
nhwc_shape[2] = nchw_shape[3];
nhwc_shape[3] = nchw_shape[1];

return Status::OK();
}

// NCHW shape to HWCN shape, required for Conv weight
Status NchwShapeToHwcn(const std::vector<uint32_t>& nchw_shape, std::vector<uint32_t>& hwcn_shape) const {
if (nchw_shape.size() == 4) {
hwcn_shape[0] = nchw_shape[2];
hwcn_shape[1] = nchw_shape[3];
hwcn_shape[2] = nchw_shape[1];
hwcn_shape[3] = nchw_shape[0];
} else if (nchw_shape.size() == 5) {
hwcn_shape[0] = nchw_shape[2];
hwcn_shape[1] = nchw_shape[3];
hwcn_shape[2] = nchw_shape[4];
hwcn_shape[3] = nchw_shape[1];
hwcn_shape[4] = nchw_shape[0];
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5.");
}

return Status::OK();
}

// CNHW shape to HWCN shape, required for Conv weight
Status CnhwShapeToHwcn(const std::vector<uint32_t>& cnhw_shape, std::vector<uint32_t>& hwcn_shape) const {
if (cnhw_shape.size() == 4) {
hwcn_shape[0] = cnhw_shape[2];
hwcn_shape[1] = cnhw_shape[3];
hwcn_shape[2] = cnhw_shape[0];
hwcn_shape[3] = cnhw_shape[1];
} else if (cnhw_shape.size() == 5) {
hwcn_shape[0] = cnhw_shape[2];
hwcn_shape[1] = cnhw_shape[3];
hwcn_shape[2] = cnhw_shape[4];
hwcn_shape[3] = cnhw_shape[0];
hwcn_shape[4] = cnhw_shape[1];
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5.");
}

return Status::OK();
}
Status TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper,
const onnx::TensorProto& initializer,
const std::vector<size_t>& perm,
std::vector<uint8_t>& transposed_data) const;

Status TransposeFromNchwToHwcn(const QnnModelWrapper& qnn_model_wrapper,
const onnx::TensorProto& initializer,
std::vector<uint8_t>& transposed_data,
bool is_3d = false) const {
auto& perm = is_3d ? nchw2hwcn_perm_3d : nchw2hwcn_perm;
return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data);
}

Status TransposeFromCnhwToHwcn(const QnnModelWrapper& qnn_model_wrapper,
const onnx::TensorProto& initializer,
std::vector<uint8_t>& transposed_data,
bool is_3d = false) const {
auto& perm = is_3d ? cnhw2hwcn_perm_3d : cnhw2hwcn_perm;
return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data);
}

Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper,
std::vector<uint32_t>& data_shape,
const onnx::TensorProto& initializer,
std::vector<uint8_t>& transposed_data) const {
auto tmp = data_shape[0];
data_shape[0] = data_shape[1];
data_shape[1] = tmp;
std::vector<size_t> two_dim_trans_perm{1, 0};
return TransposeInitializer(qnn_model_wrapper, initializer, two_dim_trans_perm, transposed_data);
}

// Onnx Pads is [x1_begin, x2_begin, x1_end, x2_end], QNN requires [x1_begin, x1_end, x2_begin, x2_end]
void ReArranagePads(std::vector<uint32_t>& pads) const {
auto pads_size = pads.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper,

// Change shape to HWCN, it could be initializer or normal input
if (conv_type == OnnxConvType::kConv) {
ORT_RETURN_IF_ERROR(NchwShapeToHwcn(input_info.shape, actual_shape));
ORT_RETURN_IF_ERROR(utils::NchwShapeToHwcn(input_info.shape, actual_shape));
} else if (conv_type == OnnxConvType::kConvTranspose) {
ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(input_info.shape, actual_shape));
ORT_RETURN_IF_ERROR(utils::CnhwShapeToHwcn(input_info.shape, actual_shape));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str());
}
Expand All @@ -224,9 +224,9 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper,
if (input_info.is_initializer) {
// Get transposed initializer bytes.
if (conv_type == OnnxConvType::kConv) {
ORT_RETURN_IF_ERROR(TransposeFromNchwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d));
ORT_RETURN_IF_ERROR(utils::TransposeFromNchwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d));

Check warning on line 227 in onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc:227: Lines should be <= 120 characters long [whitespace/line_length] [2]
} else if (conv_type == OnnxConvType::kConvTranspose) {
ORT_RETURN_IF_ERROR(TransposeFromCnhwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d));
ORT_RETURN_IF_ERROR(utils::TransposeFromCnhwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d));

Check warning on line 229 in onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc:229: Lines should be <= 120 characters long [whitespace/line_length] [2]
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str());
}
Expand Down Expand Up @@ -413,9 +413,9 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper,

// Create the final shape after the weights are transposed to HWCN.
if (conv_type == OnnxConvType::kConv) {
ORT_RETURN_IF_ERROR(NchwShapeToHwcn(shape_2d, final_shape));
ORT_RETURN_IF_ERROR(utils::NchwShapeToHwcn(shape_2d, final_shape));
} else if (conv_type == OnnxConvType::kConvTranspose) {
ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(shape_2d, final_shape));
ORT_RETURN_IF_ERROR(utils::CnhwShapeToHwcn(shape_2d, final_shape));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str());
}
Expand Down Expand Up @@ -453,9 +453,9 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper,
// Get transposed initializer bytes.
//
if (conv_type == OnnxConvType::kConv) {
ORT_RETURN_IF_ERROR(TransposeFromNchwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor));
ORT_RETURN_IF_ERROR(utils::TransposeFromNchwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor));
} else if (conv_type == OnnxConvType::kConvTranspose) {
ORT_RETURN_IF_ERROR(TransposeFromCnhwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor));
ORT_RETURN_IF_ERROR(utils::TransposeFromCnhwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,8 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
if (1 == input_trans_flag.at(input_i)) {
ORT_RETURN_IF_ERROR(quantize_param.HandleTranspose<size_t>(std::vector<size_t>({1, 0})));
ORT_RETURN_IF_ERROR(TwoDimensionTranspose(qnn_model_wrapper,
input_shape,
*input_tensor,
unpacked_tensor));
ORT_RETURN_IF_ERROR(
utils::TwoDimensionTranspose(qnn_model_wrapper, input_shape, *input_tensor, unpacked_tensor));
} else {
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h"

namespace onnxruntime {
namespace qnn {
Expand Down Expand Up @@ -92,6 +93,7 @@ static std::unique_ptr<IQnnNodeGroup> TryQnnFusions(
{"HardSigmoid", HardSigmoidMulFusion::TryFusion},
{"Conv", ConvActivationFusion::TryFusion},
{"ConvTranspose", ConvActivationFusion::TryFusion},
{"Gemm", ReshapeGemmFusion::TryFusion},
};

// For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes).
Expand Down
Loading

0 comments on commit ca59611

Please sign in to comment.