Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cray OpenSHMEM-X 11 - do not merge #22

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/ffibuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
def build_api(
module="api",
shmem_h="shmem.h",
shmem_ctx_t='...*',
shmem_team_t='...*',
shmem_ctx_t='struct{...;}',
shmem_team_t='struct{...;}',
):
from apicodegen import generate
ffi = cffi.FFI()
Expand All @@ -23,6 +23,12 @@ def build_api(
ffi.cdef(code)
for code in generate():
ffi.cdef(code)
for hdl in ('ctx', 'team'):
ffi.cdef(f"""
bool eq_{hdl}(shmem_{hdl}_t, shmem_{hdl}_t);
uintptr_t {hdl}2id(shmem_{hdl}_t);
shmem_{hdl}_t id2{hdl}(uintptr_t);
""")
ffi.cdef("""
int shmem_alltoallsmem_x(
shmem_team_t team,
Expand Down
1 change: 1 addition & 0 deletions src/libshmem.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

/* --- */

#include "libshmem/hdltypes.h"
#include "libshmem/fallback.h"
#include "libshmem/initfini.h"
#include "libshmem/memalloc.h"
Expand Down
46 changes: 6 additions & 40 deletions src/libshmem/compat/cray.h
Original file line number Diff line number Diff line change
@@ -1,46 +1,12 @@
#ifndef PySHMEM_COMPAT_CRAY_H
#define PySHMEM_COMPAT_CRAY_H

static
void shmem_complexf_sum_to_all(float _Complex *dest, const float _Complex *source, int nreduce,
int PE_start, int logPE_stride, int PE_size,
float _Complex *pWrk, long *pSync)
{
shmem_float_sum_to_all((float*)dest, (float*)source, 2*nreduce,
PE_start, logPE_stride, PE_size,
(float*)pWrk, pSync);
}

static
void shmem_complexd_sum_to_all(double _Complex *dest, const double _Complex *source, int nreduce,
int PE_start, int logPE_stride, int PE_size,
double _Complex *pWrk, long *pSync)
{
shmem_double_sum_to_all((double*)dest, (double*)source, 2*nreduce,
PE_start, logPE_stride, PE_size,
(double*)pWrk, pSync);
}

static
void shmem_complexf_prod_to_all(float _Complex *dest, const float _Complex *source, int nreduce,
int PE_start, int logPE_stride, int PE_size,
float _Complex *pWrk, long *pSync)
{
(void)dest; (void)source; (void)nreduce;
(void)PE_start; (void)logPE_stride;
(void)PE_size; (void)pWrk; (void)pSync;
PySHMEM_UNAVAILABLE;
}
#if CRAY_SHMEM_MAJOR_VERSION == 9
#include "cray09.h"
#endif

static
void shmem_complexd_prod_to_all(double _Complex *dest, const double _Complex *source, int nreduce,
int PE_start, int logPE_stride, int PE_size,
double _Complex *pWrk, long *pSync)
{
(void)dest; (void)source; (void)nreduce;
(void)PE_start; (void)logPE_stride; (void)PE_size;
(void)pWrk; (void)pSync;
PySHMEM_UNAVAILABLE;
}
#if CRAY_SHMEM_MAJOR_VERSION == 11
#include "cray11.h"
#endif

#endif
41 changes: 41 additions & 0 deletions src/libshmem/compat/cray09.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
static
void shmem_complexf_sum_to_all(float _Complex *dest, const float _Complex *source, int nreduce,
int PE_start, int logPE_stride, int PE_size,
float _Complex *pWrk, long *pSync)
{
shmem_float_sum_to_all((float*)dest, (float*)source, 2*nreduce,
PE_start, logPE_stride, PE_size,
(float*)pWrk, pSync);
}

static
void shmem_complexd_sum_to_all(double _Complex *dest, const double _Complex *source, int nreduce,
int PE_start, int logPE_stride, int PE_size,
double _Complex *pWrk, long *pSync)
{
shmem_double_sum_to_all((double*)dest, (double*)source, 2*nreduce,
PE_start, logPE_stride, PE_size,
(double*)pWrk, pSync);
}

static
void shmem_complexf_prod_to_all(float _Complex *dest, const float _Complex *source, int nreduce,
int PE_start, int logPE_stride, int PE_size,
float _Complex *pWrk, long *pSync)
{
(void)dest; (void)source; (void)nreduce;
(void)PE_start; (void)logPE_stride;
(void)PE_size; (void)pWrk; (void)pSync;
PySHMEM_UNAVAILABLE;
}

static
void shmem_complexd_prod_to_all(double _Complex *dest, const double _Complex *source, int nreduce,
int PE_start, int logPE_stride, int PE_size,
double _Complex *pWrk, long *pSync)
{
(void)dest; (void)source; (void)nreduce;
(void)PE_start; (void)logPE_stride; (void)PE_size;
(void)pWrk; (void)pSync;
PySHMEM_UNAVAILABLE;
}
567 changes: 567 additions & 0 deletions src/libshmem/compat/cray11.h

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions src/libshmem/config/cray.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,27 @@
#ifndef PySHMEM_CONFIG_CRAY_H
#define PySHMEM_CONFIG_CRAY_H

#if CRAY_SHMEM_MAJOR_VERSION >= 11
#define PySHMEM_HAVE_shmem_malloc_with_hints 1
#define PySHMEM_HAVE_shmem_team_t 1
#define PySHMEM_HAVE_SHMEM_CTX_INVALID 1
#define PySHMEM_HAVE_shmem_amo_nbi 1
#define PySHMEM_HAVE_shmem_put_signal 1
#define PySHMEM_HAVE_shmem_signal_fetch 1
#define PySHMEM_HAVE_shmem_signal_wait_until 1
#define PySHMEM_HAVE_shmem_broadcast 1
#define PySHMEM_HAVE_shmem_collect 1
#define PySHMEM_HAVE_shmem_fcollect 1
#define PySHMEM_HAVE_shmem_alltoall 1
#define PySHMEM_HAVE_shmem_alltoalls 1
#define PySHMEM_HAVE_shmem_broadcastmem 1
#define PySHMEM_HAVE_shmem_collectmem 1
#define PySHMEM_HAVE_shmem_fcollectmem 1
#define PySHMEM_HAVE_shmem_alltoallmem 1
#define PySHMEM_HAVE_shmem_alltoallsmem 1
#define PySHMEM_HAVE_shmem_reduce 1
#define PySHMEM_HAVE_shmem_wait_test_many 1
/* #define PySHMEM_HAVE_shmem_pcontrol 1 */
#endif

#endif
7 changes: 7 additions & 0 deletions src/libshmem/hdltypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#define eq_ctx(a, b) ((a) == (b))
#define ctx2id(c) ((uintptr_t)(c))
#define id2ctx(i) ((shmem_ctx_t)(i))

#define eq_team(a, b) ((a) == (b))
#define team2id(t) ((uintptr_t)(t))
#define id2team(i) ((shmem_team_t)i)
22 changes: 11 additions & 11 deletions src/shmem4py/shmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,15 @@ def __new__(
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Ctx):
return NotImplemented
return self.ob_ctx == other.ob_ctx
return lib.eq_ctx(self.ob_ctx, other.ob_ctx)

def __ne__(self, other: Any) -> bool:
if not isinstance(other, Ctx):
return NotImplemented
return self.ob_ctx != other.ob_ctx
return not lib.eq_ctx(self.ob_ctx, other.ob_ctx)

def __bool__(self) -> bool:
return self.ob_ctx != lib.SHMEM_CTX_INVALID
return not lib.eq_ctx(self.ob_ctx, lib.SHMEM_CTX_INVALID)

def __enter__(self) -> Ctx:
return self
Expand Down Expand Up @@ -331,9 +331,9 @@ def destroy(self) -> None:
return
ctx = self.ob_ctx
self.ob_ctx = lib.SHMEM_CTX_INVALID
if ctx == lib.SHMEM_CTX_DEFAULT:
if lib.eq_ctx(ctx, lib.SHMEM_CTX_DEFAULT):
return
if ctx == lib.SHMEM_CTX_INVALID:
if lib.eq_ctx(ctx, lib.SHMEM_CTX_INVALID):
return
lib.shmem_ctx_destroy(ctx)

Expand Down Expand Up @@ -400,15 +400,15 @@ def __new__(
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Team):
return NotImplemented
return self.ob_team == other.ob_team
return lib.eq_team(self.ob_team, other.ob_team)

def __ne__(self, other: Any) -> bool:
if not isinstance(other, Team):
return NotImplemented
return self.ob_team != other.ob_team
return not lib.eq_team(self.ob_team, other.ob_team)

def __bool__(self) -> bool:
return self.ob_team != lib.SHMEM_TEAM_INVALID
return not lib.eq_team(self.ob_team, lib.SHMEM_TEAM_INVALID)

def __enter__(self) -> Team:
return self
Expand All @@ -426,11 +426,11 @@ def destroy(self) -> None:
return
team = self.ob_team
self.ob_team = lib.SHMEM_TEAM_INVALID
if team == lib.SHMEM_TEAM_WORLD:
if lib.eq_team(team, lib.SHMEM_TEAM_WORLD):
return
if team == lib.SHMEM_TEAM_SHARED:
if lib.eq_team(team, lib.SHMEM_TEAM_SHARED):
return
if team == lib.SHMEM_TEAM_INVALID:
if lib.eq_team(team, lib.SHMEM_TEAM_INVALID):
return
lib.shmem_team_destroy(team)

Expand Down