Skip to content

Commit

Permalink
Merge branch 'master' into nncase-studio
Browse files Browse the repository at this point in the history
  • Loading branch information
FusionBolt authored Nov 17, 2023
2 parents 8a2ad23 + 5364b01 commit 96e15e7
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 95 deletions.
163 changes: 91 additions & 72 deletions src/Native/src/kernels/stackvm/reference/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,80 +28,99 @@ using namespace nncase::kernels::stackvm;
namespace {
// softmax(x) = exp(x - reduce_max(x)) / reduce_sum(exp(x - reduce_max(x)))
template <typename T>
result<void> softmax_impl(const T *input, T *output,
gsl::span<const size_t> in_shape,
gsl::span<const size_t> in_strides,
gsl::span<const size_t> out_strides, int64_t axis,
float beta, bool needLog = false) noexcept {
result<void>
softmax_impl(const T *input, T *output, gsl::span<const size_t> in_shape,
NNCASE_UNUSED gsl::span<const size_t> in_strides,
NNCASE_UNUSED gsl::span<const size_t> out_strides, int64_t axis,
float beta, bool needLog = false) noexcept {
size_t positive_axis = axis < 0 ? in_shape.size() + axis : axis;
dims_t axes{positive_axis};

auto reduced_shape =
kernels::detail::get_reduced_shape(in_shape, axes, true);
auto reduced_strides = get_default_strides(reduced_shape);
auto reduced_size = compute_size(reduced_shape);
std::vector<T> tmp(reduced_size, std::numeric_limits<T>::lowest());

// reduce_max
try_(apply(in_shape, [&](gsl::span<const size_t> index) -> result<void> {
auto in_idx = offset(in_strides, index);
const auto in = input[in_idx];

const auto out_index =
kernels::detail::get_reduced_offset(index, axes, true);
auto out_idx = offset(reduced_strides, out_index);
auto &out = tmp[out_idx];

out = std::max(in, out);
return ok();
}));

// x - reduce_max
try_(apply(in_shape, [&](gsl::span<const size_t> index) -> result<void> {
auto in_idx = offset(in_strides, index);
const auto in = input[in_idx];

const auto out_index =
kernels::detail::get_reduced_offset(index, axes, true);
auto max_idx = offset(reduced_strides, out_index);

auto out_idx = offset(out_strides, index);
output[out_idx] =
static_cast<T>(static_cast<float>(in - tmp[max_idx]) * beta);

return ok();
}));

// exp(x - reduce_max) and sum
tmp.assign(tmp.size(), static_cast<T>(0));
try_(apply(in_shape, [&](gsl::span<const size_t> index) -> result<void> {
auto in_idx = offset(out_strides, index);
const auto in = output[in_idx];

const auto out_index =
kernels::detail::get_reduced_offset(index, axes, true);
auto out_idx = offset(reduced_strides, out_index);
output[in_idx] = static_cast<T>(expf(static_cast<float>(in)));
tmp[out_idx] += static_cast<T>(output[in_idx]);

return ok();
}));

// div
try_(apply(in_shape, [&](gsl::span<const size_t> index) -> result<void> {
const auto in_index =
kernels::detail::get_reduced_offset(index, axes, true);
auto in_idx = offset(reduced_strides, in_index);
auto in = tmp[in_idx];

auto out_idx = offset(out_strides, index);
auto &out = output[out_idx];
out /= in;
if (needLog) {
out = static_cast<T>(std::log(static_cast<float>(out)));

if (positive_axis == in_shape.size() - 1) {
size_t reduced_size = in_shape[positive_axis];
auto out_size = compute_size(in_shape) / reduced_size;
std::vector<T> tmp(reduced_size, std::numeric_limits<T>::lowest());

for (size_t i = 0; i < out_size; i++) {
auto in_ = input + i * reduced_size;
auto out_ = output + i * reduced_size;

// reduce_max
auto max_value = *in_;
for (size_t j = 0; j < reduced_size; j++) {
max_value = std::max(max_value, in_[j]);
}

// (x - reduce_max) * beta
for (size_t j = 0; j < reduced_size; j++) {
out_[j] = static_cast<T>((static_cast<float>(in_[j]) -
static_cast<float>(max_value)) *
beta);
}

// exp((x - reduce_max) * beta) and sum
T sum = 0;
for (size_t j = 0; j < reduced_size; j++) {
out_[j] = static_cast<T>(expf(static_cast<float>(out_[j])));
sum += out_[j];
}

// div
for (size_t j = 0; j < reduced_size; j++) {
out_[j] /= sum;
if (needLog) {
out_[j] =
static_cast<T>(std::log(static_cast<float>(out_[j])));
}
}
}
} else {
size_t axis_size = in_shape[positive_axis];
size_t reduced_size = 1;
for (size_t i = positive_axis + 1; i < in_shape.size(); i++) {
reduced_size *= in_shape[i];
}
return ok();
}));
auto out_size = compute_size(in_shape) / reduced_size / axis_size;

for (size_t i = 0; i < out_size; i++) {
std::vector<T> axis_sum(reduced_size, static_cast<T>(0));
std::vector<T> max_value(reduced_size,
std::numeric_limits<T>::lowest());
auto in_ = input + i * reduced_size * axis_size;
auto out_ = output + i * reduced_size * axis_size;

// reduce_max
for (size_t k = 0; k < axis_size; k++) {
auto in_k = in_ + k * reduced_size;
for (size_t j = 0; j < reduced_size; j++) {
max_value[j] = std::max(max_value[j], in_k[j]);
}
}

// exp((x - reduce_max) * beta) and sum
for (size_t k = 0; k < axis_size; k++) {
auto in_k = in_ + k * reduced_size;
auto out_k = out_ + k * reduced_size;
for (size_t j = 0; j < reduced_size; j++) {
out_k[j] =
static_cast<T>(expf((static_cast<float>(in_k[j]) -
static_cast<float>(max_value[j])) *
beta));
axis_sum[j] += out_k[j];
}
}

// div
for (size_t k = 0; k < axis_size; k++) {
auto out_k = out_ + k * reduced_size;
for (size_t j = 0; j < reduced_size; j++) {
out_k[j] /= axis_sum[j];
if (needLog)
out_k[j] = static_cast<T>(
std::log(static_cast<float>((out_k[j]))));
}
}
}
}

return ok();
}
Expand Down
34 changes: 34 additions & 0 deletions src/Native/src/kernels/stackvm/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,25 @@ result<value_t> nncase::kernels::stackvm::get_item(
#undef RETURN_RESULT
return err(std::errc::not_supported);
}

if (input_tensor->shape().size() == 2 && begins_value.size() == 1) {
auto get_item_index = begins_value[0];
auto out_shape = dims_t{input_tensor->shape()[1]};
try_output(out_mem, output, input_tensor->dtype(), out_shape);
auto size = input_tensor->shape()[1];
#define RETURN_RESULT(_in_type) \
if (cmp_type<_in_type>(input_tensor->dtype())) { \
for (int i = 0; i < size; ++i) { \
OUT_CAST(_in_type, out_mem) \
[i] = IN_CAST(_in_type, in_mem)[get_item_index * size + i]; \
} \
return ok(output); \
}
RETURN_RESULT_SELECT(RETURN_RESULT);
#undef RETURN_RESULT
return err(std::errc::not_supported);
}

auto n = begins_value.size();
auto in_shape = input_tensor->shape();
auto ends_value = axes_t(n, 0);
Expand All @@ -423,6 +442,8 @@ result<value_t> nncase::kernels::stackvm::get_item(
out_mem, in_shape, input_tensor->strides(),
output_tensor->strides(), begin_values, end_values,
strides_values, context);
output = tensor_reshape(output_tensor,
dims_t(out_shape.begin() + n, out_shape.end()));
KERNEL_FINISH;
}
}
Expand Down Expand Up @@ -771,6 +792,18 @@ result<value_t> nncase::kernels::stackvm::bucket_pad(
try_dims_v(shape);
auto in_tensor = input.as<tensor>().expect("input is not a tensor");
auto in_shape = in_tensor->shape();
if (compute_size(in_shape) > compute_size(shape_value)) {
std::cout << "in shape" << std::endl;
for (int i = 0; i < in_shape.size(); ++i) {
std::cout << in_shape[i] << std::endl;
}
std::cout << "shape_value shape" << std::endl;
for (int i = 0; i < shape_value.size(); ++i) {
std::cout << shape_value[i] << std::endl;
}
return err(std::errc::invalid_argument);
}

auto paddings = std::vector<int>(8);
auto rank = shape_value.size();
for (int i = 0; i < rank; ++i) {
Expand Down Expand Up @@ -1105,6 +1138,7 @@ nncase::kernels::stackvm::squeeze(value_t input, value_t dim, value_t output,
try_var(in_tensor, input.as<tensor>());
auto in_shape = in_tensor->shape();
not_impl_no_contiguous(in_tensor);
// todo: dim is scalar
try_positive_axes(axes, dim, in_tensor->shape().size());
auto new_shape = squeeze_infer_shape(in_shape, axes);
output = tensor_reshape(in_tensor, new_shape);
Expand Down
2 changes: 2 additions & 0 deletions src/Nncase.Core/IR/Expr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ public DataType CheckedDataType
{
case TensorType type:
return type.DType;
case DistributedType type:
return type.TensorType.DType;
default:
if (DumpScope.Current.IsEnabled(DumpFlags.Compile))
{
Expand Down
93 changes: 78 additions & 15 deletions src/Nncase.Core/Utilities/DistributedUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static IReadOnlyList<IRArray<SBP>> GetLeafCandidateNDSBPs(TensorType tens
var ndsbp = new List<SBP>();
for (int axis = 0; axis < tensorType.Shape.Rank; axis++)
{
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i]))
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivideBy(s, placement.Hierarchy[i]))
{
ndsbp.Add(SBP.S(axis));
}
Expand All @@ -28,7 +28,7 @@ public static IReadOnlyList<IRArray<SBP>> GetLeafCandidateNDSBPs(TensorType tens

return ndsbps.CartesianProduct().
Select(ndsbp => ndsbp.ToArray()).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).
Select(ndsbp => new IRArray<SBP>(ndsbp)).
ToArray();
}
Expand All @@ -53,7 +53,7 @@ public static IReadOnlyList<IRArray<SBP>> GetPartialCandidateNDSBPs(DistributedT
candidateNdsbps[i].Add(SBP.B);
for (int axis = 0; axis < tensorType.Shape.Rank; axis++)
{
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis))
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivideBy(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis))
{
candidateNdsbps[i].Add(SBP.S(axis));
}
Expand All @@ -67,38 +67,101 @@ public static IReadOnlyList<IRArray<SBP>> GetPartialCandidateNDSBPs(DistributedT

return candidateNdsbps.CartesianProduct().
Select(ndsbp => ndsbp.ToArray()).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).
Select(ndsbp => new IRArray<SBP>(ndsbp)).
ToArray();
}

public static bool IsDistributable(TensorType tensorType, ReadOnlySpan<SBP> ndsbp, Placement placement, [MaybeNullWhen(false)] out TensorType distType)
public static bool IsDistributable(TensorType tensorType, ReadOnlySpan<SBP> ndsbp, Placement placement)
{
distType = null;
if (!tensorType.Shape.IsFixed)
{
return false;
}

var shape = tensorType.Shape.ToValueArray();
for (int i = 0; i < ndsbp.Length; i++)
var divisors = GetDivisors(new DistributedType(tensorType, new IRArray<SBP>(ndsbp.ToArray()), placement));
return divisors.Select((d, axis) => (d, axis)).All(p => p.d == 0 ? true : IsDivideBy(tensorType.Shape[p.axis].FixedValue, p.d));
}

public static IReadOnlyList<int> GetDivisors(DistributedType distributedType)
{
var shape = distributedType.TensorType.Shape.ToValueArray();
var divisors = Enumerable.Repeat(0, shape.Length).ToArray();
for (int i = 0; i < distributedType.NdSBP.Count; i++)
{
if (ndsbp[i] is SBPSplit { Axis: int axis })
if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis })
{
if (!IsDivisible(shape[axis], placement.Hierarchy[i]))
if (divisors[axis] == 0)
{
return false;
divisors[axis] = 1;
}

shape[axis] /= placement.Hierarchy[i];
divisors[axis] *= distributedType.Placement.Hierarchy[i];
}
}

distType = tensorType with { Shape = shape };
return true;
return divisors;
}

public static bool TryGetDividedTensorType(DistributedType distributedType, [System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out TensorType tensorType)
{
tensorType = null;
var divisors = GetDivisors(distributedType);
if (divisors.Select((d, i) => (d, i)).All(p => p.d == 0 || IsDivideExactly(distributedType.TensorType.Shape[p.i].FixedValue, p.d)))
{
tensorType = new TensorType(distributedType.TensorType.DType, distributedType.TensorType.Shape.Zip(divisors).Select(p => p.Second == 0 ? p.First.FixedValue : p.First.FixedValue / p.Second).ToArray());
return true;
}

return false;
}

public static Expr[] TryGetNonUniformDividedShape(DistributedType distributedType)
{
var shape = distributedType.TensorType.Shape.ToValueArray();
var hierarchies = Enumerable.Range(0, shape.Length).Select(i => new List<int>()).ToArray();
var ids = distributedType.Placement.Name.Select(c => new Var(c + "id", TensorType.Scalar(DataTypes.Int32))).ToArray();
var hierarchyStrides = TensorUtilities.GetStrides(distributedType.Placement.Hierarchy.ToArray());
for (int i = 0; i < distributedType.NdSBP.Count; i++)
{
if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis })
{
hierarchies[axis].Add(i);
}
}

return hierarchies.Select((divs, axis) =>
{
Expr dim;
if (divs.Any())
{
var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray());
var (res, rem) = Math.DivRem(shape[axis], divsor);
dim = IR.F.Math.Select(
TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1),
res,
res + rem);
}
else
{
dim = distributedType.TensorType.Shape[axis].FixedValue;
}
return dim;
}).ToArray();
}

public static bool IsDivideBy(int input, int divisor)
{
if (input >= divisor)
{
return true;
}

return false;
}

public static bool IsDivisible(int input, int divisor)
public static bool IsDivideExactly(int input, int divisor)
{
if (input >= divisor && input % divisor == 0)
{
Expand Down
Loading

0 comments on commit 96e15e7

Please sign in to comment.