Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support >2GB of Tensor data in training checkpoint #20077

Merged
merged 27 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
45e20fb
Add ability to store initializer data in an external file.
skottmckay Mar 26, 2024
3187fba
Make functions references to avoid copy of non-trivial lambda.
skottmckay Mar 26, 2024
1bed68c
Address PR comments.
skottmckay Mar 27, 2024
383a00f
fixed some small bugs + wrote first working(?) unit test
carzh Apr 4, 2024
e97f9a6
added code that converts flatbuffers with external data to ort tensor…
carzh Apr 9, 2024
4094363
fixed reader segfault + updated tests
carzh Apr 9, 2024
bca6849
moved function to try to prevent safeint build failure in the pipelines
carzh Apr 9, 2024
fc6af96
finished data verification for test for LoadOrtTensorOrtFormat and tr…
carzh Apr 10, 2024
6b9459e
oops lintrunner
carzh Apr 10, 2024
bb07efa
fixed bug with creating uint8 tensorproto
carzh Apr 10, 2024
321a924
attempting to resolve build errors in pipelines
carzh Apr 10, 2024
65d6a5f
applied some of the suggestions
carzh Apr 12, 2024
dd04354
applied suggestions, added changes to save checkpoint from checkpoint…
carzh Apr 17, 2024
972859a
fixed checkpoint test + fixed some minor checkpoint bugs
carzh Apr 18, 2024
adf3f33
export_model_for_inferencing with external data file
carzh Apr 18, 2024
7f4fc7f
added some suggestions + working export_for_inferencing test with ext…
carzh Apr 19, 2024
87f8199
Merge remote-tracking branch 'origin/main' into scottsbranch
carzh Apr 19, 2024
cd0ded0
added platform-specific error messages
carzh Apr 19, 2024
f7cf4c3
Address PR comments.
skottmckay Apr 19, 2024
70b5625
Remove temp debug output
skottmckay Apr 19, 2024
738727d
Make tests more robust.
skottmckay Apr 19, 2024
641dd45
Fix wasm build error
skottmckay Apr 19, 2024
4e40740
Remove duplicate fbs generated file. Use the default '.fbs.h' file fo…
skottmckay Apr 19, 2024
5824779
Add manual #include path edit to generated file
skottmckay Apr 19, 2024
7fa65ea
added some suggestions
carzh Apr 20, 2024
48e0db3
Address PR comments
skottmckay Apr 20, 2024
903fb18
Address PR comment
skottmckay Apr 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,16 @@ file(GLOB onnxruntime_test_common_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/common/logging/*.h"
)

file(GLOB onnxruntime_test_quantiztion_src CONFIGURE_DEPENDS
file(GLOB onnxruntime_test_quantization_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/quantization/*.cc"
"${TEST_SRC_DIR}/quantization/*.h"
)

file(GLOB onnxruntime_test_flatbuffers_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/flatbuffers/*.cc"
"${TEST_SRC_DIR}/flatbuffers/*.h"
)

if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)

file(GLOB onnxruntime_test_ir_src CONFIGURE_DEPENDS
Expand Down Expand Up @@ -767,7 +772,8 @@ if(NOT IOS)
endif()

set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src}
${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src})
${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantization_src}
${onnxruntime_test_flatbuffers_src})

if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/flatbuffers/flatbuffers_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,5 +315,4 @@ bool IsOrtFormatModelBytes(const void* bytes, int num_bytes) {
return num_bytes > 8 && // check buffer is large enough to contain identifier so we don't read random memory
fbs::InferenceSessionBufferHasIdentifier(bytes);
}

} // namespace onnxruntime::fbs::utils
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@
class ArgType(object):
INPUT = 0
OUTPUT = 1

Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ class ArgTypeAndIndex(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsArgTypeAndIndex(cls, buf, offset):
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = ArgTypeAndIndex()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsArgTypeAndIndex(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def ArgTypeAndIndexBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)

Expand All @@ -38,7 +42,26 @@ def Index(self):
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
return 0

def ArgTypeAndIndexStart(builder): builder.StartObject(2)
def ArgTypeAndIndexAddArgType(builder, argType): builder.PrependInt8Slot(0, argType, 0)
def ArgTypeAndIndexAddIndex(builder, index): builder.PrependUint32Slot(1, index, 0)
def ArgTypeAndIndexEnd(builder): return builder.EndObject()
def ArgTypeAndIndexStart(builder):
builder.StartObject(2)

def Start(builder):
ArgTypeAndIndexStart(builder)

def ArgTypeAndIndexAddArgType(builder, argType):
builder.PrependInt8Slot(0, argType, 0)

def AddArgType(builder, argType):
ArgTypeAndIndexAddArgType(builder, argType)

def ArgTypeAndIndexAddIndex(builder, index):
builder.PrependUint32Slot(1, index, 0)

def AddIndex(builder, index):
ArgTypeAndIndexAddIndex(builder, index)

def ArgTypeAndIndexEnd(builder):
return builder.EndObject()

def End(builder):
return ArgTypeAndIndexEnd(builder)
145 changes: 124 additions & 21 deletions onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/Attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ class Attribute(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsAttribute(cls, buf, offset):
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Attribute()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsAttribute(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def AttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)

Expand Down Expand Up @@ -212,23 +216,122 @@ def GraphsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
return o == 0

def AttributeStart(builder): builder.StartObject(13)
def AttributeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def AttributeAddDocString(builder, docString): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
def AttributeAddType(builder, type): builder.PrependInt32Slot(2, type, 0)
def AttributeAddF(builder, f): builder.PrependFloat32Slot(3, f, 0.0)
def AttributeAddI(builder, i): builder.PrependInt64Slot(4, i, 0)
def AttributeAddS(builder, s): builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(s), 0)
def AttributeAddT(builder, t): builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(t), 0)
def AttributeAddG(builder, g): builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(g), 0)
def AttributeAddFloats(builder, floats): builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(floats), 0)
def AttributeStartFloatsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def AttributeAddInts(builder, ints): builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(ints), 0)
def AttributeStartIntsVector(builder, numElems): return builder.StartVector(8, numElems, 8)
def AttributeAddStrings(builder, strings): builder.PrependUOffsetTRelativeSlot(10, flatbuffers.number_types.UOffsetTFlags.py_type(strings), 0)
def AttributeStartStringsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def AttributeAddTensors(builder, tensors): builder.PrependUOffsetTRelativeSlot(11, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0)
def AttributeStartTensorsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def AttributeAddGraphs(builder, graphs): builder.PrependUOffsetTRelativeSlot(12, flatbuffers.number_types.UOffsetTFlags.py_type(graphs), 0)
def AttributeStartGraphsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def AttributeEnd(builder): return builder.EndObject()
def AttributeStart(builder):
builder.StartObject(13)

def Start(builder):
AttributeStart(builder)

def AttributeAddName(builder, name):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)

def AddName(builder, name):
AttributeAddName(builder, name)

def AttributeAddDocString(builder, docString):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)

def AddDocString(builder, docString):
AttributeAddDocString(builder, docString)

def AttributeAddType(builder, type):
builder.PrependInt32Slot(2, type, 0)

def AddType(builder, type):
AttributeAddType(builder, type)

def AttributeAddF(builder, f):
builder.PrependFloat32Slot(3, f, 0.0)

def AddF(builder, f):
AttributeAddF(builder, f)

def AttributeAddI(builder, i):
builder.PrependInt64Slot(4, i, 0)

def AddI(builder, i):
AttributeAddI(builder, i)

def AttributeAddS(builder, s):
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(s), 0)

def AddS(builder, s):
AttributeAddS(builder, s)

def AttributeAddT(builder, t):
builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(t), 0)

def AddT(builder, t):
AttributeAddT(builder, t)

def AttributeAddG(builder, g):
builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(g), 0)

def AddG(builder, g):
AttributeAddG(builder, g)

def AttributeAddFloats(builder, floats):
builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(floats), 0)

def AddFloats(builder, floats):
AttributeAddFloats(builder, floats)

def AttributeStartFloatsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartFloatsVector(builder, numElems: int) -> int:
return AttributeStartFloatsVector(builder, numElems)

def AttributeAddInts(builder, ints):
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(ints), 0)

def AddInts(builder, ints):
AttributeAddInts(builder, ints)

def AttributeStartIntsVector(builder, numElems):
return builder.StartVector(8, numElems, 8)

def StartIntsVector(builder, numElems: int) -> int:
return AttributeStartIntsVector(builder, numElems)

def AttributeAddStrings(builder, strings):
builder.PrependUOffsetTRelativeSlot(10, flatbuffers.number_types.UOffsetTFlags.py_type(strings), 0)

def AddStrings(builder, strings):
AttributeAddStrings(builder, strings)

def AttributeStartStringsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartStringsVector(builder, numElems: int) -> int:
return AttributeStartStringsVector(builder, numElems)

def AttributeAddTensors(builder, tensors):
builder.PrependUOffsetTRelativeSlot(11, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0)

def AddTensors(builder, tensors):
AttributeAddTensors(builder, tensors)

def AttributeStartTensorsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartTensorsVector(builder, numElems: int) -> int:
return AttributeStartTensorsVector(builder, numElems)

def AttributeAddGraphs(builder, graphs):
builder.PrependUOffsetTRelativeSlot(12, flatbuffers.number_types.UOffsetTFlags.py_type(graphs), 0)

def AddGraphs(builder, graphs):
AttributeAddGraphs(builder, graphs)

def AttributeStartGraphsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartGraphsVector(builder, numElems: int) -> int:
return AttributeStartGraphsVector(builder, numElems)

def AttributeEnd(builder):
return builder.EndObject()

def End(builder):
return AttributeEnd(builder)
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ class AttributeType(object):
GRAPHS = 10
SPARSE_TENSOR = 11
SPARSE_TENSORS = 12

54 changes: 46 additions & 8 deletions onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/Checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ class Checkpoint(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsCheckpoint(cls, buf, offset):
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Checkpoint()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsCheckpoint(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def CheckpointBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)

Expand Down Expand Up @@ -78,10 +82,44 @@ def PropertyBag(self):
return obj
return None

def CheckpointStart(builder): builder.StartObject(4)
def CheckpointAddVersion(builder, version): builder.PrependInt32Slot(0, version, 0)
def CheckpointAddModuleState(builder, moduleState): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(moduleState), 0)
def CheckpointAddOptimizerGroups(builder, optimizerGroups): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerGroups), 0)
def CheckpointStartOptimizerGroupsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def CheckpointAddPropertyBag(builder, propertyBag): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(propertyBag), 0)
def CheckpointEnd(builder): return builder.EndObject()
def CheckpointStart(builder):
builder.StartObject(4)

def Start(builder):
CheckpointStart(builder)

def CheckpointAddVersion(builder, version):
builder.PrependInt32Slot(0, version, 0)

def AddVersion(builder, version):
CheckpointAddVersion(builder, version)

def CheckpointAddModuleState(builder, moduleState):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(moduleState), 0)

def AddModuleState(builder, moduleState):
CheckpointAddModuleState(builder, moduleState)

def CheckpointAddOptimizerGroups(builder, optimizerGroups):
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerGroups), 0)

def AddOptimizerGroups(builder, optimizerGroups):
CheckpointAddOptimizerGroups(builder, optimizerGroups)

def CheckpointStartOptimizerGroupsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartOptimizerGroupsVector(builder, numElems: int) -> int:
return CheckpointStartOptimizerGroupsVector(builder, numElems)

def CheckpointAddPropertyBag(builder, propertyBag):
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(propertyBag), 0)

def AddPropertyBag(builder, propertyBag):
CheckpointAddPropertyBag(builder, propertyBag)

def CheckpointEnd(builder):
return builder.EndObject()

def End(builder):
return CheckpointEnd(builder)
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ class DeprecatedKernelCreateInfos(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsDeprecatedKernelCreateInfos(cls, buf, offset):
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = DeprecatedKernelCreateInfos()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsDeprecatedKernelCreateInfos(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def DeprecatedKernelCreateInfosBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)

Expand Down Expand Up @@ -79,9 +83,38 @@ def KernelDefHashesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0

def DeprecatedKernelCreateInfosStart(builder): builder.StartObject(2)
def DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(nodeIndices), 0)
def DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelDefHashes), 0)
def DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems): return builder.StartVector(8, numElems, 8)
def DeprecatedKernelCreateInfosEnd(builder): return builder.EndObject()
def DeprecatedKernelCreateInfosStart(builder):
builder.StartObject(2)

def Start(builder):
DeprecatedKernelCreateInfosStart(builder)

def DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(nodeIndices), 0)

def AddNodeIndices(builder, nodeIndices):
DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices)

def DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartNodeIndicesVector(builder, numElems: int) -> int:
return DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems)

def DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelDefHashes), 0)

def AddKernelDefHashes(builder, kernelDefHashes):
DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes)

def DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems):
return builder.StartVector(8, numElems, 8)

def StartKernelDefHashesVector(builder, numElems: int) -> int:
return DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems)

def DeprecatedKernelCreateInfosEnd(builder):
return builder.EndObject()

def End(builder):
return DeprecatedKernelCreateInfosEnd(builder)
Loading
Loading