Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extern functions #96

Merged
merged 12 commits into from
Mar 7, 2020
Merged
6 changes: 4 additions & 2 deletions crates/mun/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::rc::Rc;
use std::time::Duration;

use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
use mun_compiler::{host_triple, Config, PathOrInline, Target};
use mun_compiler::{Config, PathOrInline, Target};
use mun_runtime::{invoke_fn, ReturnTypeReflection, Runtime, RuntimeBuilder};

fn main() -> Result<(), failure::Error> {
Expand Down Expand Up @@ -139,7 +139,9 @@ fn compiler_options(matches: &ArgMatches) -> Result<mun_compiler::CompilerOption
Ok(mun_compiler::CompilerOptions {
input: PathOrInline::Path(matches.value_of("INPUT").unwrap().into()), // Safe because its a required arg
config: Config {
target: Target::search(matches.value_of("target").unwrap_or_else(|| host_triple()))?,
target: matches
.value_of("target")
.map_or_else(Target::host_target, Target::search)?,
optimization_lvl,
out_dir: None,
},
Expand Down
7 changes: 6 additions & 1 deletion crates/mun_codegen/src/code_gen/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,12 @@ fn gen_struct_info<D: IrDatabase>(
let (field_offsets, _) = gen_u16_array(module, field_offsets);

let field_sizes = fields.iter().map(|field| {
target_data.get_store_size(&db.type_ir(field.ty(db), CodeGenParams { is_extern: false }))
target_data.get_store_size(&db.type_ir(
field.ty(db),
CodeGenParams {
make_marshallable: false,
},
))
});
let (field_sizes, _) = gen_u16_array(module, field_sizes);

Expand Down
9 changes: 7 additions & 2 deletions crates/mun_codegen/src/ir/adt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ pub(super) fn gen_struct_decl(db: &impl IrDatabase, s: hir::Struct) -> StructTyp
.iter()
.map(|field| {
let field_type = field.ty(db);
try_convert_any_to_basic(db.type_ir(field_type, CodeGenParams { is_extern: false }))
.expect("could not convert field type")
try_convert_any_to_basic(db.type_ir(
field_type,
CodeGenParams {
make_marshallable: false,
},
))
.expect("could not convert field type")
})
.collect();

Expand Down
51 changes: 39 additions & 12 deletions crates/mun_codegen/src/ir/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
Pat::Bind { name } => {
let builder = self.new_alloca_builder();
let pat_ty = self.infer[pat].clone();
let ty = try_convert_any_to_basic(
self.db
.type_ir(pat_ty.clone(), CodeGenParams { is_extern: false }),
)
let ty = try_convert_any_to_basic(self.db.type_ir(
pat_ty.clone(),
CodeGenParams {
make_marshallable: false,
},
))
.expect("expected basic type");
let ptr = builder.build_alloca(ty, &name.to_string());
self.pat_to_local.insert(pat, ptr);
Expand Down Expand Up @@ -470,7 +472,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
self.builder.build_load(value.into_pointer_value(), "deref")
}
hir::StructMemoryKind::Value => {
if self.params.is_extern {
if self.params.make_marshallable {
self.builder.build_load(value.into_pointer_value(), "deref")
} else {
value
Expand Down Expand Up @@ -513,8 +515,8 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
let lhs_type = self.infer[lhs].clone();
let rhs_type = self.infer[rhs].clone();
match lhs_type.as_simple() {
Some(TypeCtor::Float) => self.gen_binary_op_float(lhs, rhs, op),
Some(TypeCtor::Int) => self.gen_binary_op_int(lhs, rhs, op),
Some(TypeCtor::Float(_ty)) => self.gen_binary_op_float(lhs, rhs, op),
Some(TypeCtor::Int(ty)) => self.gen_binary_op_int(lhs, rhs, op, ty.signedness),
_ => unimplemented!(
"unimplemented operation {0}op{1}",
lhs_type.display(self.db),
Expand Down Expand Up @@ -588,6 +590,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
lhs_expr: ExprId,
rhs_expr: ExprId,
op: BinaryOp,
signedness: hir::Signedness,
) -> Option<BasicValueEnum> {
let lhs = self
.gen_expr(lhs_expr)
Expand All @@ -608,19 +611,43 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
CmpOp::Ord {
ordering: Ordering::Less,
strict: false,
} => ("lesseq", IntPredicate::SLE),
} => (
"lesseq",
match signedness {
hir::Signedness::Signed => IntPredicate::SLE,
hir::Signedness::Unsigned => IntPredicate::ULE,
},
),
CmpOp::Ord {
ordering: Ordering::Less,
strict: true,
} => ("less", IntPredicate::SLT),
} => (
"less",
match signedness {
hir::Signedness::Signed => IntPredicate::SLT,
hir::Signedness::Unsigned => IntPredicate::ULT,
},
),
CmpOp::Ord {
ordering: Ordering::Greater,
strict: false,
} => ("greatereq", IntPredicate::SGE),
} => (
"greatereq",
match signedness {
hir::Signedness::Signed => IntPredicate::SGE,
hir::Signedness::Unsigned => IntPredicate::UGE,
},
),
CmpOp::Ord {
ordering: Ordering::Greater,
strict: true,
} => ("greater", IntPredicate::SGT),
} => (
"greater",
match signedness {
hir::Signedness::Signed => IntPredicate::SGT,
hir::Signedness::Unsigned => IntPredicate::UGT,
},
),
};
Some(
self.builder
Expand Down Expand Up @@ -683,7 +710,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {

fn should_use_dispatch_table(&self) -> bool {
// FIXME: When we use the dispatch table, generated wrappers have infinite recursion
!self.params.is_extern
!self.params.make_marshallable
}

/// Generates IR for a function call.
Expand Down
7 changes: 6 additions & 1 deletion crates/mun_codegen/src/ir/dispatch_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,12 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> {
let sig = hir_type.callable_sig(self.db).unwrap();
let ir_type = self
.db
.type_ir(hir_type, CodeGenParams { is_extern: false })
.type_ir(
hir_type,
CodeGenParams {
make_marshallable: false,
},
)
.into_function_type();
let arg_types = sig
.params()
Expand Down
8 changes: 6 additions & 2 deletions crates/mun_codegen/src/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ pub(crate) fn gen_body<'a, 'b, D: IrDatabase>(
llvm_function,
llvm_functions,
dispatch_table,
CodeGenParams { is_extern: false },
CodeGenParams {
make_marshallable: false,
},
);

code_gen.gen_fn_body();
Expand All @@ -81,7 +83,9 @@ pub(crate) fn gen_wrapper_body<'a, 'b, D: IrDatabase>(
llvm_function,
llvm_functions,
dispatch_table,
CodeGenParams { is_extern: true },
CodeGenParams {
make_marshallable: true,
},
);

code_gen.gen_fn_wrapper();
Expand Down
10 changes: 7 additions & 3 deletions crates/mun_codegen/src/ir/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc<ModuleIR> {
// TODO: Remove once we have more ModuleDef variants
#[allow(clippy::single_match)]
match def {
ModuleDef::Function(f) => {
ModuleDef::Function(f) if !f.is_extern(db) => {
// Collect argument types
let fn_sig = f.ty(db).callable_sig(db).unwrap();
for ty in fn_sig.params().iter() {
Expand All @@ -70,7 +70,9 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc<ModuleIR> {
db,
*f,
&llvm_module,
CodeGenParams { is_extern: false },
CodeGenParams {
make_marshallable: false,
},
);
functions.insert(*f, fun);

Expand All @@ -84,7 +86,9 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc<ModuleIR> {
db,
*f,
&llvm_module,
CodeGenParams { is_extern: true },
CodeGenParams {
make_marshallable: true,
},
);
wrappers.insert(*f, wrapper_fun);

Expand Down
71 changes: 64 additions & 7 deletions crates/mun_codegen/src/ir/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,37 @@ use crate::{
type_info::{TypeGroup, TypeInfo},
CodeGenParams, IrDatabase,
};
use hir::{ApplicationTy, CallableDef, Ty, TypeCtor};
use hir::{ApplicationTy, CallableDef, FloatBitness, FloatTy, IntBitness, IntTy, Ty, TypeCtor};
use inkwell::types::{AnyTypeEnum, BasicType, BasicTypeEnum, StructType};
use inkwell::AddressSpace;
use mun_target::spec::Target;

/// Given a mun type, construct an LLVM IR type
#[rustfmt::skip]
pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty, params: CodeGenParams) -> AnyTypeEnum {
let context = db.context();
match ty {
Ty::Empty => AnyTypeEnum::StructType(context.struct_type(&[], false)),
Ty::Apply(ApplicationTy { ctor, .. }) => match ctor {
TypeCtor::Float => AnyTypeEnum::FloatType(context.f64_type()),
TypeCtor::Int => AnyTypeEnum::IntType(context.i64_type()),
// Float primitives
TypeCtor::Float(fty) => match fty.resolve(&db.target()).bitness {
FloatBitness::X64 => AnyTypeEnum::FloatType(context.f64_type()),
FloatBitness::X32 => AnyTypeEnum::FloatType(context.f32_type()),
_ => unreachable!()
}

// Int primitives
TypeCtor::Int(ity) => match ity.resolve(&db.target()).bitness {
IntBitness::X64 => AnyTypeEnum::IntType(context.i64_type()),
IntBitness::X32 => AnyTypeEnum::IntType(context.i32_type()),
IntBitness::X16 => AnyTypeEnum::IntType(context.i16_type()),
IntBitness::X8 => AnyTypeEnum::IntType(context.i8_type()),
_ => unreachable!()
}

// Boolean
TypeCtor::Bool => AnyTypeEnum::IntType(context.bool_type()),

TypeCtor::FnDef(def @ CallableDef::Function(_)) => {
let ty = db.callable_sig(def);
let param_tys: Vec<BasicTypeEnum> = ty
Expand All @@ -40,7 +58,7 @@ pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty, params: CodeGenParams) -> A
match s.data(db).memory_kind {
hir::StructMemoryKind::GC => struct_ty.ptr_type(AddressSpace::Generic).into(),
hir::StructMemoryKind::Value => {
if params.is_extern {
if params.make_marshallable {
struct_ty.ptr_type(AddressSpace::Generic).into()
} else {
struct_ty.into()
Expand All @@ -59,7 +77,12 @@ pub fn struct_ty_query(db: &impl IrDatabase, s: hir::Struct) -> StructType {
let name = s.name(db).to_string();
for field in s.fields(db).iter() {
// Ensure that salsa's cached value incorporates the struct fields
let _field_type_ir = db.type_ir(field.ty(db), CodeGenParams { is_extern: false });
let _field_type_ir = db.type_ir(
field.ty(db),
CodeGenParams {
make_marshallable: false,
},
);
}

db.context().opaque_struct_type(&name)
Expand All @@ -69,12 +92,46 @@ pub fn struct_ty_query(db: &impl IrDatabase, s: hir::Struct) -> StructType {
pub fn type_info_query(db: &impl IrDatabase, ty: Ty) -> TypeInfo {
match ty {
Ty::Apply(ctor) => match ctor.ctor {
TypeCtor::Float => TypeInfo::new("core::float", TypeGroup::FundamentalTypes),
TypeCtor::Int => TypeInfo::new("core::int", TypeGroup::FundamentalTypes),
TypeCtor::Float(ty) => TypeInfo::new(
format!("core::{}", ty.resolve(&db.target())),
TypeGroup::FundamentalTypes,
),
TypeCtor::Int(ty) => TypeInfo::new(
format!("core::{}", ty.resolve(&db.target())),
TypeGroup::FundamentalTypes,
),
TypeCtor::Bool => TypeInfo::new("core::bool", TypeGroup::FundamentalTypes),
TypeCtor::Struct(s) => TypeInfo::new(s.name(db).to_string(), TypeGroup::StructTypes(s)),
_ => unreachable!("{:?} unhandled", ctor),
},
_ => unreachable!("{:?} unhandled", ty),
}
}

trait ResolveBitness {
fn resolve(&self, _target: &Target) -> Self;
}

impl ResolveBitness for FloatTy {
fn resolve(&self, _target: &Target) -> Self {
let bitness = match self.bitness {
FloatBitness::Undefined => FloatBitness::X64,
bitness => bitness,
};
FloatTy { bitness }
}
}

impl ResolveBitness for IntTy {
fn resolve(&self, _target: &Target) -> Self {
let bitness = match self.bitness {
IntBitness::Undefined => IntBitness::X64,
IntBitness::Xsize => IntBitness::X64,
bitness => bitness,
};
IntTy {
bitness,
signedness: self.signedness,
}
}
}
2 changes: 1 addition & 1 deletion crates/mun_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ pub use crate::{
pub struct CodeGenParams {
/// Whether generated code should support extern function calls.
/// This allows function parameters with `struct(value)` types to be marshalled.
is_extern: bool,
make_marshallable: bool,
}
18 changes: 18 additions & 0 deletions crates/mun_codegen/src/snapshots/test__extern_fn.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
source: crates/mun_codegen/src/test.rs
expression: "extern fn add(a:int, b:int): int;\nfn main() {\n add(3,4);\n}"
---
; ModuleID = 'main.mun'
source_filename = "main.mun"

%DispatchTable = type { i64 (i64, i64)* }

@dispatchTable = global %DispatchTable zeroinitializer

define void @main() {
body:
%add_ptr = load i64 (i64, i64)*, i64 (i64, i64)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0)
%add = call i64 %add_ptr(i64 3, i64 4)
ret void
}

Loading