Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Refactor the cpp interface of the saved DMFF jax model with MD engine #173

Open
dingye18 opened this issue Mar 21, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@dingye18
Copy link
Contributor

Summary

Moving the jax2tf to HLOModule for the cpp interface of the saved DMFF model.

Motivation

The current implementation of the cpp interface between the saved DMFF model and MD engine was based on the jax2tf.
The jax2tf was used to convert the the jax function to TensorFlow function.
However, as an experimental feature of JAX, jax2tf does have some limitations for production use.

  1. Limited support for custom calls. https://github.com/google/jax/tree/jaxlib-v0.4.25/jax/experimental/jax2tf#native-serialization-supports-only-select-custom-calls.
    Occurred when using JAX 0.4.24 + TF 2.15/2.14
  2. Unsupported data type f64, s64,

image

Suggested Solutions

jax-ml/jax#1871
Old solution.

Lack of documentation, more exploration required

Further Information, Files, and Links

No response

@dingye18 dingye18 added the enhancement New feature or request label Mar 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant