Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cffi: typedef shmem_{ctx|team}_t as an opaque struct type #24

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/ffibuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
def build_api(
module="api",
shmem_h="shmem.h",
shmem_ctx_t='...*',
shmem_team_t='...*',
shmem_ctx_t='struct{...;}',
shmem_team_t='struct{...;}',
):
from apicodegen import generate
ffi = cffi.FFI()
Expand All @@ -23,6 +23,12 @@ def build_api(
ffi.cdef(code)
for code in generate():
ffi.cdef(code)
for hdl in ('ctx', 'team'):
ffi.cdef(f"""
bool eq_{hdl}(shmem_{hdl}_t, shmem_{hdl}_t);
uintptr_t {hdl}2id(shmem_{hdl}_t);
shmem_{hdl}_t id2{hdl}(uintptr_t);
""")
ffi.cdef("""
int shmem_alltoallsmem_x(
shmem_team_t team,
Expand Down
1 change: 1 addition & 0 deletions src/libshmem.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

/* --- */

#include "libshmem/hdltypes.h"
#include "libshmem/fallback.h"
#include "libshmem/initfini.h"
#include "libshmem/memalloc.h"
Expand Down
7 changes: 7 additions & 0 deletions src/libshmem/hdltypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#define eq_ctx(a, b) ((a) == (b))
#define ctx2id(c) ((uintptr_t)(c))
#define id2ctx(i) ((shmem_ctx_t)(i))

#define eq_team(a, b) ((a) == (b))
#define team2id(t) ((uintptr_t)(t))
#define id2team(i) ((shmem_team_t)i)
22 changes: 11 additions & 11 deletions src/shmem4py/shmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,15 @@ def __new__(
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Ctx):
return NotImplemented
return self.ob_ctx == other.ob_ctx
return lib.eq_ctx(self.ob_ctx, other.ob_ctx)

def __ne__(self, other: Any) -> bool:
if not isinstance(other, Ctx):
return NotImplemented
return self.ob_ctx != other.ob_ctx
return not lib.eq_ctx(self.ob_ctx, other.ob_ctx)

def __bool__(self) -> bool:
return self.ob_ctx != lib.SHMEM_CTX_INVALID
return not lib.eq_ctx(self.ob_ctx, lib.SHMEM_CTX_INVALID)

def __enter__(self) -> Ctx:
return self
Expand Down Expand Up @@ -331,9 +331,9 @@ def destroy(self) -> None:
return
ctx = self.ob_ctx
self.ob_ctx = lib.SHMEM_CTX_INVALID
if ctx == lib.SHMEM_CTX_DEFAULT:
if lib.eq_ctx(ctx, lib.SHMEM_CTX_DEFAULT):
return
if ctx == lib.SHMEM_CTX_INVALID:
if lib.eq_ctx(ctx, lib.SHMEM_CTX_INVALID):
return
lib.shmem_ctx_destroy(ctx)

Expand Down Expand Up @@ -400,15 +400,15 @@ def __new__(
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Team):
return NotImplemented
return self.ob_team == other.ob_team
return lib.eq_team(self.ob_team, other.ob_team)

def __ne__(self, other: Any) -> bool:
if not isinstance(other, Team):
return NotImplemented
return self.ob_team != other.ob_team
return not lib.eq_team(self.ob_team, other.ob_team)

def __bool__(self) -> bool:
return self.ob_team != lib.SHMEM_TEAM_INVALID
return not lib.eq_team(self.ob_team, lib.SHMEM_TEAM_INVALID)

def __enter__(self) -> Team:
return self
Expand All @@ -426,11 +426,11 @@ def destroy(self) -> None:
return
team = self.ob_team
self.ob_team = lib.SHMEM_TEAM_INVALID
if team == lib.SHMEM_TEAM_WORLD:
if lib.eq_team(team, lib.SHMEM_TEAM_WORLD):
return
if team == lib.SHMEM_TEAM_SHARED:
if lib.eq_team(team, lib.SHMEM_TEAM_SHARED):
return
if team == lib.SHMEM_TEAM_INVALID:
if lib.eq_team(team, lib.SHMEM_TEAM_INVALID):
return
lib.shmem_team_destroy(team)

Expand Down