Skip to content

Commit

Permalink
Allow only serializing external tensors
Browse files Browse the repository at this point in the history
MLDrift's program cache serialization can fail if MLDrift changes
OR the GPU driver changes. It is not currently possible to detect
GPU driver changes on all devices so serializing the program
cache isn't always safe.

Allow serializing external tensors without serializing the program
cache.

PiperOrigin-RevId: 714088597
  • Loading branch information
tf-marissaw authored and copybara-github committed Jan 15, 2025
1 parent 68fb467 commit 54806d6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
16 changes: 12 additions & 4 deletions tflite/delegates/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,14 @@ TfLiteStatus SerializationEntry::GetData(TfLiteContext* context,
}
}

SerializationEntry Serialization::GetEntryImpl(
const std::string& custom_key, TfLiteContext* context,
const TfLiteDelegateParams* delegate_params) {
uint64_t Serialization::GetFingerprint(
const std::string& model_token, const std::string& custom_key,
TfLiteContext* context, const TfLiteDelegateParams* delegate_params) {
// First incorporate model_token.
// We use Fingerprint64 instead of std::hash, since the latter isn't
// guaranteed to be stable across runs. See b/172237993.
uint64_t fingerprint =
::util::Fingerprint64(model_token_.c_str(), model_token_.size());
::util::Fingerprint64(model_token.c_str(), model_token.size());

// Incorporate custom_key.
const uint64_t custom_str_fingerprint =
Expand Down Expand Up @@ -297,6 +297,14 @@ SerializationEntry Serialization::GetEntryImpl(
partition_data.size() * sizeof(int32_t));
fingerprint = CombineFingerprints(fingerprint, partition_fingerprint);
}
return fingerprint;
}

SerializationEntry Serialization::GetEntryImpl(
const std::string& custom_key, TfLiteContext* context,
const TfLiteDelegateParams* delegate_params) {
uint64_t fingerprint =
GetFingerprint(model_token_, custom_key, context, delegate_params);

// Get a fingerprint-specific lock that is passed to the SerializationKey, to
// ensure noone else gets access to an equivalent SerializationKey.
Expand Down
8 changes: 8 additions & 0 deletions tflite/delegates/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ class Serialization {
Serialization(const Serialization&) = delete;
Serialization& operator=(const Serialization&) = delete;

// Generate a unique fingerprint for the given model_token and custom_key.
// If context and delegate_params are provided, the fingerprint will be
// unique to the given context and delegate_params.
static uint64_t GetFingerprint(
const std::string& model_token, const std::string& custom_key,
TfLiteContext* context = nullptr,
const TfLiteDelegateParams* delegate_params = nullptr);

protected:
SerializationEntry GetEntryImpl(
const std::string& custom_key, TfLiteContext* context = nullptr,
Expand Down

0 comments on commit 54806d6

Please sign in to comment.