Skip to content

Commit

Permalink
use _continue_interrupt_flag to reduce binary size
Browse files Browse the repository at this point in the history
  • Loading branch information
romancardenas committed Apr 12, 2024
1 parent 52d5185 commit e3aa1c2
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 103 deletions.
196 changes: 151 additions & 45 deletions riscv-rt/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,122 @@ pub fn loop_global_asm(input: TokenStream) -> TokenStream {
res.parse().unwrap()
}

#[derive(Clone, Copy)]
enum RiscvArch {
Rv32,
Rv64,
}

const TRAP_SIZE: usize = 16;

#[rustfmt::skip]
const TRAP_FRAME: [&str; TRAP_SIZE] = [
"ra",
"t0",
"t1",
"t2",
"t3",
"t4",
"t5",
"t6",
"a0",
"a1",
"a2",
"a3",
"a4",
"a5",
"a6",
"a7",
];

fn store_trap<T: FnMut(&str) -> bool>(arch: RiscvArch, mut filter: T) -> String {
let (width, store) = match arch {
RiscvArch::Rv32 => (4, "sw"),
RiscvArch::Rv64 => (8, "sd"),
};
let mut stores = Vec::new();
for (i, reg) in TRAP_FRAME
.iter()
.enumerate()
.filter(|(_, &reg)| filter(reg))
{
stores.push(format!("{store} {reg}, {i}*{width}(sp)"));
}
stores.join("\n")
}

fn load_trap(arch: RiscvArch) -> String {
let (width, load) = match arch {
RiscvArch::Rv32 => (4, "lw"),
RiscvArch::Rv64 => (8, "ld"),
};
let mut loads = Vec::new();
for (i, reg) in TRAP_FRAME.iter().enumerate() {
loads.push(format!("{load} {reg}, {i}*{width}(sp)"));
}
loads.join("\n")
}

#[proc_macro]
pub fn weak_start_trap_riscv32(_input: TokenStream) -> TokenStream {
weak_start_trap(RiscvArch::Rv32)
}

#[proc_macro]
pub fn weak_start_trap_riscv64(_input: TokenStream) -> TokenStream {
weak_start_trap(RiscvArch::Rv64)
}

fn weak_start_trap(arch: RiscvArch) -> TokenStream {
let width = match arch {
RiscvArch::Rv32 => 4,
RiscvArch::Rv64 => 8,
};
// ensure we do not break that sp is 16-byte aligned
if (TRAP_SIZE * width) % 16 != 0 {
return parse::Error::new(Span::call_site(), "Trap frame size must be 16-byte aligned")
.to_compile_error()
.into();
}
let store = store_trap(arch, |_| true);
let load = load_trap(arch);

#[cfg(feature = "s-mode")]
let ret = "sret";
#[cfg(not(feature = "s-mode"))]
let ret = "mret";

let instructions: proc_macro2::TokenStream = format!(
"
core::arch::global_asm!(
\".section .trap, \\\"ax\\\"
.align {width}
.weak _start_trap
_start_trap:
addi sp, sp, - {TRAP_SIZE} * {width}
{store}
add a0, sp, zero
jal ra, _start_trap_rust
{load}
addi sp, sp, {TRAP_SIZE} * {width}
{ret}
\");"
)
.parse()
.unwrap();

#[cfg(feature = "v-trap")]
let v_trap = v_trap::continue_interrupt_trap(arch);
#[cfg(not(feature = "v-trap"))]
let v_trap = proc_macro2::TokenStream::new();

quote!(
#instructions
#v_trap
)
.into()
}

#[proc_macro_attribute]
pub fn interrupt_riscv32(args: TokenStream, input: TokenStream) -> TokenStream {
interrupt(args, input, RiscvArch::Rv32)
Expand Down Expand Up @@ -376,7 +487,7 @@ fn interrupt(args: TokenStream, input: TokenStream, _arch: RiscvArch) -> TokenSt
#[cfg(not(feature = "v-trap"))]
let start_trap = proc_macro2::TokenStream::new();
#[cfg(feature = "v-trap")]
let start_trap = v_trap::start_interrupt_trap_asm(ident, _arch);
let start_trap = v_trap::start_interrupt_trap(ident, _arch);

quote!(
#start_trap
Expand All @@ -390,45 +501,41 @@ fn interrupt(args: TokenStream, input: TokenStream, _arch: RiscvArch) -> TokenSt
mod v_trap {
use super::*;

const TRAP_SIZE: usize = 16;

#[rustfmt::skip]
const TRAP_FRAME: [&str; TRAP_SIZE] = [
"ra",
"t0",
"t1",
"t2",
"t3",
"t4",
"t5",
"t6",
"a0",
"a1",
"a2",
"a3",
"a4",
"a5",
"a6",
"a7",
];

pub(crate) fn start_interrupt_trap_asm(
pub(crate) fn start_interrupt_trap(
ident: &syn::Ident,
arch: RiscvArch,
) -> proc_macro2::TokenStream {
let function = ident.to_string();
let (width, store, load) = match arch {
RiscvArch::Rv32 => (4, "sw", "lw"),
RiscvArch::Rv64 => (8, "sd", "ld"),
let interrupt = ident.to_string();
let width = match arch {
RiscvArch::Rv32 => 4,
RiscvArch::Rv64 => 8,
};
let store = store_trap(arch, |r| r == "a0");

let (mut stores, mut loads) = (Vec::new(), Vec::new());
for (i, r) in TRAP_FRAME.iter().enumerate() {
stores.push(format!(" {store} {r}, {i}*{width}(sp)"));
loads.push(format!(" {load} {r}, {i}*{width}(sp)"));
}
let store = stores.join("\n");
let load = loads.join("\n");
let instructions = format!(
"
core::arch::global_asm!(
\".section .trap, \\\"ax\\\"
.align {width}
.global _start_{interrupt}_trap
_start_{interrupt}_trap:
addi sp, sp, -{TRAP_SIZE} * {width} // allocate space for trap frame
{store} // store trap partially (only register a0)
la a0, {interrupt} // load interrupt handler address into a0
j _continue_interrupt_trap // jump to common part of interrupt trap
\");"
);

instructions.parse().unwrap()
}

pub(crate) fn continue_interrupt_trap(arch: RiscvArch) -> proc_macro2::TokenStream {
let width = match arch {
RiscvArch::Rv32 => 4,
RiscvArch::Rv64 => 8,
};
let store = store_trap(arch, |reg| reg != "a0");
let load = load_trap(arch);

#[cfg(feature = "s-mode")]
let ret = "sret";
Expand All @@ -439,16 +546,15 @@ mod v_trap {
"
core::arch::global_asm!(
\".section .trap, \\\"ax\\\"
.align {width}
.global _start_{function}_trap
_start_{function}_trap:
addi sp, sp, - {TRAP_SIZE} * {width}
{store}
call {function}
{load}
addi sp, sp, {TRAP_SIZE} * {width}
{ret}\"
);"
.align {width} // TODO is this necessary?
.global _continue_interrupt_trap
_continue_interrupt_trap:
{store} // store trap partially (all registers except a0)
jalr ra, a0, 0 // jump to corresponding interrupt handler (address stored in a0)
{load} // restore trap frame
addi sp, sp, {TRAP_SIZE} * {width} // deallocate space for trap frame
{ret} // return from interrupt
\");"
);

instructions.parse().unwrap()
Expand Down
61 changes: 3 additions & 58 deletions riscv-rt/src/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,65 +277,10 @@ _pre_init_trap:
j _pre_init_trap",
);

/// Trap entry point (_start_trap). It saves caller saved registers, calls
/// _start_trap_rust, restores caller saved registers and then returns.
///
/// # Usage
///
/// The macro takes 5 arguments:
/// - `$STORE`: the instruction used to store a register in the stack (e.g. `sd` for riscv64)
/// - `$LOAD`: the instruction used to load a register from the stack (e.g. `ld` for riscv64)
/// - `$BYTES`: the number of bytes used to store a register (e.g. 8 for riscv64)
/// - `$TRAP_SIZE`: the number of registers to store in the stack (e.g. 32 for all the user registers)
/// - list of tuples of the form `($REG, $LOCATION)`, where:
/// - `$REG`: the register to store/load
/// - `$LOCATION`: the location in the stack where to store/load the register
#[rustfmt::skip]
macro_rules! trap_handler {
($STORE:ident, $LOAD:ident, $BYTES:literal, $TRAP_SIZE:literal, [$(($REG:ident, $LOCATION:literal)),*]) => {
// ensure we do not break that sp is 16-byte aligned
const _: () = assert!(($TRAP_SIZE * $BYTES) % 16 == 0);
global_asm!(
"
.section .trap, \"ax\"
.weak _start_trap
_start_trap:",
// save space for trap handler in stack
concat!("addi sp, sp, -", stringify!($TRAP_SIZE * $BYTES)),
// save registers in the desired order
$(concat!(stringify!($STORE), " ", stringify!($REG), ", ", stringify!($LOCATION * $BYTES), "(sp)"),)*
// call rust trap handler
"add a0, sp, zero
jal ra, _start_trap_rust",
// restore registers in the desired order
$(concat!(stringify!($LOAD), " ", stringify!($REG), ", ", stringify!($LOCATION * $BYTES), "(sp)"),)*
// free stack
concat!("addi sp, sp, ", stringify!($TRAP_SIZE * $BYTES)),
);
cfg_global_asm!(
// return from trap
#[cfg(feature = "s-mode")]
"sret",
#[cfg(not(feature = "s-mode"))]
"mret",
);
};
}

#[rustfmt::skip]
#[cfg(riscv32)]
trap_handler!(
sw, lw, 4, 16,
[(ra, 0), (t0, 1), (t1, 2), (t2, 3), (t3, 4), (t4, 5), (t5, 6), (t6, 7),
(a0, 8), (a1, 9), (a2, 10), (a3, 11), (a4, 12), (a5, 13), (a6, 14), (a7, 15)]
);
#[rustfmt::skip]
riscv_rt_macros::weak_start_trap_riscv32!();
#[cfg(riscv64)]
trap_handler!(
sd, ld, 8, 16,
[(ra, 0), (t0, 1), (t1, 2), (t2, 3), (t3, 4), (t4, 5), (t5, 6), (t6, 7),
(a0, 8), (a1, 9), (a2, 10), (a3, 11), (a4, 12), (a5, 13), (a6, 14), (a7, 15)]
);
riscv_rt_macros::weak_start_trap_riscv64!();

#[cfg(feature = "v-trap")]
cfg_global_asm!(
Expand All @@ -345,7 +290,7 @@ cfg_global_asm!(
.type _vector_table, @function
.option push
.balign 0x100 // TODO check if this is the correct alignment
.balign 0x4 // TODO check if this is the correct alignment
.option norelax
.option norvc
Expand Down

0 comments on commit e3aa1c2

Please sign in to comment.