Skip to content

Commit

Permalink
Support >2GB of Tensor data in training checkpoint (microsoft#20077)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Add ability to store initializer data in an external file.
Update training checkpoint code to use external file if data > ~2GB.

I don't see a way for the flatbuffers 64-bit offsets to be used, as they
don't support storing 'table' types with 64-bit offsets (and our Tensor
is a 'table' type not a simple struct).


https://github.com/google/flatbuffers/blob/0cfb7eb80b05c058e19e50fb575263908e601469/tests/64bit/test_64bit.fbs#L38-L39

Allowing a Tensor to have its raw_data in an external file should
hopefully work with the least friction. As it's an extra field it's
backwards compatible.

Please feel free to suggest alternative approaches. 

Side note: the diffs in the generated *.fbs.h files are unexpectedly
large. Maybe they weren't re-generated when the new flatbuffers version
was checked in. I updated by running:
`python .\compile_schema.py -f <build output
dir>\_deps\flatbuffers-build\Debug\flatc.exe`
from onnxruntime\core\flatbuffers\schema which I thought was the correct
way but maybe that's out of date.

I think you can ignore all the diffs in the generated files and just
worry about the changes to the .fbs files in
onnxruntime/core/flatbuffers/schema. Basically start at the bottom of
the files changed and work up as all the 'real' diffs are there.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: carzh <[email protected]>
  • Loading branch information
2 people authored and Ted Themistokleous committed May 7, 2024
1 parent 72de8a9 commit c3ecdef
Show file tree
Hide file tree
Showing 64 changed files with 3,873 additions and 1,363 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ exclude_patterns = [
'js/**',
'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**', # Contains data chunks
'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code
'onnxruntime/test/flatbuffers/*.fbs.h', # Generated code
'onnxruntime/core/graph/contrib_ops/quantization_defs.cc',
'onnxruntime/core/mlas/**', # Contains assembly code
'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting
Expand Down
10 changes: 8 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,16 @@ file(GLOB onnxruntime_test_common_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/common/logging/*.h"
)

file(GLOB onnxruntime_test_quantiztion_src CONFIGURE_DEPENDS
file(GLOB onnxruntime_test_quantization_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/quantization/*.cc"
"${TEST_SRC_DIR}/quantization/*.h"
)

file(GLOB onnxruntime_test_flatbuffers_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/flatbuffers/*.cc"
"${TEST_SRC_DIR}/flatbuffers/*.h"
)

if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)

file(GLOB onnxruntime_test_ir_src CONFIGURE_DEPENDS
Expand Down Expand Up @@ -767,7 +772,8 @@ if(NOT IOS)
endif()

set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src}
${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src})
${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantization_src}
${onnxruntime_test_flatbuffers_src})

if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/flatbuffers/flatbuffers_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,5 +315,4 @@ bool IsOrtFormatModelBytes(const void* bytes, int num_bytes) {
return num_bytes > 8 && // check buffer is large enough to contain identifier so we don't read random memory
fbs::InferenceSessionBufferHasIdentifier(bytes);
}

} // namespace onnxruntime::fbs::utils
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@
class ArgType(object):
INPUT = 0
OUTPUT = 1

Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ class ArgTypeAndIndex(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsArgTypeAndIndex(cls, buf, offset):
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = ArgTypeAndIndex()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsArgTypeAndIndex(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def ArgTypeAndIndexBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)

Expand All @@ -38,7 +42,26 @@ def Index(self):
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
return 0

def ArgTypeAndIndexStart(builder): builder.StartObject(2)
def ArgTypeAndIndexAddArgType(builder, argType): builder.PrependInt8Slot(0, argType, 0)
def ArgTypeAndIndexAddIndex(builder, index): builder.PrependUint32Slot(1, index, 0)
def ArgTypeAndIndexEnd(builder): return builder.EndObject()
def ArgTypeAndIndexStart(builder):
builder.StartObject(2)

def Start(builder):
ArgTypeAndIndexStart(builder)

def ArgTypeAndIndexAddArgType(builder, argType):
builder.PrependInt8Slot(0, argType, 0)

def AddArgType(builder, argType):
ArgTypeAndIndexAddArgType(builder, argType)

def ArgTypeAndIndexAddIndex(builder, index):
builder.PrependUint32Slot(1, index, 0)

def AddIndex(builder, index):
ArgTypeAndIndexAddIndex(builder, index)

def ArgTypeAndIndexEnd(builder):
return builder.EndObject()

def End(builder):
return ArgTypeAndIndexEnd(builder)
145 changes: 124 additions & 21 deletions onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/Attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ class Attribute(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsAttribute(cls, buf, offset):
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Attribute()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsAttribute(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def AttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)

Expand Down Expand Up @@ -212,23 +216,122 @@ def GraphsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
return o == 0

def AttributeStart(builder): builder.StartObject(13)
def AttributeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def AttributeAddDocString(builder, docString): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
def AttributeAddType(builder, type): builder.PrependInt32Slot(2, type, 0)
def AttributeAddF(builder, f): builder.PrependFloat32Slot(3, f, 0.0)
def AttributeAddI(builder, i): builder.PrependInt64Slot(4, i, 0)
def AttributeAddS(builder, s): builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(s), 0)
def AttributeAddT(builder, t): builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(t), 0)
def AttributeAddG(builder, g): builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(g), 0)
def AttributeAddFloats(builder, floats): builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(floats), 0)
def AttributeStartFloatsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def AttributeAddInts(builder, ints): builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(ints), 0)
def AttributeStartIntsVector(builder, numElems): return builder.StartVector(8, numElems, 8)
def AttributeAddStrings(builder, strings): builder.PrependUOffsetTRelativeSlot(10, flatbuffers.number_types.UOffsetTFlags.py_type(strings), 0)
def AttributeStartStringsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def AttributeAddTensors(builder, tensors): builder.PrependUOffsetTRelativeSlot(11, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0)
def AttributeStartTensorsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def AttributeAddGraphs(builder, graphs): builder.PrependUOffsetTRelativeSlot(12, flatbuffers.number_types.UOffsetTFlags.py_type(graphs), 0)
def AttributeStartGraphsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def AttributeEnd(builder): return builder.EndObject()
def AttributeStart(builder):
builder.StartObject(13)

def Start(builder):
AttributeStart(builder)

def AttributeAddName(builder, name):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)

def AddName(builder, name):
AttributeAddName(builder, name)

def AttributeAddDocString(builder, docString):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)

def AddDocString(builder, docString):
AttributeAddDocString(builder, docString)

def AttributeAddType(builder, type):
builder.PrependInt32Slot(2, type, 0)

def AddType(builder, type):
AttributeAddType(builder, type)

def AttributeAddF(builder, f):
builder.PrependFloat32Slot(3, f, 0.0)

def AddF(builder, f):
AttributeAddF(builder, f)

def AttributeAddI(builder, i):
builder.PrependInt64Slot(4, i, 0)

def AddI(builder, i):
AttributeAddI(builder, i)

def AttributeAddS(builder, s):
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(s), 0)

def AddS(builder, s):
AttributeAddS(builder, s)

def AttributeAddT(builder, t):
builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(t), 0)

def AddT(builder, t):
AttributeAddT(builder, t)

def AttributeAddG(builder, g):
builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(g), 0)

def AddG(builder, g):
AttributeAddG(builder, g)

def AttributeAddFloats(builder, floats):
builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(floats), 0)

def AddFloats(builder, floats):
AttributeAddFloats(builder, floats)

def AttributeStartFloatsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartFloatsVector(builder, numElems: int) -> int:
return AttributeStartFloatsVector(builder, numElems)

def AttributeAddInts(builder, ints):
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(ints), 0)

def AddInts(builder, ints):
AttributeAddInts(builder, ints)

def AttributeStartIntsVector(builder, numElems):
return builder.StartVector(8, numElems, 8)

def StartIntsVector(builder, numElems: int) -> int:
return AttributeStartIntsVector(builder, numElems)

def AttributeAddStrings(builder, strings):
builder.PrependUOffsetTRelativeSlot(10, flatbuffers.number_types.UOffsetTFlags.py_type(strings), 0)

def AddStrings(builder, strings):
AttributeAddStrings(builder, strings)

def AttributeStartStringsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartStringsVector(builder, numElems: int) -> int:
return AttributeStartStringsVector(builder, numElems)

def AttributeAddTensors(builder, tensors):
builder.PrependUOffsetTRelativeSlot(11, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0)

def AddTensors(builder, tensors):
AttributeAddTensors(builder, tensors)

def AttributeStartTensorsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartTensorsVector(builder, numElems: int) -> int:
return AttributeStartTensorsVector(builder, numElems)

def AttributeAddGraphs(builder, graphs):
builder.PrependUOffsetTRelativeSlot(12, flatbuffers.number_types.UOffsetTFlags.py_type(graphs), 0)

def AddGraphs(builder, graphs):
AttributeAddGraphs(builder, graphs)

def AttributeStartGraphsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartGraphsVector(builder, numElems: int) -> int:
return AttributeStartGraphsVector(builder, numElems)

def AttributeEnd(builder):
return builder.EndObject()

def End(builder):
return AttributeEnd(builder)
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ class AttributeType(object):
GRAPHS = 10
SPARSE_TENSOR = 11
SPARSE_TENSORS = 12

54 changes: 46 additions & 8 deletions onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/Checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ class Checkpoint(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsCheckpoint(cls, buf, offset):
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Checkpoint()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsCheckpoint(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def CheckpointBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)

Expand Down Expand Up @@ -78,10 +82,44 @@ def PropertyBag(self):
return obj
return None

def CheckpointStart(builder): builder.StartObject(4)
def CheckpointAddVersion(builder, version): builder.PrependInt32Slot(0, version, 0)
def CheckpointAddModuleState(builder, moduleState): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(moduleState), 0)
def CheckpointAddOptimizerGroups(builder, optimizerGroups): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerGroups), 0)
def CheckpointStartOptimizerGroupsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def CheckpointAddPropertyBag(builder, propertyBag): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(propertyBag), 0)
def CheckpointEnd(builder): return builder.EndObject()
def CheckpointStart(builder):
builder.StartObject(4)

def Start(builder):
CheckpointStart(builder)

def CheckpointAddVersion(builder, version):
builder.PrependInt32Slot(0, version, 0)

def AddVersion(builder, version):
CheckpointAddVersion(builder, version)

def CheckpointAddModuleState(builder, moduleState):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(moduleState), 0)

def AddModuleState(builder, moduleState):
CheckpointAddModuleState(builder, moduleState)

def CheckpointAddOptimizerGroups(builder, optimizerGroups):
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerGroups), 0)

def AddOptimizerGroups(builder, optimizerGroups):
CheckpointAddOptimizerGroups(builder, optimizerGroups)

def CheckpointStartOptimizerGroupsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartOptimizerGroupsVector(builder, numElems: int) -> int:
return CheckpointStartOptimizerGroupsVector(builder, numElems)

def CheckpointAddPropertyBag(builder, propertyBag):
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(propertyBag), 0)

def AddPropertyBag(builder, propertyBag):
CheckpointAddPropertyBag(builder, propertyBag)

def CheckpointEnd(builder):
return builder.EndObject()

def End(builder):
return CheckpointEnd(builder)
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ class DeprecatedKernelCreateInfos(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsDeprecatedKernelCreateInfos(cls, buf, offset):
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = DeprecatedKernelCreateInfos()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsDeprecatedKernelCreateInfos(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def DeprecatedKernelCreateInfosBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)

Expand Down Expand Up @@ -79,9 +83,38 @@ def KernelDefHashesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0

def DeprecatedKernelCreateInfosStart(builder): builder.StartObject(2)
def DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(nodeIndices), 0)
def DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelDefHashes), 0)
def DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems): return builder.StartVector(8, numElems, 8)
def DeprecatedKernelCreateInfosEnd(builder): return builder.EndObject()
def DeprecatedKernelCreateInfosStart(builder):
builder.StartObject(2)

def Start(builder):
DeprecatedKernelCreateInfosStart(builder)

def DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(nodeIndices), 0)

def AddNodeIndices(builder, nodeIndices):
DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices)

def DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartNodeIndicesVector(builder, numElems: int) -> int:
return DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems)

def DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelDefHashes), 0)

def AddKernelDefHashes(builder, kernelDefHashes):
DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes)

def DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems):
return builder.StartVector(8, numElems, 8)

def StartKernelDefHashesVector(builder, numElems: int) -> int:
return DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems)

def DeprecatedKernelCreateInfosEnd(builder):
return builder.EndObject()

def End(builder):
return DeprecatedKernelCreateInfosEnd(builder)
Loading

0 comments on commit c3ecdef

Please sign in to comment.