Skip to content

Commit

Permalink
Cache TensorFlow Metadata proto files (#921)
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones authored Feb 20, 2024
1 parent 083b586 commit b49795b
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
import sbt._
import sbt.util.CacheImplicits._
import sbtprotoc.ProtocPlugin.ProtobufConfig
import com.github.sbt.git.SbtGit.GitKeys.gitRemoteRepo
import com.typesafe.tools.mima.core._
Expand Down Expand Up @@ -550,6 +551,11 @@ lazy val protobuf = project
description := "Magnolia add-on for Google Protocol Buffer"
)

val tensorflowMetadataSourcesDir =
settingKey[File]("Directory containing TensorFlow metadata proto files")
val tensorflowMetadata =
taskKey[Seq[File]]("Retrieve TensorFlow metadata proto files")

lazy val tensorflow = project
.in(file("tensorflow"))
.dependsOn(
Expand All @@ -570,14 +576,34 @@ lazy val tensorflow = project
// remove compilation warnings for generated java files
javacOptions ~= { _.filterNot(_ == "-Xlint:all") },
// tensorflow metadata protos are not packaged into a jar. Manually extract them as external
Compile / tensorflowMetadataSourcesDir := target.value / s"metadata-$tensorflowMetadataVersion",
Compile / PB.protoSources += target.value / s"metadata-$tensorflowMetadataVersion",
Compile / tensorflowMetadata := {
def work(tensorFlowMetadataVersion: String) = {
val tfMetadata = url(
s"https://github.com/tensorflow/metadata/archive/refs/tags/v$tensorFlowMetadataVersion.zip"
)
IO.unzipURL(tfMetadata, target.value, "*.proto").toSeq
}

val cacheStoreFactory = streams.value.cacheStoreFactory
val root = (Compile / tensorflowMetadataSourcesDir).value
val tracker =
Tracked.inputChanged(cacheStoreFactory.make("input")) { (versionChanged, version: String) =>
val cached = Tracked.outputChanged(cacheStoreFactory.make("output")) {
(outputChanged: Boolean, files: Seq[HashFileInfo]) =>
if (versionChanged || outputChanged) work(version)
else files.map(_.file)
}
cached(() => (root ** "*.proto").get().map(FileInfo.hash(_)))
}

tracker(tensorflowMetadataVersion)
},
Compile / PB.unpackDependencies := {
val tfMetadata = new URL(
s"https://github.com/tensorflow/metadata/archive/refs/tags/v$tensorflowMetadataVersion.zip"
)
val protoFiles = IO.unzipURL(tfMetadata, target.value, _.endsWith(".proto"))
val root = target.value / s"metadata-$tensorflowMetadataVersion"
val metadataDep = ProtocPlugin.UnpackedDependency(protoFiles.toSeq, Seq.empty)
val protoFiles = (Compile / tensorflowMetadata).value
val root = (Compile / tensorflowMetadataSourcesDir).value
val metadataDep = ProtocPlugin.UnpackedDependency(protoFiles, Seq.empty)
val deps = (Compile / PB.unpackDependencies).value
new ProtocPlugin.UnpackedDependencies(deps.mappedFiles ++ Map(root -> metadataDep))
},
Expand Down

0 comments on commit b49795b

Please sign in to comment.