Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[On-Device-Training] Upgrade Flatbuffers to Support 2GB+ Checkpoints. #19770

Merged
merged 17 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "6df40a2471737b27271bdd9b900ab5f3aec746c7",
"commitHash": "0100f6a5779831fa7a651e4b67ef389a8752bd9b",
"repositoryUrl": "https://github.com/google/flatbuffers.git"
},
"comments": "flatbuffers"
Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b3132
# Until the 3.4.1 release this is the best option we have.
# Issue link: https://gitlab.com/libeigen/eigen/-/issues/2744
eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea095dfd276ac/eigen-e7248b26a1ed53fa030c5c459f7ea095dfd276ac.zip;be8be39fdbc6e60e94fa7870b280707069b5b81a
flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf
flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip;59422c3b5e573dd192fead2834d25951f1c1670c
fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494
fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1
google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.3.zip;bf9870756ee3f8d2d3b346b24ee3600a41c74d3d
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ FetchContent_Declare(
URL ${DEP_URL_flatbuffers}
URL_HASH SHA1=${DEP_SHA1_flatbuffers}
PATCH_COMMAND ${ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND}
FIND_PACKAGE_ARGS 1.12.0...<2.0.0 NAMES Flatbuffers
FIND_PACKAGE_ARGS 23.5.9 NAMES Flatbuffers
)

# Download a protoc binary from Internet if needed
Expand Down
40 changes: 8 additions & 32 deletions cmake/patches/flatbuffers/flatbuffers.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,11 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3987eac9..5e5462f1 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -223,7 +223,7 @@ elseif(CMAKE_COMPILER_IS_GNUCXX)
"${CMAKE_CXX_FLAGS} -std=c++0x")
endif(CYGWIN)
set(CMAKE_CXX_FLAGS
- "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow")
+ "${CMAKE_CXX_FLAGS} -Wall -pedantic -Wextra -Werror=shadow -Wno-error=stringop-overflow")
set(FLATBUFFERS_PRIVATE_CXX_FLAGS "-Wold-style-cast")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.4)
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
diff --git a/src/idl_gen_rust.cpp b/src/idl_gen_rust.cpp
index 55b8439b..dc03e8a8 100644
--- a/src/idl_gen_rust.cpp
+++ b/src/idl_gen_rust.cpp
@@ -406,7 +406,8 @@ class RustGenerator : public BaseGenerator {
// example: f(A, D::E) -> super::D::E
// does not include leaf object (typically a struct type).

- size_t i = 0;
+ // fix unused but set variable warning
+ //size_t i = 0;
std::stringstream stream;

auto s = src->components.begin();
@@ -417,7 +418,7 @@ class RustGenerator : public BaseGenerator {
if (*s != *d) { break; }
++s;
++d;
- ++i;
+ //++i;
}

for (; s != src->components.end(); ++s) { stream << "super::"; }
@@ -279,5 +279,5 @@
# Append FLATBUFFERS_CXX_FLAGS to CMAKE_CXX_FLAGS.
if(DEFINED FLATBUFFERS_CXX_FLAGS)
message(STATUS "extend CXX_FLAGS with ${FLATBUFFERS_CXX_FLAGS}")
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS}")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS} -Wno-error=stringop-overflow")
endif()
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
11 changes: 11 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,19 @@
#pragma warning(pop)
#endif

#if defined(__GNUC__)
#pragma GCC diagnostic push

#ifdef HAS_SHORTEN_64_TO_32
#pragma GCC diagnostic ignored "-Wshorten-64-to-32"
#endif

#include "flatbuffers/flatbuffers.h"

#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
#include "core/common/gsl.h"

#include "core/common/common.h"
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/flatbuffers/schema/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ e.g.
- /build/Linux/Debug/_deps/flatbuffers-build/flatc

It is possible to use another flatc as well, e.g., from a separate installation. Note that ONNX Runtime uses
FlatBuffers 1.12.
FlatBuffers 23.5.26.

To update the flatbuffers schemas and generated files:
1. Modify [the ORT file format schema](ort.fbs) or [training checkpoint schema](ort_training_checkpoint.fbs).
Expand Down
54 changes: 27 additions & 27 deletions onnxruntime/core/flatbuffers/schema/ort.fbs.h
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,8 @@ struct DimensionValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_DIM_TYPE) &&
VerifyField<int64_t>(verifier, VT_DIM_VALUE) &&
VerifyField<int8_t>(verifier, VT_DIM_TYPE, 1) &&
VerifyField<int64_t>(verifier, VT_DIM_VALUE, 8) &&
VerifyOffset(verifier, VT_DIM_PARAM) &&
verifier.VerifyString(dim_param()) &&
verifier.EndTable();
Expand Down Expand Up @@ -634,7 +634,7 @@ struct TensorTypeAndShape FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_ELEM_TYPE) &&
VerifyField<int32_t>(verifier, VT_ELEM_TYPE, 4) &&
VerifyOffset(verifier, VT_SHAPE) &&
verifier.VerifyTable(shape()) &&
verifier.EndTable();
Expand Down Expand Up @@ -687,7 +687,7 @@ struct MapType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_KEY_TYPE) &&
VerifyField<int32_t>(verifier, VT_KEY_TYPE, 4) &&
VerifyOffset(verifier, VT_VALUE_TYPE) &&
verifier.VerifyTable(value_type()) &&
verifier.EndTable();
Expand Down Expand Up @@ -787,7 +787,7 @@ struct NodeEdge FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_NODE_INDEX) &&
VerifyField<uint32_t>(verifier, VT_NODE_INDEX, 4) &&
VerifyOffset(verifier, VT_INPUT_EDGES) &&
verifier.VerifyVector(input_edges()) &&
VerifyOffset(verifier, VT_OUTPUT_EDGES) &&
Expand Down Expand Up @@ -911,11 +911,11 @@ struct Node FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
verifier.VerifyString(doc_string()) &&
VerifyOffset(verifier, VT_DOMAIN) &&
verifier.VerifyString(domain()) &&
VerifyField<int32_t>(verifier, VT_SINCE_VERSION) &&
VerifyField<uint32_t>(verifier, VT_INDEX) &&
VerifyField<int32_t>(verifier, VT_SINCE_VERSION, 4) &&
VerifyField<uint32_t>(verifier, VT_INDEX, 4) &&
VerifyOffset(verifier, VT_OP_TYPE) &&
verifier.VerifyString(op_type()) &&
VerifyField<int32_t>(verifier, VT_TYPE) &&
VerifyField<int32_t>(verifier, VT_TYPE, 4) &&
VerifyOffset(verifier, VT_EXECUTION_PROVIDER_TYPE) &&
verifier.VerifyString(execution_provider_type()) &&
VerifyOffset(verifier, VT_INPUTS) &&
Expand Down Expand Up @@ -1174,7 +1174,7 @@ struct TypeInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_DENOTATION) &&
verifier.VerifyString(denotation()) &&
VerifyField<uint8_t>(verifier, VT_VALUE_TYPE) &&
VerifyField<uint8_t>(verifier, VT_VALUE_TYPE, 1) &&
VerifyOffset(verifier, VT_VALUE) &&
VerifyTypeInfoValue(verifier, value(), value_type()) &&
verifier.EndTable();
Expand Down Expand Up @@ -1259,7 +1259,7 @@ struct OperatorSetId FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_DOMAIN) &&
verifier.VerifyString(domain()) &&
VerifyField<int64_t>(verifier, VT_VERSION) &&
VerifyField<int64_t>(verifier, VT_VERSION, 8) &&
verifier.EndTable();
}
};
Expand Down Expand Up @@ -1343,7 +1343,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
verifier.VerifyString(doc_string()) &&
VerifyOffset(verifier, VT_DIMS) &&
verifier.VerifyVector(dims()) &&
VerifyField<int32_t>(verifier, VT_DATA_TYPE) &&
VerifyField<int32_t>(verifier, VT_DATA_TYPE, 4) &&
VerifyOffset(verifier, VT_RAW_DATA) &&
verifier.VerifyVector(raw_data()) &&
VerifyOffset(verifier, VT_STRING_DATA) &&
Expand Down Expand Up @@ -1568,9 +1568,9 @@ struct Attribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
verifier.VerifyString(name()) &&
VerifyOffset(verifier, VT_DOC_STRING) &&
verifier.VerifyString(doc_string()) &&
VerifyField<int32_t>(verifier, VT_TYPE) &&
VerifyField<float>(verifier, VT_F) &&
VerifyField<int64_t>(verifier, VT_I) &&
VerifyField<int32_t>(verifier, VT_TYPE, 4) &&
VerifyField<float>(verifier, VT_F, 4) &&
VerifyField<int64_t>(verifier, VT_I, 8) &&
VerifyOffset(verifier, VT_S) &&
verifier.VerifyString(s()) &&
VerifyOffset(verifier, VT_T) &&
Expand Down Expand Up @@ -1759,12 +1759,12 @@ struct NodesToOptimizeIndices FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_NODE_INDICES) &&
verifier.VerifyVector(node_indices()) &&
VerifyField<uint32_t>(verifier, VT_NUM_INPUTS) &&
VerifyField<uint32_t>(verifier, VT_NUM_OUTPUTS) &&
VerifyField<uint8_t>(verifier, VT_HAS_VARIADIC_INPUT) &&
VerifyField<uint8_t>(verifier, VT_HAS_VARIADIC_OUTPUT) &&
VerifyField<uint32_t>(verifier, VT_NUM_VARIADIC_INPUTS) &&
VerifyField<uint32_t>(verifier, VT_NUM_VARIADIC_OUTPUTS) &&
VerifyField<uint32_t>(verifier, VT_NUM_INPUTS, 4) &&
VerifyField<uint32_t>(verifier, VT_NUM_OUTPUTS, 4) &&
VerifyField<uint8_t>(verifier, VT_HAS_VARIADIC_INPUT, 1) &&
VerifyField<uint8_t>(verifier, VT_HAS_VARIADIC_OUTPUT, 1) &&
VerifyField<uint32_t>(verifier, VT_NUM_VARIADIC_INPUTS, 4) &&
VerifyField<uint32_t>(verifier, VT_NUM_VARIADIC_OUTPUTS, 4) &&
verifier.EndTable();
}
};
Expand Down Expand Up @@ -1862,8 +1862,8 @@ struct DeprecatedNodeIndexAndKernelDefHash FLATBUFFERS_FINAL_CLASS : private fla
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_NODE_INDEX) &&
VerifyField<uint64_t>(verifier, VT_KERNEL_DEF_HASH) &&
VerifyField<uint32_t>(verifier, VT_NODE_INDEX, 4) &&
VerifyField<uint64_t>(verifier, VT_KERNEL_DEF_HASH, 8) &&
verifier.EndTable();
}
};
Expand Down Expand Up @@ -2161,7 +2161,7 @@ struct Graph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, VT_NODES) &&
verifier.VerifyVector(nodes()) &&
verifier.VerifyVectorOfTables(nodes()) &&
VerifyField<uint32_t>(verifier, VT_MAX_NODE_INDEX) &&
VerifyField<uint32_t>(verifier, VT_MAX_NODE_INDEX, 4) &&
VerifyOffset(verifier, VT_NODE_EDGES) &&
verifier.VerifyVector(node_edges()) &&
verifier.VerifyVectorOfTables(node_edges()) &&
Expand Down Expand Up @@ -2390,7 +2390,7 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int64_t>(verifier, VT_IR_VERSION) &&
VerifyField<int64_t>(verifier, VT_IR_VERSION, 8) &&
VerifyOffset(verifier, VT_OPSET_IMPORT) &&
verifier.VerifyVector(opset_import()) &&
verifier.VerifyVectorOfTables(opset_import()) &&
Expand All @@ -2400,7 +2400,7 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
verifier.VerifyString(producer_version()) &&
VerifyOffset(verifier, VT_DOMAIN) &&
verifier.VerifyString(domain()) &&
VerifyField<int64_t>(verifier, VT_MODEL_VERSION) &&
VerifyField<int64_t>(verifier, VT_MODEL_VERSION, 8) &&
VerifyOffset(verifier, VT_DOC_STRING) &&
verifier.VerifyString(doc_string()) &&
VerifyOffset(verifier, VT_GRAPH) &&
Expand Down Expand Up @@ -2740,8 +2740,8 @@ struct ArgTypeAndIndex FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_ARG_TYPE) &&
VerifyField<uint32_t>(verifier, VT_INDEX) &&
VerifyField<int8_t>(verifier, VT_ARG_TYPE, 1) &&
VerifyField<uint32_t>(verifier, VT_INDEX, 4) &&
verifier.EndTable();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, VT_FROZEN_PARAMS) &&
verifier.VerifyVector(frozen_params()) &&
verifier.VerifyVectorOfTables(frozen_params()) &&
VerifyField<uint8_t>(verifier, VT_IS_NOMINAL_STATE) &&
VerifyField<uint8_t>(verifier, VT_IS_NOMINAL_STATE, 1) &&
verifier.EndTable();
}
};
Expand Down Expand Up @@ -206,8 +206,8 @@ struct OptimizerGroup FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_GROUP_NAME) &&
verifier.VerifyString(group_name()) &&
VerifyField<int64_t>(verifier, VT_STEP) &&
VerifyField<float>(verifier, VT_INITIAL_LEARNING_RATE) &&
VerifyField<int64_t>(verifier, VT_STEP, 8) &&
VerifyField<float>(verifier, VT_INITIAL_LEARNING_RATE, 4) &&
VerifyOffset(verifier, VT_OPTIMIZER_STATES) &&
verifier.VerifyVector(optimizer_states()) &&
verifier.VerifyVectorOfTables(optimizer_states()) &&
Expand Down Expand Up @@ -289,7 +289,7 @@ struct IntProperty FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_NAME) &&
verifier.VerifyString(name()) &&
VerifyField<int64_t>(verifier, VT_VALUE) &&
VerifyField<int64_t>(verifier, VT_VALUE, 8) &&
verifier.EndTable();
}
};
Expand Down Expand Up @@ -353,7 +353,7 @@ struct FloatProperty FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_NAME) &&
verifier.VerifyString(name()) &&
VerifyField<float>(verifier, VT_VALUE) &&
VerifyField<float>(verifier, VT_VALUE, 4) &&
verifier.EndTable();
}
};
Expand Down Expand Up @@ -572,7 +572,7 @@ struct Checkpoint FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_VERSION) &&
VerifyField<int32_t>(verifier, VT_VERSION, 4) &&
VerifyOffset(verifier, VT_MODULE_STATE) &&
verifier.VerifyTable(module_state()) &&
VerifyOffset(verifier, VT_OPTIMIZER_GROUPS) &&
Expand Down
Loading
Loading