Skip to content

Commit

Permalink
riscv: Add XuanTie matrix extension support
Browse files Browse the repository at this point in the history
This patch is made with the XuanTie matrix spec [1].

How to Enable XuanTie Matrix:
 - Set mxstatus.matrix_enable = 1
 - Add dts with zxtmatrix:
	riscv,isa = "rv64imafdc_xxx_zxtmatrix";

This patch only support context switch, because of huge stack size cost.
TODO:
 - Signal with matrix.
 - Ptrace with matrix.

[1]: https://github.com/T-head-Semi/riscv-matrix-extension-spec

Signed-off-by: Guo Ren <[email protected]>
Signed-off-by: Guo Ren <[email protected]>
  • Loading branch information
guoren83 authored and RevySR committed Jun 29, 2024
1 parent a827454 commit 5c1ea0f
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 2 deletions.
15 changes: 14 additions & 1 deletion arch/riscv/include/asm/csr.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,19 @@
#define SR_VS_CLEAN _AC(0x00000400, UXL)
#define SR_VS_DIRTY _AC(0x00000600, UXL)

#define SR_MS _AC(0x06000000, UXL) /* Matrix Status */
#define SR_MS_OFF _AC(0x00000000, UXL)
#define SR_MS_INITIAL _AC(0x02000000, UXL)
#define SR_MS_CLEAN _AC(0x04000000, UXL)
#define SR_MS_DIRTY _AC(0x06000000, UXL)

#define SR_XS _AC(0x00018000, UXL) /* Extension Status */
#define SR_XS_OFF _AC(0x00000000, UXL)
#define SR_XS_INITIAL _AC(0x00008000, UXL)
#define SR_XS_CLEAN _AC(0x00010000, UXL)
#define SR_XS_DIRTY _AC(0x00018000, UXL)

#define SR_FS_VS (SR_FS | SR_VS) /* Vector and Floating-Point Unit */
#define SR_FS_VS (SR_FS | SR_VS | SR_MS) /* Vector and Floating-Point Unit */

#if __riscv_xlen == 32
#define SR_SD _AC(0x80000000, UXL) /* FS/VS/XS dirty */
Expand Down Expand Up @@ -400,6 +406,13 @@
#define CSR_VTYPE 0xc21
#define CSR_VLENB 0xc22

#define CSR_XMRSTART 0x801
#define CSR_XMCSR 0x802
#define CSR_XMSIZE 0x803
#define CSR_XMLENB 0xcc0
#define CSR_XRLENB 0xcc1
#define CSR_XMISA 0xcc2

#ifdef CONFIG_RISCV_M_MODE
# define CSR_STATUS CSR_MSTATUS
# define CSR_IE CSR_MIE
Expand Down
1 change: 1 addition & 0 deletions arch/riscv/include/asm/hwcap.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#define RISCV_ISA_EXT_ZICSR 40
#define RISCV_ISA_EXT_ZIFENCEI 41
#define RISCV_ISA_EXT_ZIHPM 42
#define RISCV_ISA_EXT_ZXTMATRIX 43

#define RISCV_ISA_EXT_MAX 64

Expand Down
1 change: 1 addition & 0 deletions arch/riscv/include/asm/insn-def.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
#define RV___RS2(v) __RV_REG(v)

#define RV_OPCODE_MISC_MEM RV_OPCODE(15)
#define RV_OPCODE_MATRIX RV_OPCODE(43)
#define RV_OPCODE_SYSTEM RV_OPCODE(115)

#define HFENCE_VVMA(vaddr, asid) \
Expand Down
1 change: 1 addition & 0 deletions arch/riscv/include/asm/processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct thread_struct {
unsigned long bad_cause;
unsigned long vstate_ctrl;
struct __riscv_v_ext_state vstate;
struct __riscv_m_ext_state mstate;
} __attribute__((__aligned__(sizeof(xlen_t))));

/* Whitelist the fstate from the task_struct for hardened usercopy */
Expand Down
2 changes: 2 additions & 0 deletions arch/riscv/include/asm/switch_to.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ do { \
__switch_to_fpu(__prev, __next); \
if (has_vector()) \
__switch_to_vector(__prev, __next); \
if (has_matrix()) \
__switch_to_matrix(__prev, __next); \
((last) = __switch_to(__prev, __next)); \
} while (0)

Expand Down
122 changes: 122 additions & 0 deletions arch/riscv/include/asm/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,89 @@
#include <asm/hwcap.h>
#include <asm/csr.h>
#include <asm/asm.h>
#include <asm/insn-def.h>

static __always_inline bool has_matrix(void)
{
return riscv_has_extension_unlikely(RISCV_ISA_EXT_ZXTMATRIX);
}

static inline void __riscv_m_mstate_clean(struct pt_regs *regs)
{
regs->status = (regs->status & ~SR_MS) | SR_MS_CLEAN;
}

static inline void __riscv_m_mstate_dirty(struct pt_regs *regs)
{
regs->status = (regs->status & ~SR_MS) | SR_MS_DIRTY;
}

static inline void riscv_m_mstate_off(struct pt_regs *regs)
{
regs->status = (regs->status & ~SR_MS) | SR_MS_OFF;
}

static inline void riscv_m_mstate_on(struct pt_regs *regs)
{
regs->status = (regs->status & ~SR_MS) | SR_MS_INITIAL;
}

static inline bool riscv_m_mstate_query(struct pt_regs *regs)
{
return (regs->status & SR_MS) != 0;
}

static __always_inline void riscv_m_enable(void)
{
csr_set(CSR_SSTATUS, SR_MS);
}

static __always_inline void riscv_m_disable(void)
{
csr_clear(CSR_SSTATUS, SR_MS);
}

static __always_inline void __mstate_csr_save(struct __riscv_m_ext_state *dest)
{
asm volatile (
"csrr %0, " __stringify(CSR_XMRSTART) "\n\t"
"csrr %1, " __stringify(CSR_XMCSR) "\n\t"
"csrr %2, " __stringify(CSR_XMSIZE) "\n\t"
: "=r" (dest->xmrstart), "=r" (dest->xmcsr), "=r" (dest->xmsize)
: :);
}

static __always_inline void __mstate_csr_restore(struct __riscv_m_ext_state *src)
{
asm volatile (
"csrw " __stringify(CSR_XMRSTART) ", %0\n\t"
"csrw " __stringify(CSR_XMCSR) ", %1\n\t"
"csrw " __stringify(CSR_XMSIZE) ", %2\n\t"
: : "r" (src->xmrstart), "r" (src->xmcsr), "r" (src->xmsize)
:);
}

static inline void __riscv_m_mstate_save(struct __riscv_m_ext_state *save_to,
void *datap)
{
riscv_m_enable();
__mstate_csr_save(save_to);
asm volatile (
INSN_R(OPCODE_MATRIX, FUNC3(0), FUNC7(21), __RD(0), RS1(%0), __RS2(7))
: : "r" (datap) : "memory");
riscv_m_disable();
}

static inline void __riscv_m_mstate_restore(struct __riscv_m_ext_state *restore_from,
void *datap)
{
riscv_m_enable();
asm volatile (
INSN_R(OPCODE_MATRIX, FUNC3(0), FUNC7(20), __RD(0), RS1(%0), __RS2(7))
: : "r" (datap) : "memory");
__mstate_csr_restore(restore_from);
riscv_m_disable();
}

extern unsigned long riscv_v_vsize;
int riscv_v_setup_vsize(void);
Expand Down Expand Up @@ -173,6 +256,17 @@ static inline void riscv_v_vstate_save(struct task_struct *task,
}
}

static inline void riscv_m_mstate_save(struct task_struct *task,
struct pt_regs *regs)
{
if ((regs->status & SR_MS) == SR_MS_DIRTY) {
struct __riscv_m_ext_state *mstate = &task->thread.mstate;

__riscv_m_mstate_save(mstate, mstate->datap);
__riscv_m_mstate_clean(regs);
}
}

static inline void riscv_v_vstate_restore(struct task_struct *task,
struct pt_regs *regs)
{
Expand All @@ -184,6 +278,27 @@ static inline void riscv_v_vstate_restore(struct task_struct *task,
}
}

static inline void riscv_m_mstate_restore(struct task_struct *task,
struct pt_regs *regs)
{
if ((regs->status & SR_MS) != SR_MS_OFF) {
struct __riscv_m_ext_state *mstate = &task->thread.mstate;

__riscv_m_mstate_restore(mstate, mstate->datap);
__riscv_m_mstate_clean(regs);
}
}

static inline void __switch_to_matrix(struct task_struct *prev,
struct task_struct *next)
{
struct pt_regs *regs;

regs = task_pt_regs(prev);
riscv_m_mstate_save(prev, regs);
riscv_m_mstate_restore(next, task_pt_regs(next));
}

static inline void __switch_to_vector(struct task_struct *prev,
struct task_struct *next)
{
Expand All @@ -203,6 +318,7 @@ struct pt_regs;

static inline int riscv_v_setup_vsize(void) { return -EOPNOTSUPP; }
static __always_inline bool has_vector(void) { return false; }
static __always_inline bool has_matrix(void) { return false; }
static inline bool riscv_v_first_use_handler(struct pt_regs *regs) { return false; }
static inline bool riscv_v_vstate_query(struct pt_regs *regs) { return false; }
static inline bool riscv_v_vstate_ctrl_user_allowed(void) { return false; }
Expand All @@ -214,6 +330,12 @@ static inline bool riscv_v_vstate_ctrl_user_allowed(void) { return false; }
#define riscv_v_vstate_off(regs) do {} while (0)
#define riscv_v_vstate_on(regs) do {} while (0)

#define riscv_m_mstate_save(task, regs) do {} while (0)
#define riscv_m_mstate_restore(task, regs) do {} while (0)
#define __switch_to_matrix(__prev, __next) do {} while (0)
#define riscv_m_mstate_off(regs) do {} while (0)
#define riscv_m_mstate_on(regs) do {} while (0)

#endif /* CONFIG_RISCV_ISA_V */

#endif /* ! __ASM_RISCV_VECTOR_H */
7 changes: 7 additions & 0 deletions arch/riscv/include/uapi/asm/ptrace.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ struct __riscv_v_regset_state {
char vreg[];
};

struct __riscv_m_ext_state {
xlen_t xmrstart;
xlen_t xmcsr;
xlen_t xmsize;
void *datap;
};

/*
* According to spec: The number of bits in a single vector register,
* VLEN >= ELEN, which must be a power of 2, and must be no greater than
Expand Down
1 change: 1 addition & 0 deletions arch/riscv/kernel/cpufeature.c
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ const struct riscv_isa_ext_data riscv_isa_ext[] = {
__RISCV_ISA_EXT_DATA(svinval, RISCV_ISA_EXT_SVINVAL),
__RISCV_ISA_EXT_DATA(svnapot, RISCV_ISA_EXT_SVNAPOT),
__RISCV_ISA_EXT_DATA(svpbmt, RISCV_ISA_EXT_SVPBMT),
__RISCV_ISA_EXT_DATA(zxtmatrix, RISCV_ISA_EXT_ZXTMATRIX),
};

const size_t riscv_isa_ext_count = ARRAY_SIZE(riscv_isa_ext);
Expand Down
92 changes: 91 additions & 1 deletion arch/riscv/kernel/vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,31 @@ static bool insn_is_vector(u32 insn_buf)
return false;
}

static bool insn_is_matrix(u32 insn_buf)
{
u32 opcode = insn_buf & __INSN_OPCODE_MASK;
u32 csr;

/*
* All M-related instructions, including CSR operations are 4-Byte. So,
* do not handle if the instruction length is not 4-Byte.
*/
if (unlikely(GET_INSN_LENGTH(insn_buf) != 4))
return false;

switch (opcode) {
case 43:
return true;
case RVG_OPCODE_SYSTEM:
csr = RVG_EXTRACT_SYSTEM_CSR(insn_buf);
if ((csr >= CSR_XMRSTART && csr <= CSR_XMSIZE) ||
(csr >= CSR_XMLENB && csr <= CSR_XMISA))
return true;
}

return false;
}

static int riscv_v_thread_zalloc(void)
{
void *datap;
Expand All @@ -94,6 +119,20 @@ static int riscv_v_thread_zalloc(void)
return 0;
}

static int riscv_m_thread_zalloc(void)
{
void *datap;

datap = kzalloc(csr_read(CSR_XMLENB) * 8, GFP_KERNEL);
if (!datap)
return -ENOMEM;

current->thread.mstate.datap = datap;
memset(&current->thread.mstate, 0, offsetof(struct __riscv_m_ext_state,
datap));
return 0;
}

#define VSTATE_CTRL_GET_CUR(x) ((x) & PR_RISCV_V_VSTATE_CTRL_CUR_MASK)
#define VSTATE_CTRL_GET_NEXT(x) (((x) & PR_RISCV_V_VSTATE_CTRL_NEXT_MASK) >> 2)
#define VSTATE_CTRL_MAKE_NEXT(x) (((x) << 2) & PR_RISCV_V_VSTATE_CTRL_NEXT_MASK)
Expand Down Expand Up @@ -131,7 +170,7 @@ bool riscv_v_vstate_ctrl_user_allowed(void)
}
EXPORT_SYMBOL_GPL(riscv_v_vstate_ctrl_user_allowed);

bool riscv_v_first_use_handler(struct pt_regs *regs)
static bool __riscv_v_first_use_handler(struct pt_regs *regs)
{
u32 __user *epc = (u32 __user *)(ulong)regs->epc;
u32 insn = (u32)regs->badaddr;
Expand Down Expand Up @@ -171,6 +210,57 @@ bool riscv_v_first_use_handler(struct pt_regs *regs)
return true;
}

static bool __riscv_m_first_use_handler(struct pt_regs *regs)
{
u32 __user *epc = (u32 __user *)(ulong)regs->epc;
u32 insn = (u32)regs->badaddr;

/* Do not handle if Matrix is not supported, or disabled */
if (!has_matrix())
return false;

/* If Matrix has been enabled then it is not the first-use trap */
if (riscv_m_mstate_query(regs))
return false;

/* Get the instruction */
if (!insn) {
if (__get_user(insn, epc))
return false;
}

/* Filter out non-Matrix instructions */
if (!insn_is_matrix(insn))
return false;

/*
* When datap = NULL, it's the first use.
* When datap != NULL, mrelease makes sstatus.ms=OFF.
*/
if (current->thread.mstate.datap == NULL) {
if (riscv_m_thread_zalloc()) {
force_sig(SIGBUS);
return true;
}
riscv_m_mstate_restore(current, regs);
}

riscv_m_mstate_on(regs);
return true;
}

bool riscv_v_first_use_handler(struct pt_regs *regs)
{
bool ret;

ret = __riscv_v_first_use_handler(regs);

if (!ret)
ret = __riscv_m_first_use_handler(regs);

return ret;
}

void riscv_v_vstate_ctrl_init(struct task_struct *tsk)
{
bool inherit;
Expand Down

0 comments on commit 5c1ea0f

Please sign in to comment.