Skip to content

Commit

Permalink
add unary parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Aug 8, 2023
1 parent e6bb3fd commit 945799c
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 8 deletions.
3 changes: 3 additions & 0 deletions modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public static class CSourceBuiltn
bool (*bool_binary_and)(bool, bool);
bool (*bool_binary_or)(bool, bool);
bool (*bool_binary_xor)(bool, bool);
// paralell
void *thread_start(void *(*callable)(void *args), void *user);
void thread_end();
} nncase_mt_t;";

public const string Include = @"#include <stdbool.h>
Expand Down
88 changes: 82 additions & 6 deletions modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ internal struct IndentScope : IDisposable
public IndentScope(StringBuilder sb)
{
_initialized = true;
_originalWriter = _writer.Value;
_writer.Value = new IndentWriter(sb);
_originalWriter = null;
}

public IndentScope()
Expand Down Expand Up @@ -70,6 +70,12 @@ public CSymbol(string type, string name)
public string Name { get; }

public override string ToString() => $"{Type} {Name}";

public static IReadOnlyList<CSymbol> Builtns => new CSymbol[] {

Check failure on line 74 in modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-macos

A property should not follow a method

Check failure on line 74 in modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

A property should not follow a method

Check failure on line 74 in modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

A property should not follow a method

Check failure on line 74 in modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

A property should not follow a method

Check failure on line 74 in modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

A property should not follow a method

Check failure on line 74 in modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

A property should not follow a method

Check failure on line 74 in modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

A property should not follow a method
new CSymbol("nncase_mt_t*", "nncase_mt"),
new CSymbol("uint8_t*", "data"),
new CSymbol("const uint8_t*", "rdata"),
};
}

internal sealed class IndentWriter : StringWriter
Expand Down Expand Up @@ -100,16 +106,22 @@ internal sealed class CSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
{
private readonly Dictionary<Expr, CSymbol> _exprMemo;
private readonly StringBuilder _implBuilder;
private readonly StringBuilder _declBuilder;
private readonly StringWriter _declWriter;

public CSourceConvertVisitor()
{
_implBuilder = new StringBuilder();
_declBuilder = new StringBuilder();
_declWriter = new StringWriter(_declBuilder);
_exprMemo = new(ReferenceEqualityComparer.Instance);
}

public PrimFunction VisitEntry => (TIR.PrimFunction)VisitRoot!;

public FunctionCSource GetFunctionCSource()
{
return new(_exprMemo[VisitRoot!].Type + ";", _implBuilder.ToString());
return new(_declBuilder.ToString(), _implBuilder.ToString());
}

/// <inheritdoc/>
Expand All @@ -127,6 +139,9 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr)

var type = $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b.MemSpan.Start).ToString()).ToArray())}, {CSourceBuiltn.FixedParameters})";

_declWriter.WriteLine(type + ";");
_declWriter.WriteLine();

using (var scope = new IndentScope(_implBuilder))
{
// 1. Function signature
Expand All @@ -142,7 +157,8 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr)
IndentScope.Writer.IndWrite("}\n");
}

symbol = new(type, new(expr.Name));
var ctype = $"void (*{expr.Name})({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b.MemSpan.Start).ToString()).ToArray())}, {CSourceBuiltn.FixedParameters})";
symbol = new(ctype, expr.Name);
_exprMemo.Add(expr, symbol);
return symbol;
}
Expand Down Expand Up @@ -294,10 +310,51 @@ protected override CSymbol VisitFor(For expr)
// 1. For Loop signature
var loopVar = Visit(expr.LoopVar);
IndentScope.Writer.IndWrite($"for ({loopVar.Type} {loopVar.Name} = {Visit(expr.Domain.Start).Name}; {loopVar.Name} < {Visit(expr.Domain.Stop).Name}; {loopVar.Name}+={Visit(expr.Domain.Step).Name}) {{\n");
using (_ = new IndentScope())

if (expr.Mode == LoopMode.Parallel)
{
// find the vars will be used and make new struct type.
var msg_fields = _exprMemo.Where(p => p.Key is MemSpan or TIR.Buffer or Var).Select(p => p.Value).Concat(CSymbol.Builtns);
var msg_type = DeclThreadMessageStruct(msg_fields);

using (new IndentScope(_declBuilder))
{
IndentScope.Writer.IndWrite($"void *{VisitEntry.Name}_inner(void *args) {{\n");
using (new IndentScope())
{
IndentScope.Writer.IndWrite($"{msg_type}* _message = ({msg_type}*)args;\n");
foreach (var sym in msg_fields)
{
IndentScope.Writer.IndWrite($"{sym.Type} {sym.Name} = _message->{sym.Name};\n");
}

Visit(expr.Body);
}

IndentScope.Writer.IndWrite(" return 0;\n");
IndentScope.Writer.IndWrite("}\n");
}

using (new IndentScope())
{
IndentScope.Writer.IndWrite($"{msg_type} _message = {{\n");
foreach (var sym in msg_fields)
{
IndentScope.Writer.IndWrite($".{sym.Name} = {sym.Name},\n");
}

IndentScope.Writer.IndWrite("};\n");

IndentScope.Writer.IndWrite($"nncase_mt->thread_start({VisitEntry.Name}_inner, (void *)_message);\n");
}
}
else
{
// 2. For Body
Visit(expr.Body);
using (_ = new IndentScope())
{
// 2. For Body
Visit(expr.Body);
}
}

// 3. For closing
Expand Down Expand Up @@ -360,4 +417,23 @@ protected override CSymbol VisitIfThenElse(IfThenElse expr)
_exprMemo.Add(expr, symbol);
return symbol;
}

private string DeclThreadMessageStruct(IEnumerable<CSymbol> keyValues)
{
var type = $"{VisitEntry.Name}_thread_message_t";
_declWriter.WriteLine("typedef struct {");
foreach (var sym in keyValues)
{
if (sym.Name == string.Empty)
{
throw new InvalidOperationException("empty name");
}

_declWriter.WriteLine(" " + sym.Type + " " + sym.Name + ";");
}

_declWriter.WriteLine($"}} {type};");
_declWriter.WriteLine();
return type;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private void GenerateMatMul(ReadOnlySpan<Buffer> arguments, Buffer ret, Call exp
var lhs = arguments[0];
var rhs = arguments[1];

var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray();
var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), i == 0 ? LoopMode.Parallel : LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray();
var loopVars = loops.Select(f => f.loopVar).ToArray();
var stmt = T.Serial(out var m, (0, lhs.Dimensions[^2])).Body(
T.Serial(out var n, (0, rhs.Dimensions[^1])).Body(
Expand All @@ -108,7 +108,7 @@ private void GenerateMatMul(ReadOnlySpan<Buffer> arguments, Buffer ret, Call exp
private void GenerateUnary(Unary unary, ReadOnlySpan<Buffer> arguments, Buffer ret)
{
var input = arguments[Unary.Input.Index];
var loops = Enumerable.Range(0, input.Rank).Select(i => (T.ForLoop(out var loopVar, (0, input.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray();
var loops = Enumerable.Range(0, input.Rank).Select(i => (T.ForLoop(out var loopVar, (0, input.Dimensions[i]), i == 0 ? LoopMode.Parallel : LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray();
var loopVars = loops.Select(f => f.loopVar).ToArray();
Expr stmt = T.BufferStore(ret, loopVars, IR.F.Math.Unary(unary.UnaryOp, T.BufferLoad(input, loopVars)));
var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build());
Expand Down

0 comments on commit 945799c

Please sign in to comment.