Skip to content

Commit

Permalink
Merge pull request stakira#1093 from yqzhishen/ds-tensor-cache
Browse files Browse the repository at this point in the history
[DiffSinger] Add tensor caching system
  • Loading branch information
stakira authored Apr 20, 2024
2 parents e3915b5 + 7dc5500 commit 893204e
Show file tree
Hide file tree
Showing 12 changed files with 429 additions and 29 deletions.
30 changes: 26 additions & 4 deletions OpenUtau.Core/DiffSinger/DiffSingerBasePhonemizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using K4os.Hash.xxHash;
using Serilog;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;

using OpenUtau.Api;
using OpenUtau.Core.Ustx;
using OpenUtau.Core.Util;

namespace OpenUtau.Core.DiffSinger
{
Expand All @@ -17,6 +19,8 @@ public abstract class DiffSingerBasePhonemizer : MachineLearningPhonemizer
DsConfig dsConfig;
string rootPath;
float frameMs;
ulong linguisticHash;
ulong durationHash;
InferenceSession linguisticModel;
InferenceSession durationModel;
IG2p g2p;
Expand Down Expand Up @@ -51,14 +55,18 @@ public override void SetSinger(USinger singer) {
//Load models
var linguisticModelPath = Path.Join(rootPath, dsConfig.linguistic);
try {
linguisticModel = new InferenceSession(linguisticModelPath);
var linguisticModelBytes = File.ReadAllBytes(linguisticModelPath);
linguisticHash = XXH64.DigestOf(linguisticModelBytes);
linguisticModel = new InferenceSession(linguisticModelBytes);
} catch (Exception e) {
Log.Error(e, $"failed to load linguistic model from {linguisticModelPath}");
return;
}
var durationModelPath = Path.Join(rootPath, dsConfig.dur);
try {
durationModel = new InferenceSession(durationModelPath);
var durationModelBytes = File.ReadAllBytes(durationModelPath);
durationHash = XXH64.DigestOf(durationModelBytes);
durationModel = new InferenceSession(durationModelBytes);
} catch (Exception e) {
Log.Error(e, $"failed to load duration model from {durationModelPath}");
return;
Expand Down Expand Up @@ -260,7 +268,14 @@ protected override void ProcessPart(Note[][] phrase) {
new DenseTensor<Int64>(word_dur, new int[] { word_dur.Length }, false)
.Reshape(new int[] { 1, word_dur.Length })));
Onnx.VerifyInputNames(linguisticModel, linguisticInputs);
var linguisticOutputs = linguisticModel.Run(linguisticInputs);
var linguisticCache = Preferences.Default.DiffSingerTensorCache
? new DiffSingerCache(linguisticHash, linguisticInputs)
: null;
var linguisticOutputs = linguisticCache?.Load();
if (linguisticOutputs is null) {
linguisticOutputs = linguisticModel.Run(linguisticInputs).Cast<NamedOnnxValue>().ToList();
linguisticCache?.Save(linguisticOutputs);
}
Tensor<float> encoder_out = linguisticOutputs
.Where(o => o.Name == "encoder_out")
.First()
Expand Down Expand Up @@ -291,7 +306,14 @@ protected override void ProcessPart(Note[][] phrase) {
durationInputs.Add(NamedOnnxValue.CreateFromTensor("spk_embed", spkEmbedTensor));
}
Onnx.VerifyInputNames(durationModel, durationInputs);
var durationOutputs = durationModel.Run(durationInputs);
var durationCache = Preferences.Default.DiffSingerTensorCache
? new DiffSingerCache(durationHash, durationInputs)
: null;
var durationOutputs = durationCache?.Load();
if (durationOutputs is null) {
durationOutputs = durationModel.Run(durationInputs).Cast<NamedOnnxValue>().ToList();
durationCache?.Save(durationOutputs);
}
List<double> durationFrames = durationOutputs.First().AsTensor<float>().Select(x=>(double)x).ToList();

//Alignment
Expand Down
311 changes: 311 additions & 0 deletions OpenUtau.Core/DiffSinger/DiffSingerCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using K4os.Hash.xxHash;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using Serilog;

namespace OpenUtau.Core.DiffSinger {

public class DiffSingerCache {
private const string FormatHeader = "TENSORCACHE";

private readonly ulong hash;
private readonly string filename;

public ulong Hash => hash;

public DiffSingerCache(ulong identifier, ICollection<NamedOnnxValue> inputs) {
using var stream = new MemoryStream();
using (var writer = new BinaryWriter(stream)) {
writer.Write(identifier);
foreach (var onnxValue in inputs.OrderBy(v => v.Name, StringComparer.InvariantCulture)) {
SerializeNamedOnnxValue(writer, onnxValue);
}
}

hash = XXH64.DigestOf(stream.ToArray());
filename = $"ds-{hash:x16}.tensorcache";
}

public ICollection<NamedOnnxValue>? Load() {
var cachePath = Path.Join(PathManager.Inst.CachePath, filename);
if (!File.Exists(cachePath)) return null;

var result = new List<NamedOnnxValue>();
using var stream = new FileStream(cachePath, FileMode.Open, FileAccess.Read);
using var reader = new BinaryReader(stream);
// header
if (reader.ReadString() != FormatHeader) {
throw new InvalidDataException($"[TensorCache] Unexpected file header in {filename}.");
}
try {
// count
var count = reader.ReadInt32();
for (var i = 0; i < count; ++i) {
// data
result.Add(DeserializeNamedOnnxValue(reader));
}
} catch (Exception e) {
Log.Error(e,
"[TensorCache] Exception encountered when deserializing cache file. Root exception message: {msg}", e.Message);
Delete();
return null;
}

return result;
}

public void Delete() {
var cachePath = Path.Join(PathManager.Inst.CachePath, filename);
if (File.Exists(cachePath)) {
File.Delete(cachePath);
}
}

public void Save(ICollection<NamedOnnxValue> outputs) {
var cachePath = Path.Join(PathManager.Inst.CachePath, filename);
using var stream = new FileStream(cachePath, FileMode.Create, FileAccess.Write);
using var writer = new BinaryWriter(stream);
// header
writer.Write(FormatHeader);
// count
writer.Write(outputs.Count);
foreach (var onnxValue in outputs) {
// data
SerializeNamedOnnxValue(writer, onnxValue);
}
}

private static void SerializeNamedOnnxValue(BinaryWriter writer, NamedOnnxValue namedOnnxValue) {
if (namedOnnxValue.ValueType != OnnxValueType.ONNX_TYPE_TENSOR) {
throw new NotSupportedException(
$"[TensorCache] The only supported ONNX value type is {OnnxValueType.ONNX_TYPE_TENSOR}. Got {namedOnnxValue.ValueType} instead."
);
}
// name
writer.Write(namedOnnxValue.Name);
var tensorBase = (TensorBase) namedOnnxValue.Value;
var elementType = tensorBase.GetTypeInfo().ElementType;
// dtype
writer.Write((int)elementType);
switch (elementType) {
case TensorElementType.Float: {
var tensor = namedOnnxValue.AsTensor<float>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.UInt8: {
var tensor = namedOnnxValue.AsTensor<byte>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.Int8: {
var tensor = namedOnnxValue.AsTensor<sbyte>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.UInt16: {
var tensor = namedOnnxValue.AsTensor<ushort>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.Int16: {
var tensor = namedOnnxValue.AsTensor<short>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.Int32: {
var tensor = namedOnnxValue.AsTensor<int>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.Int64: {
var tensor = namedOnnxValue.AsTensor<long>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.String: {
var tensor = namedOnnxValue.AsTensor<string>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.Bool: {
var tensor = namedOnnxValue.AsTensor<bool>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.Float16: {
var tensor = namedOnnxValue.AsTensor<Float16>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.Double: {
var tensor = namedOnnxValue.AsTensor<double>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.UInt32: {
var tensor = namedOnnxValue.AsTensor<uint>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.UInt64: {
var tensor = namedOnnxValue.AsTensor<ulong>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.BFloat16: {
var tensor = namedOnnxValue.AsTensor<BFloat16>();
SerializeTensor(writer, tensor);
break;
}
case TensorElementType.Complex64:
case TensorElementType.Complex128:
case TensorElementType.DataTypeMax:
default:
throw new NotSupportedException($"[TensorCache] Unsupported tensor element type: {elementType}.");
}
}

private static void SerializeTensor<T>(BinaryWriter writer, Tensor<T> tensor) {
if (tensor.IsReversedStride) {
throw new NotSupportedException("[TensorCache] Tensors in reversed strides are not supported.");
}
// rank
writer.Write(tensor.Rank);
// shape
foreach (var dim in tensor.Dimensions) {
writer.Write(dim);
}
// size
var size = (int)tensor.Length;
writer.Write(size);
if (typeof(T) == typeof(string)) {
// string tensor
// data
foreach (var element in tensor.ToArray()) {
writer.Write(element!.ToString());
}
} else {
// numeric tensor
// data
var data = new byte[size * tensor.GetTypeInfo().TypeSize];
Buffer.BlockCopy(tensor.ToArray(), 0, data, 0, data.Length);
writer.Write(data);
}
}

private static NamedOnnxValue DeserializeNamedOnnxValue(BinaryReader reader) {
// name
var name = reader.ReadString();
// dtype
var dtype = (TensorElementType)reader.ReadInt32();
// rank
var rank = reader.ReadInt32();
// shape
int[] shape = new int[rank];
for (var i = 0; i < rank; ++i) {
shape[i] = reader.ReadInt32();
}
// size
var size = reader.ReadInt32();
NamedOnnxValue namedOnnxValue;
switch (dtype) {
case TensorElementType.Float: {
var tensor = DeserializeTensor<float>(reader, size, sizeof(float), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.UInt8: {
var tensor = DeserializeTensor<byte>(reader, size, sizeof(byte), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.Int8: {
var tensor = DeserializeTensor<sbyte>(reader, size, sizeof(sbyte), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.UInt16: {
var tensor = DeserializeTensor<ushort>(reader, size, sizeof(ushort), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.Int16: {
var tensor = DeserializeTensor<short>(reader, size, sizeof(short), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.Int32: {
var tensor = DeserializeTensor<int>(reader, size, sizeof(int), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.Int64: {
var tensor = DeserializeTensor<long>(reader, size, sizeof(long), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.String: {
// string tensor
Tensor<string> tensor = new DenseTensor<string>(size);
for (var i = 0; i < size; ++i) {
tensor[i] = reader.ReadString();
}
tensor = tensor.Reshape(shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.Bool: {
var tensor = DeserializeTensor<bool>(reader, size, sizeof(bool), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.Float16: {
var tensor = DeserializeTensor<Float16>(reader, size, sizeof(ushort), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.Double: {
var tensor = DeserializeTensor<double>(reader, size, sizeof(double), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.UInt32: {
var tensor = DeserializeTensor<uint>(reader, size, sizeof(uint), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.UInt64: {
var tensor = DeserializeTensor<ulong>(reader, size, sizeof(ulong), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.BFloat16: {
var tensor = DeserializeTensor<BFloat16>(reader, size, sizeof(ushort), shape);
namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor);
break;
}
case TensorElementType.Complex64:
case TensorElementType.Complex128:
case TensorElementType.DataTypeMax:
default:
throw new NotSupportedException($"[TensorCache] Unsupported tensor element type: {dtype}.");
}

return namedOnnxValue;
}

private static Tensor<T> DeserializeTensor<T>(BinaryReader reader, int size, int typeSize, ReadOnlySpan<int> shape)
{
var bytes = reader.ReadBytes(size * typeSize);
var data = new T[size];
Buffer.BlockCopy(bytes, 0, data, 0, bytes.Length);
Tensor<T> tensor = new DenseTensor<T>(data, shape);
return tensor;
}
}
}
Loading

0 comments on commit 893204e

Please sign in to comment.