diff --git a/source/dlc.js b/source/dlc.js index 56ab54564b..51c41cb4d7 100644 --- a/source/dlc.js +++ b/source/dlc.js @@ -11,17 +11,16 @@ dlc.ModelFactory = class { async open(context, target) { await context.require('./dlc-schema'); dlc.schema = flatbuffers.get('dlc').dlc; - const format = target.entries.has('model') ? 'DLC' : 'DLC Weights'; - target.read(); + await target.read(context); const metadata = await context.metadata('dlc-metadata.json'); - return new dlc.Model(metadata, format, target); + return new dlc.Model(metadata, target); } }; dlc.Model = class { - constructor(metadata, format, target) { - this.format = format; + constructor(metadata, target) { + this.format = target.format; this.metadata = []; if (target.metadata.size > 0) { const version = target.metadata.get('model-version'); @@ -260,303 +259,338 @@ dlc.Container = class { static open(context) { const entries = context.entries('zip'); - if (entries.size > 0) { - if (entries.has('model') || entries.has('model.params')) { - return new dlc.Container(entries); - } + if (entries.has('model') || entries.has('model.params')) { + return new dlc.Container(entries.get('model'), entries.get('model.params'), entries.get('dlc.metadata')); } const stream = context.stream; switch (dlc.Container._signature(stream).split('.').pop()) { case 'NETD': - return new dlc.Container(new Map([ [ 'model', stream ] ])); + return new dlc.Container(stream, undefined, undefined); case 'NETP': - return new dlc.Container(new Map([ [ 'model.params', stream ] ])); + return new dlc.Container(undefined, stream, undefined); case 'NR64': - return new dlc.Container(new Map([ [ 'model', stream ] ])); + return new dlc.Container(undefined, stream, undefined); default: return null; } } - constructor(entries) { - this.entries = entries; + constructor(model, params, metadata) { + this._model = model; + this._params = params; + this._metadata = metadata; + } + + async read(context) { + const request = async (context, name) => { + try { + return await context.request(name, null); + } catch (error) { + return null; + } + }; + if (this._model === undefined) { + this._model = await request(context, 'model'); + } + if (this._params === undefined) { + this._params = await request(context, 'model.params'); + } + if (this._metadata === undefined) { + this._metadata = await request(context, 'dlc.metadata'); + } this.graphs = []; this.metadata = new Map(); + if (this._model) { + this.format = 'DLC'; + const stream = this._model; + delete this._model; + const signature = dlc.Container._signature(stream); + switch (signature) { + case '2': { + throw new dlc.Error("File contains undocumented DLC v2 data."); + } + case '3.NETD': + case 'NETD': { + this.version = 3; + this.graphs = dlc.Container._model3(stream, signature); + break; + } + case '4.NETD': { + this.version = 4; + this.graphs = dlc.Container._model4(stream); + break; + } + default: { + const buffer = stream.peek(Math.min(stream.length, 16)); + const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join(''); + throw new dlc.Error("File contains undocumented '" + content + "' data."); + } + } + } + if (this._params) { + this.format = this.format || 'DLC Weights'; + const stream = this._params; + delete this._params; + const signature = dlc.Container._signature(stream); + switch (signature) { + case '2': { + throw new dlc.Error("File contains undocumented DLC v2 data."); + } + case '3.NETP': + case 'NETP': { + this.version = this.graphs.length > 0 ? this.version : 3; + this.graphs = dlc.Container._params3(stream, signature, this.graphs); + break; + } + case '4.NETP': { + dlc.Container._params4(stream, this.graphs); + break; + } + case '4.NR64': { + throw new dlc.Error("File contains undocumented 'NR64' params data."); + } + default: { + const buffer = stream.peek(Math.min(stream.length, 16)); + const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join(''); + throw new dlc.Error("File contains undocumented '" + content + "' data."); + } + } + } + if (this._metadata) { + const stream = this._metadata; + delete this._metadata; + const reader = text.Reader.open(stream); + for (;;) { + const line = reader.read(); + if (line === undefined) { + break; + } + const index = line.indexOf('='); + if (index === -1) { + break; + } + const key = line.substring(0, index); + const value = line.substring(index + 1); + this.metadata.set(key, value); + } + } } - read() { - if (this.entries) { - const entries = this.entries; - delete this.entries; - if (entries.has('model')) { - const stream = entries.get('model'); - const signature = dlc.Container._signature(stream); - switch (signature) { - case '2': { - throw new dlc.Error("File contains undocumented DLC v2 data."); - } - case '3.NETD': - case 'NETD': { - this.version = 3; - let model = null; - try { - const buffer = new Uint8Array(signature === 'NETD' ? stream.peek() : stream.peek().subarray(8)); - const reader = flatbuffers.BinaryReader.open(buffer); - model = dlc.schema.v3.Model.decode(reader, reader.root); - model.version = 3; - } catch (error) { - const message = error && error.message ? error.message : error.toString(); - throw new dlc.Error('File format is not dlc.v1.NETD (' + message.replace(/\.$/, '') + ').'); - } - model.tensors = []; - this.graphs.push(model); - const updateAttribute = (attr) => { - switch (attr.type) { - case 1: return [ 'boolean', attr.bool_value ]; - case 2: return [ 'int32', attr.int32_value ]; - case 3: return [ 'uint32', attr.uint32_value ]; - case 4: return [ 'float32', attr.float32_value ]; - case 5: return [ 'string', attr.string_value ]; - case 7: return [ 'byte[]', Array.from(attr.byte_list) ]; - case 8: return [ 'int32[]', Array.from(attr.int32_list) ]; - case 9: return [ 'float32[]', Array.from(attr.float32_list) ]; - case 11: { - const obj = {}; - let index = 0; - let list = true; - for (const attribute of attr.attributes) { - const entry = updateAttribute(attribute); - obj[attribute.name] = entry[1]; - list = list && index.toString() === attribute.name; - index++; - } - return list ? [ '', Object.values(obj) ] : [ '', obj ]; - } - default: - throw new dlc.Error("Unsupported attribute type '" + attr.type + "'."); - } - }; - for (const node of model.nodes) { - for (const attribute of node.attributes) { - const entry = updateAttribute(attribute); - attribute.type = entry[0]; - attribute.data = entry[1]; - } - } - break; - } - case '4.NETD': { - this.version = 4; - let model = null; - try { - const buffer = new Uint8Array(stream.peek().subarray(8)); - const reader = flatbuffers.BinaryReader.open(buffer); - model = dlc.schema.v4.Model.decode(reader, reader.root); - } catch (error) { - const message = error && error.message ? error.message : error.toString(); - throw new dlc.Error('File format is not dlc.v4.NETD (' + message.replace(/\.$/, '') + ').'); - } - this.graphs = model.graphs; - const dataType = (value) => { - switch (value) { - case 0x0032: return 'int32'; - case 0x0108: return 'int8'; - case 0x0132: return 'int32'; - case 0x0232: return 'float32'; - case 0x0308: return 'qint8'; - case 0x0332: return 'qint32'; - case 0x0408: return 'uint8'; - case 0x0416: return 'uint16'; - case 0x0508: return 'boolean'; - default: throw new dlc.Error("Unsupported data type '" + JSON.stringify(value) + "'."); - } - }; - const updateTensor = (tensor) => { - tensor.dtype = dataType(tensor.dtype); - tensor.output_dtype = dataType(tensor.output_dtype); - }; - for (const graph of this.graphs) { - for (const node of graph.nodes) { - for (const attribute of node.attributes) { - switch (attribute.kind) { - case 0: { - const value = attribute.value; - switch (value.kind) { - case 0x7fffffff: - attribute.data = value.string_value; - attribute.type = 'string'; - break; - case 0x0032: - attribute.data = value.int32_value; - break; - case 0x0108: - attribute.data = value.int32_value; - attribute.type = 'int8'; - break; - case 0x0132: - attribute.data = value.int32_value; - attribute.type = 'int32'; - break; - case 0x0232: - attribute.data = value.float32_value; - attribute.type = 'float32'; - break; - case 0x0508: - attribute.data = value.int32_value !== 0; - attribute.type = 'boolean'; - break; - default: - throw new dlc.Error("Unknown attribute value kind '" + value.kind + "'."); - } - break; - } - case 1: { - const tensor = attribute.tensor; - updateTensor(tensor); - attribute.type = 'tensor'; - attribute.data = tensor; - break; - } - default: { - throw new dlc.Error("Unknown attribute kind '" + attribute.kind + "'."); - } - } - } - } - for (const tensor of graph.tensors) { - updateTensor(tensor); - } - } - break; - } - case '4.NR64': { - throw new dlc.Error("File contains undocumented 'NR64' data."); - } - default: { - const buffer = stream.peek(Math.min(stream.length, 16)); - const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join(''); - throw new dlc.Error("File contains undocumented '" + content + "' data."); + static _model3(stream, signature) { + let model = null; + try { + const buffer = new Uint8Array(signature === 'NETD' ? stream.peek() : stream.peek().subarray(8)); + const reader = flatbuffers.BinaryReader.open(buffer); + model = dlc.schema.v3.Model.decode(reader, reader.root); + } catch (error) { + const message = error && error.message ? error.message : error.toString(); + throw new dlc.Error('File format is not dlc.v1.NETD (' + message.replace(/\.$/, '') + ').'); + } + model.tensors = []; + const updateAttribute = (attr) => { + switch (attr.type) { + case 1: return [ 'boolean', attr.bool_value ]; + case 2: return [ 'int32', attr.int32_value ]; + case 3: return [ 'uint32', attr.uint32_value ]; + case 4: return [ 'float32', attr.float32_value ]; + case 5: return [ 'string', attr.string_value ]; + case 7: return [ 'byte[]', Array.from(attr.byte_list) ]; + case 8: return [ 'int32[]', Array.from(attr.int32_list) ]; + case 9: return [ 'float32[]', Array.from(attr.float32_list) ]; + case 11: { + const obj = {}; + let index = 0; + let list = true; + for (const attribute of attr.attributes) { + const name = attribute.name; + const entry = updateAttribute(attribute); + obj[name] = entry[1]; + list = list && index.toString() === attribute.name; + index++; } + return list ? [ '', Object.values(obj) ] : [ '', obj ]; } + default: + throw new dlc.Error("Unsupported attribute type '" + attr.type + "'."); } - if (entries.has('model.params')) { - const stream = entries.get('model.params'); - const signature = dlc.Container._signature(stream); - switch (signature) { - case '2': { - throw new dlc.Error("File contains undocumented DLC v2 data."); - } - case '3.NETP': - case 'NETP': { - let params = null; - try { - const buffer = new Uint8Array(signature === 'NETP' ? stream.peek() : stream.peek().subarray(8)); - const reader = flatbuffers.BinaryReader.open(buffer); - params = dlc.schema.v3.ModelParameters.decode(reader, reader.root); - } catch (error) { - const message = error && error.message ? error.message : error.toString(); - throw new dlc.Error('File format is not dlc.v1.NETP (' + message.replace(/\.$/, '') + ').'); - } - if (this.graphs.length === 0) { - this.version = 3; - const graph = new dlc.schema.v3.ModelParameters(); - graph.nodes = new Array(params.nodes.length); - graph.tensors = []; - for (let i = 0; i < graph.nodes.length; i++) { - const node = new dlc.schema.v3.Node(); - node.type = 'Weights'; - node.name = params.nodes[i].name; - node.inputs = []; - node.outputs = []; - node.attributes = []; - graph.nodes[i] = node; - } - this.graphs.push(graph); - } - const graph = this.graphs[0]; - const dataType = (value) => { - switch (value) { - case null: return '?'; - case 6: return 'uint8'; - case 9: return 'float32'; + }; + for (const node of model.nodes) { + for (const attribute of node.attributes) { + const entry = updateAttribute(attribute); + attribute.type = entry[0]; + attribute.data = entry[1]; + } + } + return [ model ]; + } + + static _model4(stream) { + let model = null; + try { + const buffer = new Uint8Array(stream.peek().subarray(8)); + const reader = flatbuffers.BinaryReader.open(buffer); + model = dlc.schema.v4.Model.decode(reader, reader.root); + } catch (error) { + const message = error && error.message ? error.message : error.toString(); + throw new dlc.Error('File format is not dlc.v4.NETD (' + message.replace(/\.$/, '') + ').'); + } + const dataType = (value) => { + switch (value) { + case 0x0032: return 'int32'; + case 0x0108: return 'int8'; + case 0x0132: return 'int32'; + case 0x0232: return 'float32'; + case 0x0308: return 'qint8'; + case 0x0332: return 'qint32'; + case 0x0408: return 'uint8'; + case 0x0416: return 'uint16'; + case 0x0508: return 'boolean'; + default: throw new dlc.Error("Unsupported data type '" + JSON.stringify(value) + "'."); + } + }; + const updateTensor = (tensor) => { + tensor.dtype = dataType(tensor.dtype); + tensor.output_dtype = dataType(tensor.output_dtype); + }; + for (const graph of model.graphs) { + for (const node of graph.nodes) { + for (const attribute of node.attributes) { + switch (attribute.kind) { + case 0: { + const value = attribute.value; + switch (value.kind) { + case 0x7fffffff: + attribute.data = value.string_value; + attribute.type = 'string'; + break; + case 0x0032: + attribute.data = value.int32_value; + break; + case 0x0108: + attribute.data = value.int32_value; + attribute.type = 'int8'; + break; + case 0x0132: + attribute.data = value.int32_value; + attribute.type = 'int32'; + break; + case 0x0232: + attribute.data = value.float32_value; + attribute.type = 'float32'; + break; + case 0x0508: + attribute.data = value.int32_value !== 0; + attribute.type = 'boolean'; + break; default: - throw new dlc.Error("Unsupported data type '" + JSON.stringify(value) + "'."); - } - }; - const weights = new Map(params.nodes.map((node) => [ node.name, node.weights ])); - for (const node of graph.nodes) { - if (weights.has(node.name)) { - const tensors = weights.get(node.name); - for (const tensor of tensors) { - tensor.data.dtype = dataType(tensor.data.dtype); - } - node.weights = tensors; + throw new dlc.Error("Unknown attribute value kind '" + value.kind + "'."); } + break; } - break; - } - case '4.NETP': { - let params = null; - try { - const buffer = new Uint8Array(stream.peek().subarray(8)); - const reader = flatbuffers.BinaryReader.open(buffer); - params = dlc.schema.v4.ModelParameters.decode(reader, reader.root); - } catch (error) { - const message = error && error.message ? error.message : error.toString(); - throw new dlc.Error('File format is not dlc.v2.NETP (' + message.replace(/\.$/, '') + ').'); - } - if (this.graphs.length === 0) { - throw new dlc.Error('Model definition not available.'); + case 1: { + const tensor = attribute.tensor; + updateTensor(tensor); + attribute.type = 'tensor'; + attribute.data = tensor; + break; } - const weights = new Map(params.graphs.map((graph) => [ graph.name, graph ])); - for (const graph of this.graphs) { - const params = weights.get(graph.name); - const tensors = new Map(params.tensors.map((tensor) => [ tensor.name, tensor ])); - for (const tensor of graph.tensors) { - if (tensor.location === 4) { - tensor.data = tensors.get(tensor.name).bytes; - } - } - for (let i = 0; i < graph.nodes.length; i++) { - const node = graph.nodes[i]; - const tensors = new Map(params.nodes[i].tensors.map((tensor) => [ tensor.name, tensor ])); - for (const attribute of node.attributes) { - const tensor = attribute.tensor; - if (tensor) { - tensor.data = tensors.get(tensor.name).bytes; - } - } - } + default: { + throw new dlc.Error("Unknown attribute kind '" + attribute.kind + "'."); } - break; - } - default: { - const buffer = stream.peek(Math.min(stream.length, 16)); - const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join(''); - throw new dlc.Error("File contains undocumented '" + content + "' data."); } } } - if (entries.has('dlc.metadata')) { - const stream = entries.get('dlc.metadata'); - const reader = text.Reader.open(stream); - for (;;) { - const line = reader.read(); - if (line === undefined) { - break; - } - const index = line.indexOf('='); - if (index === -1) { - break; + for (const tensor of graph.tensors) { + updateTensor(tensor); + } + } + return model.graphs; + } + + static _params3(stream, signature, graphs) { + let params = null; + try { + const buffer = new Uint8Array(signature === 'NETP' ? stream.peek() : stream.peek().subarray(8)); + const reader = flatbuffers.BinaryReader.open(buffer); + params = dlc.schema.v3.ModelParameters.decode(reader, reader.root); + } catch (error) { + const message = error && error.message ? error.message : error.toString(); + throw new dlc.Error('File format is not dlc.v1.NETP (' + message.replace(/\.$/, '') + ').'); + } + if (graphs.length === 0) { + const graph = new dlc.schema.v3.ModelParameters(); + graph.nodes = new Array(params.nodes.length); + graph.tensors = []; + for (let i = 0; i < graph.nodes.length; i++) { + const node = new dlc.schema.v3.Node(); + node.type = 'Weights'; + node.name = params.nodes[i].name; + node.inputs = []; + node.outputs = []; + node.attributes = []; + graph.nodes[i] = node; + } + graphs.push(graph); + } + const graph = graphs[0]; + const dataType = (value) => { + switch (value) { + case null: return '?'; + case 6: return 'uint8'; + case 9: return 'float32'; + default: + throw new dlc.Error("Unsupported data type '" + JSON.stringify(value) + "'."); + } + }; + const weights = new Map(params.nodes.map((node) => [ node.name, node.weights ])); + for (const node of graph.nodes) { + if (weights.has(node.name)) { + const tensors = weights.get(node.name); + for (const tensor of tensors) { + tensor.data.dtype = dataType(tensor.data.dtype); + } + node.weights = tensors; + } + } + return graphs; + } + + static _params4(stream, graphs) { + let params = null; + try { + const buffer = new Uint8Array(stream.peek().subarray(8)); + const reader = flatbuffers.BinaryReader.open(buffer); + params = dlc.schema.v4.ModelParameters.decode(reader, reader.root); + } catch (error) { + const message = error && error.message ? error.message : error.toString(); + throw new dlc.Error('File format is not dlc.v2.NETP (' + message.replace(/\.$/, '') + ').'); + } + if (graphs.length === 0) { + throw new dlc.Error('Model definition not available.'); + } + const weights = new Map(params.graphs.map((graph) => [ graph.name, graph ])); + for (const graph of graphs) { + const params = weights.get(graph.name); + const tensors = new Map(params.tensors.map((tensor) => [ tensor.name, tensor ])); + for (const tensor of graph.tensors) { + if (tensor.location === 4) { + tensor.data = tensors.get(tensor.name).bytes; + } + } + for (let i = 0; i < graph.nodes.length; i++) { + const node = graph.nodes[i]; + const tensors = new Map(params.nodes[i].tensors.map((tensor) => [ tensor.name, tensor ])); + for (const attribute of node.attributes) { + const tensor = attribute.tensor; + if (tensor) { + tensor.data = tensors.get(tensor.name).bytes; } - const key = line.substring(0, index); - const value = line.substring(index + 1); - this.metadata.set(key, value); } } } } + static _signature(stream) { if (stream) { const buffer = stream.peek(Math.min(stream.length, 16));