Skip to content

Commit

Permalink
Setup the flaxlib in C++, using Meson and Nanobind.
Browse files Browse the repository at this point in the history
Not enabling GitHub auto-build of flaxlib yet - gonna do that later.

PiperOrigin-RevId: 696623140
  • Loading branch information
IvyZX authored and Flax Authors committed Nov 14, 2024
1 parent f265a5e commit 72f4971
Show file tree
Hide file tree
Showing 16 changed files with 77 additions and 29 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/flax_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,10 @@ jobs:
uses: astral-sh/setup-uv@v2
with:
version: "0.3.0"
- name: Setup Rust (flaxlib)
uses: actions-rust-lang/setup-rust-toolchain@v1

- name: Install dependencies
run: |
uv sync --extra all --extra testing --extra docs
uv pip install ./flaxlib
- name: Install JAX
run: |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ build/
docs*/**/_autosummary
docs*/_build
docs*/**/tmp
flaxlib_src/build
flaxlib_src/builddir
flaxlib_src/dist
flaxlib_src/subprojects

# used by direnv
.envrc
Expand Down
1 change: 0 additions & 1 deletion flaxlib/README.md

This file was deleted.

15 changes: 0 additions & 15 deletions flaxlib/flaxlib/__init__.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
34 changes: 34 additions & 0 deletions flaxlib_src/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# flaxlib

## Build flaxlib from source

Install necessary dependencies to build the C++ based package.

```shell
pip install meson-python ninja build
```

Clone the Flax repository, navigate to the flaxlib source directory.

```shell
git clone [email protected]:google/flax.git
cd flax/flaxlib_src
```

Configure the build.

```shell
mkdir -p subprojects
meson wrap install robin-map
meson wrap install nanobind
meson setup builddir
```

Compile the code. You'll need to run this repeatedly if you modify the source
code. Note that the actual wheel name will differ depending on your system.

```shell
meson compile -C builddir
python -m build . -w
pip install dist/flaxlib-0.0.1-cp311-cp311-macosx_14_0_arm64.whl --force-reinstall
```
File renamed without changes.
14 changes: 14 additions & 0 deletions flaxlib_src/meson.build
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
project(
'flaxlib',
'cpp',
version: '0.0.1',
default_options: ['cpp_std=c++17'],
)
py = import('python').find_installation()
nanobind_dep = dependency('nanobind', static: true)
py.extension_module(
'flaxlib',
sources: ['src/lib.cc'],
dependencies: [nanobind_dep],
install: true,
)
8 changes: 3 additions & 5 deletions flaxlib/pyproject.toml → flaxlib_src/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[build-system]
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"
requires = ['meson-python']
build-backend = 'mesonpy'

[project]
name = "flaxlib"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: C++",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
Expand All @@ -15,5 +15,3 @@ dynamic = ["version"]
tests = [
"pytest",
]
[tool.maturin]
features = ["pyo3/extension-module"]
14 changes: 14 additions & 0 deletions flaxlib_src/src/lib.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <string>

#include "nanobind/nanobind.h"
#include "nanobind/stl/string.h"

namespace flaxlib {
std::string sum_as_string(int a, int b) {
return std::to_string(a + b);
}

NB_MODULE(flaxlib, m) {
m.def("sum_as_string", &sum_as_string);
}
} // namespace flaxlib
File renamed without changes.
File renamed without changes.
13 changes: 8 additions & 5 deletions tests/flaxlib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
import flaxlib

# TODO: Re-enable this test after setting up CI build for flaxlib CC.

class TestFlaxlib(absltest.TestCase):
# from absl.testing import absltest
# import flaxlib

def test_flaxlib(self):
self.assertEqual(flaxlib.sum_as_string(1, 2), '3')

# class TestFlaxlib(absltest.TestCase):

# def test_flaxlib(self):
# self.assertEqual(flaxlib.sum_as_string(1, 2), '3')

0 comments on commit 72f4971

Please sign in to comment.