From 945799cc49d539bcc375271d95fc96ea43b7ec62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 8 Aug 2023 19:55:42 +0800 Subject: [PATCH] add unary parallel --- .../CodeGen/CSourceBuiltn.cs | 3 + .../CodeGen/CSourceConvertVisitor.cs | 88 +++++++++++++++++-- .../Passes/Tile/SingleCPUFusionConverter.cs | 4 +- 3 files changed, 87 insertions(+), 8 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index b2e0a6fcb3..a945b05958 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -72,6 +72,9 @@ public static class CSourceBuiltn bool (*bool_binary_and)(bool, bool); bool (*bool_binary_or)(bool, bool); bool (*bool_binary_xor)(bool, bool); + // paralell + void *thread_start(void *(*callable)(void *args), void *user); + void thread_end(); } nncase_mt_t;"; public const string Include = @"#include diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index bc715c4f2f..2803ebba94 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -27,8 +27,8 @@ internal struct IndentScope : IDisposable public IndentScope(StringBuilder sb) { _initialized = true; + _originalWriter = _writer.Value; _writer.Value = new IndentWriter(sb); - _originalWriter = null; } public IndentScope() @@ -70,6 +70,12 @@ public CSymbol(string type, string name) public string Name { get; } public override string ToString() => $"{Type} {Name}"; + + public static IReadOnlyList Builtns => new CSymbol[] { + new CSymbol("nncase_mt_t*", "nncase_mt"), + new CSymbol("uint8_t*", "data"), + new CSymbol("const uint8_t*", "rdata"), + }; } internal sealed class IndentWriter : StringWriter @@ -100,16 +106,22 @@ internal sealed class CSourceConvertVisitor : ExprFunctor { private readonly Dictionary _exprMemo; private readonly StringBuilder _implBuilder; + private readonly StringBuilder _declBuilder; + private readonly StringWriter _declWriter; public CSourceConvertVisitor() { _implBuilder = new StringBuilder(); + _declBuilder = new StringBuilder(); + _declWriter = new StringWriter(_declBuilder); _exprMemo = new(ReferenceEqualityComparer.Instance); } + public PrimFunction VisitEntry => (TIR.PrimFunction)VisitRoot!; + public FunctionCSource GetFunctionCSource() { - return new(_exprMemo[VisitRoot!].Type + ";", _implBuilder.ToString()); + return new(_declBuilder.ToString(), _implBuilder.ToString()); } /// @@ -127,6 +139,9 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) var type = $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b.MemSpan.Start).ToString()).ToArray())}, {CSourceBuiltn.FixedParameters})"; + _declWriter.WriteLine(type + ";"); + _declWriter.WriteLine(); + using (var scope = new IndentScope(_implBuilder)) { // 1. Function signature @@ -142,7 +157,8 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) IndentScope.Writer.IndWrite("}\n"); } - symbol = new(type, new(expr.Name)); + var ctype = $"void (*{expr.Name})({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b.MemSpan.Start).ToString()).ToArray())}, {CSourceBuiltn.FixedParameters})"; + symbol = new(ctype, expr.Name); _exprMemo.Add(expr, symbol); return symbol; } @@ -294,10 +310,51 @@ protected override CSymbol VisitFor(For expr) // 1. For Loop signature var loopVar = Visit(expr.LoopVar); IndentScope.Writer.IndWrite($"for ({loopVar.Type} {loopVar.Name} = {Visit(expr.Domain.Start).Name}; {loopVar.Name} < {Visit(expr.Domain.Stop).Name}; {loopVar.Name}+={Visit(expr.Domain.Step).Name}) {{\n"); - using (_ = new IndentScope()) + + if (expr.Mode == LoopMode.Parallel) + { + // find the vars will be used and make new struct type. + var msg_fields = _exprMemo.Where(p => p.Key is MemSpan or TIR.Buffer or Var).Select(p => p.Value).Concat(CSymbol.Builtns); + var msg_type = DeclThreadMessageStruct(msg_fields); + + using (new IndentScope(_declBuilder)) + { + IndentScope.Writer.IndWrite($"void *{VisitEntry.Name}_inner(void *args) {{\n"); + using (new IndentScope()) + { + IndentScope.Writer.IndWrite($"{msg_type}* _message = ({msg_type}*)args;\n"); + foreach (var sym in msg_fields) + { + IndentScope.Writer.IndWrite($"{sym.Type} {sym.Name} = _message->{sym.Name};\n"); + } + + Visit(expr.Body); + } + + IndentScope.Writer.IndWrite(" return 0;\n"); + IndentScope.Writer.IndWrite("}\n"); + } + + using (new IndentScope()) + { + IndentScope.Writer.IndWrite($"{msg_type} _message = {{\n"); + foreach (var sym in msg_fields) + { + IndentScope.Writer.IndWrite($".{sym.Name} = {sym.Name},\n"); + } + + IndentScope.Writer.IndWrite("};\n"); + + IndentScope.Writer.IndWrite($"nncase_mt->thread_start({VisitEntry.Name}_inner, (void *)_message);\n"); + } + } + else { - // 2. For Body - Visit(expr.Body); + using (_ = new IndentScope()) + { + // 2. For Body + Visit(expr.Body); + } } // 3. For closing @@ -360,4 +417,23 @@ protected override CSymbol VisitIfThenElse(IfThenElse expr) _exprMemo.Add(expr, symbol); return symbol; } + + private string DeclThreadMessageStruct(IEnumerable keyValues) + { + var type = $"{VisitEntry.Name}_thread_message_t"; + _declWriter.WriteLine("typedef struct {"); + foreach (var sym in keyValues) + { + if (sym.Name == string.Empty) + { + throw new InvalidOperationException("empty name"); + } + + _declWriter.WriteLine(" " + sym.Type + " " + sym.Name + ";"); + } + + _declWriter.WriteLine($"}} {type};"); + _declWriter.WriteLine(); + return type; + } } diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index bd45b072ca..2dc177c607 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -86,7 +86,7 @@ private void GenerateMatMul(ReadOnlySpan arguments, Buffer ret, Call exp var lhs = arguments[0]; var rhs = arguments[1]; - var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); + var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), i == 0 ? LoopMode.Parallel : LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); var loopVars = loops.Select(f => f.loopVar).ToArray(); var stmt = T.Serial(out var m, (0, lhs.Dimensions[^2])).Body( T.Serial(out var n, (0, rhs.Dimensions[^1])).Body( @@ -108,7 +108,7 @@ private void GenerateMatMul(ReadOnlySpan arguments, Buffer ret, Call exp private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer ret) { var input = arguments[Unary.Input.Index]; - var loops = Enumerable.Range(0, input.Rank).Select(i => (T.ForLoop(out var loopVar, (0, input.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); + var loops = Enumerable.Range(0, input.Rank).Select(i => (T.ForLoop(out var loopVar, (0, input.Dimensions[i]), i == 0 ? LoopMode.Parallel : LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); var loopVars = loops.Select(f => f.loopVar).ToArray(); Expr stmt = T.BufferStore(ret, loopVars, IR.F.Math.Unary(unary.UnaryOp, T.BufferLoad(input, loopVars))); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build());