diff --git a/tool/microkit/src/main.rs b/tool/microkit/src/main.rs index 4d12ccebb..a80fbe3db 100644 --- a/tool/microkit/src/main.rs +++ b/tool/microkit/src/main.rs @@ -1569,8 +1569,9 @@ fn build_system( let pd_endpoint_names: Vec = system .protection_domains .iter() - .filter(|pd| pd.needs_ep()) - .map(|pd| format!("EP: PD={}", pd.name)) + .enumerate() + .filter(|(idx, pd)| pd.needs_ep(*idx, &system.channels)) + .map(|(_, pd)| format!("EP: PD={}", pd.name)) .collect(); let endpoint_names = [vec![format!("EP: Monitor Fault")], pd_endpoint_names].concat(); // Reply objects @@ -1593,8 +1594,9 @@ fn build_system( system .protection_domains .iter() - .map(|pd| { - if pd.needs_ep() { + .enumerate() + .map(|(idx, pd)| { + if pd.needs_ep(idx, &system.channels) { let obj = &endpoint_objs[1..][i]; i += 1; Some(obj) @@ -2157,7 +2159,7 @@ fn build_system( // Minting in the address space for (idx, pd) in system.protection_domains.iter().enumerate() { - let obj = if pd.needs_ep() { + let obj = if pd.needs_ep(idx, &system.channels) { pd_endpoint_objs[idx].unwrap() } else { ¬ification_objs[idx] @@ -2344,89 +2346,55 @@ fn build_system( } for cc in &system.channels { - let pd_a = &system.protection_domains[cc.pd_a]; - let pd_b = &system.protection_domains[cc.pd_b]; - let pd_a_cnode_obj = cnode_objs_by_pd[pd_a]; - let pd_b_cnode_obj = cnode_objs_by_pd[pd_b]; - let pd_a_notification_obj = ¬ification_objs[cc.pd_a]; - let pd_b_notification_obj = ¬ification_objs[cc.pd_b]; - - // Set up the notification caps - let pd_a_cap_idx = BASE_OUTPUT_NOTIFICATION_CAP + cc.id_a; - let pd_a_badge = 1 << cc.id_b; - assert!(pd_a_cap_idx < PD_CAP_SIZE); - system_invocations.push(Invocation::new( - config, - InvocationArgs::CnodeMint { - cnode: pd_a_cnode_obj.cap_addr, - dest_index: pd_a_cap_idx, - dest_depth: PD_CAP_BITS, - src_root: root_cnode_cap, - src_obj: pd_b_notification_obj.cap_addr, - src_depth: config.cap_address_bits, - rights: Rights::All as u64, // FIXME: Check rights - badge: pd_a_badge, - }, - )); + for (send, recv) in [(&cc.end_a, &cc.end_b), (&cc.end_b, &cc.end_a)] { + let send_pd = &system.protection_domains[send.pd]; + let send_cnode_obj = cnode_objs_by_pd[send_pd]; + let recv_notification_obj = ¬ification_objs[recv.pd]; - let pd_b_cap_idx = BASE_OUTPUT_NOTIFICATION_CAP + cc.id_b; - let pd_b_badge = 1 << cc.id_a; - assert!(pd_b_cap_idx < PD_CAP_SIZE); - system_invocations.push(Invocation::new( - config, - InvocationArgs::CnodeMint { - cnode: pd_b_cnode_obj.cap_addr, - dest_index: pd_b_cap_idx, - dest_depth: PD_CAP_BITS, - src_root: root_cnode_cap, - src_obj: pd_a_notification_obj.cap_addr, - src_depth: config.cap_address_bits, - rights: Rights::All as u64, // FIXME: Check rights - badge: pd_b_badge, - }, - )); + if send.can_notify_other { + let send_cap_idx = BASE_OUTPUT_NOTIFICATION_CAP + send.id; + assert!(send_cap_idx < PD_CAP_SIZE); + // receiver sees the sender's badge. + let send_badge = 1 << recv.id; - // Set up the endpoint caps - if pd_b.pp { - let pd_a_cap_idx = BASE_OUTPUT_ENDPOINT_CAP + cc.id_a; - let pd_a_badge = PPC_BADGE | cc.id_b; - let pd_b_endpoint_obj = pd_endpoint_objs[cc.pd_b].unwrap(); - assert!(pd_a_cap_idx < PD_CAP_SIZE); + system_invocations.push(Invocation::new( + config, + InvocationArgs::CnodeMint { + cnode: send_cnode_obj.cap_addr, + dest_index: send_cap_idx, + dest_depth: PD_CAP_BITS, + src_root: root_cnode_cap, + src_obj: recv_notification_obj.cap_addr, + src_depth: config.cap_address_bits, + rights: Rights::All as u64, // FIXME: Check rights + badge: send_badge, + }, + )); + } - system_invocations.push(Invocation::new( - config, - InvocationArgs::CnodeMint { - cnode: pd_a_cnode_obj.cap_addr, - dest_index: pd_a_cap_idx, - dest_depth: PD_CAP_BITS, - src_root: root_cnode_cap, - src_obj: pd_b_endpoint_obj.cap_addr, - src_depth: config.cap_address_bits, - rights: Rights::All as u64, // FIXME: Check rights - badge: pd_a_badge, - }, - )); - } + if send.can_ppcall_other { + let send_cap_idx = BASE_OUTPUT_ENDPOINT_CAP + send.id; + assert!(send_cap_idx < PD_CAP_SIZE); + // receiver sees the sender's badge. + let send_badge = PPC_BADGE | recv.id; - if pd_a.pp { - let pd_b_cap_idx = BASE_OUTPUT_ENDPOINT_CAP + cc.id_b; - let pd_b_badge = PPC_BADGE | cc.id_a; - let pd_a_endpoint_obj = pd_endpoint_objs[cc.pd_a].unwrap(); - assert!(pd_b_cap_idx < PD_CAP_SIZE); + let recv_endpoint_obj = + pd_endpoint_objs[recv.pd].expect("endpoint object to exist"); - system_invocations.push(Invocation::new( - config, - InvocationArgs::CnodeMint { - cnode: pd_b_cnode_obj.cap_addr, - dest_index: pd_b_cap_idx, - dest_depth: PD_CAP_BITS, - src_root: root_cnode_cap, - src_obj: pd_a_endpoint_obj.cap_addr, - src_depth: config.cap_address_bits, - rights: Rights::All as u64, // FIXME: Check rights - badge: pd_b_badge, - }, - )); + system_invocations.push(Invocation::new( + config, + InvocationArgs::CnodeMint { + cnode: send_cnode_obj.cap_addr, + dest_index: send_cap_idx, + dest_depth: PD_CAP_BITS, + src_root: root_cnode_cap, + src_obj: recv_endpoint_obj.cap_addr, + src_depth: config.cap_address_bits, + rights: Rights::All as u64, // FIXME: Check rights + badge: send_badge, + }, + )); + } } } diff --git a/tool/microkit/src/sdf.rs b/tool/microkit/src/sdf.rs index cb260b7d3..567449de8 100644 --- a/tool/microkit/src/sdf.rs +++ b/tool/microkit/src/sdf.rs @@ -127,12 +127,18 @@ pub struct SysSetVar { pub kind: SysSetVarKind, } +#[derive(Debug, Clone)] +pub struct ChannelEnd { + pub pd: usize, + pub id: u64, + pub can_notify_other: bool, + pub can_ppcall_other: bool, +} + #[derive(Debug)] pub struct Channel { - pub pd_a: usize, - pub id_a: u64, - pub pd_b: usize, - pub id_b: u64, + pub end_a: ChannelEnd, + pub end_b: ChannelEnd, } #[derive(Debug, PartialEq, Eq, Hash)] @@ -143,7 +149,6 @@ pub struct ProtectionDomain { pub priority: u8, pub budget: u64, pub period: u64, - pub pp: bool, pub passive: bool, pub stack_size: u64, pub smc: bool, @@ -270,8 +275,13 @@ impl SysMap { } impl ProtectionDomain { - pub fn needs_ep(&self) -> bool { - self.pp || self.has_children || self.virtual_machine.is_some() + pub fn needs_ep(&self, self_id: usize, channels: &[Channel]) -> bool { + self.has_children + || self.virtual_machine.is_some() + || channels.iter().any(|channel| { + (channel.end_a.can_ppcall_other && channel.end_b.pd == self_id) + || (channel.end_b.can_ppcall_other && channel.end_a.pd == self_id) + }) } fn from_xml( @@ -283,7 +293,6 @@ impl ProtectionDomain { let mut attrs = vec![ "name", "priority", - "pp", "budget", "period", "passive", @@ -330,21 +339,6 @@ impl ProtectionDomain { )); } - let pp = if let Some(xml_pp) = node.attribute("pp") { - match str_to_bool(xml_pp) { - Some(val) => val, - None => { - return Err(value_error( - xml_sdf, - node, - "pp must be 'true' or 'false'".to_string(), - )) - } - } - } else { - false - }; - let passive = if let Some(xml_passive) = node.attribute("passive") { match str_to_bool(xml_passive) { Some(val) => val, @@ -590,7 +584,6 @@ impl ProtectionDomain { priority: priority as u8, budget, period, - pp, passive, stack_size, smc, @@ -780,6 +773,71 @@ impl SysMemoryRegion { } } +impl ChannelEnd { + fn from_xml<'a>( + xml_sdf: &'a XmlSystemDescription, + node: &'a roxmltree::Node, + pds: &[ProtectionDomain], + ) -> Result { + let node_name = node.tag_name().name(); + if node_name != "end" { + let pos = xml_sdf.doc.text_pos_at(node.range().start); + return Err(format!( + "Error: invalid XML element '{}': {}", + node_name, + loc_string(xml_sdf, pos) + )); + } + + check_attributes(xml_sdf, &node, &["pd", "id", "pp", "notify"])?; + let end_pd = checked_lookup(xml_sdf, &node, "pd")?; + let end_id = checked_lookup(xml_sdf, &node, "id")? + .parse::() + .unwrap(); + + if end_id > PD_MAX_ID as i64 { + return Err(value_error( + xml_sdf, + &node, + format!("id must be < {}", PD_MAX_ID + 1), + )); + } + + if end_id < 0 { + return Err(value_error(xml_sdf, &node, "id must be >= 0".to_string())); + } + + let can_notify_other = node + .attribute("notify") + .map(str_to_bool) + .unwrap_or(Some(true)) + .ok_or_else(|| { + value_error(xml_sdf, node, "notify must be 'true'/'false'".to_string()) + })?; + + let can_ppcall_other = node + .attribute("pp") + .map(str_to_bool) + .unwrap_or(Some(false)) + .ok_or_else(|| value_error(xml_sdf, node, "pp must be 'true'/'false'".to_string()))?; + + if let Some(pd_idx) = pds.iter().position(|pd| pd.name == end_pd) { + return Ok(ChannelEnd { + pd: pd_idx, + id: end_id.try_into().unwrap(), + can_notify_other, + can_ppcall_other, + }); + } else { + return Err(value_error( + xml_sdf, + &node, + format!("invalid PD name '{end_pd}'"), + )); + } + } +} + impl Channel { /// It should be noted that this function assumes that `pds` is populated /// with all the Protection Domains that could potentially be connected with @@ -791,70 +849,22 @@ impl Channel { ) -> Result { check_attributes(xml_sdf, node, &[])?; - let mut ends: Vec<(usize, u64)> = Vec::new(); - for child in node.children() { - if !child.is_element() { - continue; - } - - let child_name = child.tag_name().name(); - match child_name { - "end" => { - check_attributes(xml_sdf, &child, &["pd", "id"])?; - let end_pd = checked_lookup(xml_sdf, &child, "pd")?; - let end_id = checked_lookup(xml_sdf, &child, "id")? - .parse::() - .unwrap(); - - if end_id > PD_MAX_ID as i64 { - return Err(value_error( - xml_sdf, - &child, - format!("id must be < {}", PD_MAX_ID + 1), - )); - } - - if end_id < 0 { - return Err(value_error(xml_sdf, &child, "id must be >= 0".to_string())); - } - - if let Some(pd_idx) = pds.iter().position(|pd| pd.name == end_pd) { - ends.push((pd_idx, end_id as u64)) - } else { - return Err(value_error( - xml_sdf, - &child, - format!("invalid PD name '{end_pd}'"), - )); - } - } - _ => { - let pos = xml_sdf.doc.text_pos_at(node.range().start); - return Err(format!( - "Error: invalid XML element '{}': {}", - child_name, - loc_string(xml_sdf, pos) - )); - } - } - } - - if ends.len() != 2 { + let [ref end_a, ref end_b] = node + .children() + .filter(|child| child.is_element()) + .map(|node| ChannelEnd::from_xml(xml_sdf, &node, pds)) + .collect::, _>>()?[..] + else { return Err(value_error( xml_sdf, node, "exactly two end elements must be specified".to_string(), )); - } - - let (pd_a, id_a) = ends[0]; - let (pd_b, id_b) = ends[1]; + }; Ok(Channel { - pd_a, - id_a, - pd_b, - id_b, + end_a: end_a.clone(), + end_b: end_b.clone(), }) } } @@ -1166,24 +1176,24 @@ pub fn parse(filename: &str, xml: &str, config: &Config) -> Result