Skip to content

Commit

Permalink
Merge branch 'dev/3.0' into feature/ntt_benchmark_roofline_3
Browse files Browse the repository at this point in the history
  • Loading branch information
guodongliang committed Sep 30, 2024
2 parents bfcbeb1 + e6d95d2 commit 6eb378e
Show file tree
Hide file tree
Showing 76 changed files with 1,134 additions and 192 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/runtime-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ jobs:
shell: bash
run: |
conan install . --build=missing -s build_type=${{matrix.config.buildType}} -pr:a=toolchains/${{matrix.config.name}}.profile.jinja -o "&:runtime=True" -o "&:python=True" -o "&:tests=True"
cmake --preset conan-release
cmake --preset conan-runtime-release
- name: Build & Install
run: |
cmake --build build/${{matrix.config.buildType}} --config ${{matrix.config.buildType}}
Expand Down Expand Up @@ -119,7 +119,7 @@ jobs:
shell: bash
run: |
conan install . --build=missing -s build_type=${{matrix.config.buildType}} -pr:h=toolchains/${{matrix.config.name}}.profile.jinja -pr:b=toolchains/x86_64-linux.profile.jinja -o "&:runtime=True" -o "&:python=True" -o "&:tests=True"
cmake --preset conan-release
cmake --preset conan-runtime-release
- name: Build & Install
run: |
Expand Down
2 changes: 2 additions & 0 deletions conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def generate(self):
tc.variables['BUILD_TESTING'] = self.options.tests
if self.options.get_safe("python_root", default="") != "":
tc.variables['Python3_ROOT_DIR'] = self.options.python_root
if self.options.runtime:
tc.presets_prefix += "-runtime";
tc.generate()
deps = CMakeDeps(self)
deps.generate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ protected override CSymbol VisitCall(Call expr)
case IR.Tensors.Cast op:
str = $"(({op.NewType.ToC()}){arguments[0].Name})";
break;
case TIR.CPU.Memcopy op:
case TIR.Memcopy op:
IndentScope.Writer.IndWrite($"tensor_copy({arguments[1].Name}, {arguments[0].Name});\n");
break;
case TIR.CPU.Unary op:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ protected override CSymbol VisitCall(Call expr)
};

string str = string.Empty;
if (expr.Target is TIR.CPU.CPUKernelOp xpuOp)
if (expr.Target is Op kop && kop is TIR.CPU.CPUKernelOp or TIR.Memcopy)
{
foreach (var item in expr.Arguments.ToArray().OfType<TIR.Buffer>())
{
Expand All @@ -263,7 +263,7 @@ protected override CSymbol VisitCall(Call expr)
IndentScope.Writer.Write($"auto start_{CallCount} = get_ms_time();\n");
#endif
var args = expr.Arguments.ToArray().OfType<TIR.Buffer>().ToArray();
switch (xpuOp)
switch (kop)
{
case TIR.CPU.Unary unary:
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Unary.cshtml", new UnaryKernelTemplateModel
Expand Down Expand Up @@ -415,7 +415,7 @@ protected override CSymbol VisitCall(Call expr)
}).Result);

break;
case TIR.CPU.Memcopy copy:
case TIR.Memcopy copy:
IndentScope.Writer.Write($"tensor_copy({Visit(args[0]).Name}, {Visit(args[1]).Name});\n");
break;
case TIR.CPU.Gather gather:
Expand Down Expand Up @@ -477,9 +477,23 @@ protected override CSymbol VisitCall(Call expr)
Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(),
UnaryOp = UnaryOp.Erf,
}).Result);
break;
case TIR.CPU.Compare compare:
{
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Compare.cshtml", new CompareKernelTemplateModel
{
Arguments = args.Select(x => new KernelArgument { Symbol = Visit(x) }).ToArray(),
CompareOp = compare.CompareOp,
}).Result);
}

break;
case TIR.CPU.ScatterND scatterND:
IndentScope.Writer.Write($"scatter_nd({Visit(args[0]).Name}, {Visit(args[1]).Name}, {Visit(args[2]).Name}, {Visit(args[3]).Name});\n");

break;
default:
throw new NotSupportedException(xpuOp.ToString());
throw new NotSupportedException(kop.ToString());
}
#if PROFILE_CALL
IndentScope.Writer.Write($"printf(\"{expr.Target.GetType().Name} cost: %f\\n\", get_ms_time() - start_{CallCount++});\n");
Expand Down
5 changes: 5 additions & 0 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/KernelTemplateModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public class BinaryKernelTemplateModel : KernelTemplateModel
public BinaryOp BinaryOp { get; set; }
}

public class CompareKernelTemplateModel : KernelTemplateModel
{
public CompareOp CompareOp { get; set; }
}

public class TypedKernelTemplateModel<T> : KernelTemplateModel
where T : IR.Op
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@model Nncase.CodeGen.CPU.CompareKernelTemplateModel
@{
string CompareToCFunction(CompareOp op) =>
op switch
{
CompareOp.Equal => "ops::equal",
CompareOp.NotEqual => "ops::not_equal",
CompareOp.LowerThan => "ops::less",
CompareOp.LowerOrEqual => "ops::less_or_equal",
CompareOp.GreaterThan => "ops::greater",
CompareOp.GreaterOrEqual => "ops::greater_or_equal",
_ => throw new NotSupportedException($"Unsupported Compare: {op}."),
};
}
compare<@CompareToCFunction(Model.CompareOp)>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name));
15 changes: 15 additions & 0 deletions modules/Nncase.Modules.CPU/Evaluator/CPU/Boxing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ public IRType Visit(ITypeInferenceContext context, Boxing target)
{
return (IRType)new InvalidType("Same NDSBP");
}
if (inv.NdSBP.Any(sbp => sbp is SBPPartialSum))
{
DistributedUtility.TryGetDividedTensorType(inv, out var inType);
DistributedUtility.TryGetDividedTensorType(outv, out var outType);
var nonPartialSumPos = Enumerable.Range(0, inv.NdSBP.Count).Where(i => inv.NdSBP[i] is not SBPPartialSum);
if (nonPartialSumPos.Any(i => inv.NdSBP[i] is SBPSplit && outv.NdSBP[i] is SBPBroadCast))
{
// TODO: S[i]->S[i] may be a problem
return new InvalidType("Not supported NDSBP");
}
return outv;
}
else
{
return outv;
Expand Down
12 changes: 10 additions & 2 deletions modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ public IRType Visit(ITypeInferenceContext context, Binary target)
return TupleType.Void;
}

public MicroKernelInfo Visit(Binary op, AffineDim[] domain, AffineMap[] accessMaps, int[][] bufferShapes, ITargetOptions targetOptions)
public MicroKernelInfo Visit(Binary op, MicroKernelContext context)
{
return new(Enumerable.Repeat(1, domain.Length).ToArray(), Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray(), 128, 128);
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[2] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos);
}
}
5 changes: 4 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/CPUModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

using DryIoc;
using Nncase.Evaluator.Imaging;
using Nncase.Evaluator.Math;
using Nncase.Evaluator.NN;
using Nncase.Evaluator.Tensors;
using Nncase.Hosting;
using Nncase.IR.Tensors;

namespace Nncase.Evaluator.TIR.CPU;

Expand All @@ -18,7 +20,6 @@ public void ConfigureServices(IRegistrator registrator)
{
registrator.RegisterManyInterface<BinaryEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<MatmulEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<MemcopyEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PtrOfEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<SramPtrEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<TensorLoadEvaluator>(reuse: Reuse.Singleton);
Expand Down Expand Up @@ -48,5 +49,7 @@ public void ConfigureServices(IRegistrator registrator)
registrator.RegisterManyInterface<WhereEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<ExpandEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<ErfEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<CompareEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<ScatterNDEvaluator>(reuse: Reuse.Singleton);
}
}
30 changes: 30 additions & 0 deletions modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Compare.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Nncase.IR;
using Nncase.IR.Affine;
using Nncase.Schedule;
using Nncase.TIR.CPU;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class CompareEvaluator : ITypeInferencer<Compare>, IKernelInfoEvaluator<Compare>
{
public IRType Visit(ITypeInferenceContext context, Compare target)
{
return TupleType.Void;
}

public MicroKernelInfo Visit(Compare op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[2] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos);
}
}
12 changes: 10 additions & 2 deletions modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@ public sealed class MatmulEvaluator : ITypeInferencer<Matmul>, IKernelInfoEvalua
{
public IRType Visit(ITypeInferenceContext context, Matmul target) => TupleType.Void;

public MicroKernelInfo Visit(Matmul op, AffineDim[] domain, AffineMap[] accessMaps, int[][] bufferShapes, ITargetOptions targetOptions)
public MicroKernelInfo Visit(Matmul op, MicroKernelContext context)
{
return new(Enumerable.Repeat(1, domain.Length).ToArray(), Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray(), 128, 8);
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[2] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read | MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos);
}
}
15 changes: 14 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,26 @@
using System.Linq;
using Nncase.CostModel;
using Nncase.IR;
using Nncase.Schedule;
using Nncase.TIR.CPU;
using Nncase.Utilities;
using OrtKISharp;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class PackEvaluator : ITypeInferencer<Pack>
public sealed class PackEvaluator : ITypeInferencer<Pack>, IKernelInfoEvaluator<Pack>
{
public IRType Visit(ITypeInferenceContext context, Pack target) => TupleType.Void;

public MicroKernelInfo Visit(Pack op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos);
}
}
12 changes: 10 additions & 2 deletions modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,16 @@ public sealed class PackedBinaryEvaluator : ITypeInferencer<PackedBinary>, IKern
{
public IRType Visit(ITypeInferenceContext context, PackedBinary target) => TupleType.Void;

public MicroKernelInfo Visit(PackedBinary op, AffineDim[] domain, AffineMap[] accessMaps, int[][] bufferShapes, ITargetOptions targetOptions)
public MicroKernelInfo Visit(PackedBinary op, MicroKernelContext context)
{
return new(Enumerable.Repeat(1, domain.Length).ToArray(), Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray(), 128, 128);
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[2] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos);
}
}
17 changes: 17 additions & 0 deletions modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/ScatterND.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Nncase.IR;
using Nncase.IR.Affine;
using Nncase.Schedule;
using Nncase.TIR.CPU;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class ScatterNDEvaluator : ITypeInferencer<ScatterND>
{
public IRType Visit(ITypeInferenceContext context, ScatterND target)
{
return TupleType.Void;
}
}
27 changes: 9 additions & 18 deletions modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Swish.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,15 @@ public IRType Visit(ITypeInferenceContext context, Swish target)
return TupleType.Void;
}

public MicroKernelInfo Visit(Swish op, AffineDim[] domain, AffineMap[] accessMaps, int[][] bufferShapes, ITargetOptions targetOptions)
public MicroKernelInfo Visit(Swish swish, MicroKernelContext context)
{
var primitives = new int[bufferShapes[0].Length];
var multipliers = new ValueRange<int>[bufferShapes[0].Length];
for (int i = 0; i < bufferShapes[0].Length; i++)
{
if (Utilities.DistributedUtility.IsDivideExactly(bufferShapes[0][i], 4))
{
primitives[i] = 4;
multipliers[i] = new(1, bufferShapes[0][i] / 4);
}
else
{
primitives[i] = 1;
multipliers[i] = new(1, bufferShapes[0][i]);
}
}

return new MicroKernelInfo(primitives, multipliers, 128, 128);
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos);
}
}
27 changes: 9 additions & 18 deletions modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,15 @@ public IRType Visit(ITypeInferenceContext context, Unary target)
return TupleType.Void;
}

public MicroKernelInfo Visit(Unary op, AffineDim[] domain, AffineMap[] accessMaps, int[][] bufferShapes, ITargetOptions targetOptions)
public MicroKernelInfo Visit(Unary op, MicroKernelContext context)
{
var primitives = new int[bufferShapes[0].Length];
var multipliers = new ValueRange<int>[bufferShapes[0].Length];
for (int i = 0; i < bufferShapes[0].Length; i++)
{
if (Utilities.DistributedUtility.IsDivideExactly(bufferShapes[0][i], 4))
{
primitives[i] = 4;
multipliers[i] = new(1, bufferShapes[0][i] / 4);
}
else
{
primitives[i] = 1;
multipliers[i] = new(1, bufferShapes[0][i]);
}
}

return new MicroKernelInfo(primitives, multipliers, 128, 128);
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos);
}
}
Loading

0 comments on commit 6eb378e

Please sign in to comment.