diff --git a/lib/bap_disasm/bap_disasm_rec.ml b/lib/bap_disasm/bap_disasm_rec.ml index 1855b297a..f54b6a0e2 100644 --- a/lib/bap_disasm/bap_disasm_rec.ml +++ b/lib/bap_disasm/bap_disasm_rec.ml @@ -123,9 +123,9 @@ type stage1 = { type stage2 = { stage1 : stage1; - addrs : mem Addrs.t; (* table of blocks *) - succs : dests Addrs.t; - preds : addr list Addrs.t; + addrs : mem Addrs.t; (* table of blocks *) + succs : dests Addrs.t; + preds : addr list Addrs.t; disasm : mem -> decoded list; } @@ -161,13 +161,14 @@ let ok_nil = function | Ok xs -> xs | Error _ -> [] -let is_jump s mem insn = +let is_terminator s mem insn = Dis.Insn.is insn `May_affect_control_flow || - has_jump (ok_nil (s.lift mem insn)) + has_jump (ok_nil (s.lift mem insn)) || + Set.mem s.inits (Addr.succ (Memory.max_addr mem)) let update s mem insn dests : stage1 = let s = { s with visited = Visited.add_insn s.visited mem } in - if is_jump s mem insn then + if is_terminator s mem insn then let () = update_dests s mem dests in let roots = List.(filter_map ~f:fst dests |> rev_append s.roots) in { s with roots } diff --git a/lib/bap_disasm/bap_disasm_reconstructor.ml b/lib/bap_disasm/bap_disasm_reconstructor.ml index 94413d6bc..3294b7f2d 100644 --- a/lib/bap_disasm/bap_disasm_reconstructor.ml +++ b/lib/bap_disasm/bap_disasm_reconstructor.ml @@ -71,43 +71,78 @@ let is_unresolved blk cfg = deg = 0 || (deg = 1 && is_fall (Seq.hd_exn (Cfg.Node.outputs blk cfg))) -let add_callnames syms name cfg blk = - if is_call blk then +let add_call symtab blk name label = + Symtab.add_call symtab blk name label + +let add_unresolved syms name cfg blk = + if is_unresolved blk cfg then let call_addr = terminator_addr blk in - if is_unresolved blk cfg then - Symtab.add_call_name syms blk (name call_addr) - else - Seq.fold ~init:syms (Cfg.Node.outputs blk cfg) - ~f:(fun syms e -> - if is_fall e then syms - else - Cfg.Edge.dst e |> Block.addr |> name |> - Symtab.add_call_name syms blk) + add_call syms blk (name call_addr) `Fall else syms let collect name cfg roots = Seq.fold (Cfg.nodes cfg) ~init:(Block.Set.empty, Symtab.empty) - ~f:(fun (entries,syms) blk -> - Set.union entries (entries_of_block cfg roots blk), - add_callnames syms name cfg blk) - -let reconstruct name roots prog = - let roots = Addr.Set.of_list roots in - let entries,syms = collect name prog roots in - let is_call e = Set.mem entries (Cfg.Edge.dst e) in - let rec add cfg node = - let cfg = Cfg.Node.insert node cfg in - Seq.fold (Cfg.Node.outputs node prog) ~init:cfg - ~f:(fun cfg edge -> - if is_call edge then cfg + ~f:(fun (entries, syms) blk -> + let entries' = entries_of_block cfg roots blk in + Set.union entries entries', add_unresolved syms name cfg blk) + +let reachable cfg from = + let rec loop nodes node = + Seq.fold (Cfg.Node.outputs node cfg) + ~init:(Set.add nodes node) + ~f:(fun nodes edge -> + if Set.mem nodes ( Cfg.Edge.dst edge) then nodes + else loop nodes (Cfg.Edge.dst edge)) in + loop Block.Set.empty from + +let sub roots prog start = + let is_call e = Set.mem roots (Cfg.Edge.dst e) in + let update_inputs node init = + Seq.fold ~init (Cfg.Node.inputs node prog) ~f:Set.add in + let rec loop cfg inputs node = + Seq.fold (Cfg.Node.outputs node prog) + ~init:(cfg, update_inputs node inputs) + ~f:(fun (cfg, inputs) edge -> + if is_call edge then cfg,inputs else let cfg' = Cfg.Edge.insert edge cfg in - if Cfg.Node.mem (Cfg.Edge.dst edge) cfg then cfg' - else add cfg' (Cfg.Edge.dst edge)) in - Set.fold entries ~init:syms ~f:(fun syms entry -> - let name = name (Block.addr entry) in - let cfg = add Cfg.empty entry in - Symtab.add_symbol syms (name,entry,cfg)) + if Cfg.Node.mem (Cfg.Edge.dst edge) cfg then cfg',inputs + else loop cfg' inputs (Cfg.Edge.dst edge)) in + let cfg = Cfg.Node.insert start Cfg.empty in + loop cfg Cfg.Edge.Set.empty start + +let edges_of_seq s = Seq.fold s ~init:Cfg.Edge.Set.empty ~f:Set.add + +let reconstruct name initial_roots prog = + let (--) = Set.diff in + let update_symtab syms cfg entry inputs = + let name = name (Block.addr entry) in + let syms = Symtab.add_symbol syms (name,entry,cfg) in + Set.fold inputs ~init:syms ~f:(fun syms e -> + add_call syms (Cfg.Edge.src e) name (Cfg.Edge.label e)) in + let remove_node cfg n = Cfg.Node.remove n cfg in + let remove_reachable cfg from = + let reachable = reachable cfg from in + let cfg = Set.fold reachable ~init:cfg ~f:remove_node in + remove_node cfg from in + let collect_destinations edges = + Set.fold edges ~init:Block.Set.empty ~f:(fun bs e -> + Set.add bs (Cfg.Edge.dst e)) in + let rec loop known_roots syms = function + | [] -> syms,known_roots + | root :: roots -> + let self_inputs = edges_of_seq (Cfg.Node.inputs root prog) in + let cfg, inputs = sub known_roots prog root in + let edges = edges_of_seq (Cfg.edges cfg) in + let calls = collect_destinations (inputs -- edges -- self_inputs) in + let cfg = Set.fold calls ~init:cfg ~f:remove_reachable in + let known = Set.union known_roots calls in + let syms = update_symtab syms cfg root self_inputs in + let syms,known = loop known syms (Set.to_list calls) in + loop known syms roots in + let initial_roots = Addr.Set.of_list initial_roots in + let roots,syms = collect name prog initial_roots in + fst @@ loop roots syms (Set.to_list roots) let of_blocks syms = let reconstruct (cfg : cfg) = diff --git a/lib/bap_disasm/bap_disasm_symtab.ml b/lib/bap_disasm/bap_disasm_symtab.ml index 96bf479eb..34c97423c 100644 --- a/lib/bap_disasm/bap_disasm_symtab.ml +++ b/lib/bap_disasm/bap_disasm_symtab.ml @@ -12,6 +12,7 @@ module Insn = Bap_disasm_insn type block = Block.t [@@deriving compare, sexp_of] +type edge = Block.edge [@@deriving compare, sexp_of] type cfg = Cfg.t [@@deriving compare] @@ -29,11 +30,10 @@ type t = { addrs : fn Addr.Map.t; names : fn String.Map.t; memory : fn Memmap.t; - callnames : string Addr.Map.t; + callees : (string * edge) list Addr.Map.t; } [@@deriving sexp_of] - let compare t1 t2 = Addr.Map.compare Fn.compare t1.addrs t2.addrs @@ -47,7 +47,7 @@ let empty = { addrs = Addr.Map.empty; names = String.Map.empty; memory = Memmap.empty; - callnames = Addr.Map.empty; + callees = Addr.Map.empty; } let merge m1 m2 = @@ -58,8 +58,9 @@ let filter_mem mem name entry = Memmap.filter mem ~f:(fun (n,e,_) -> not(String.(name = n) || Block.(entry = e))) -let filter_callnames name = - Map.filter ~f:( fun name' -> String.(name <> name')) +let filter_callees name callees = + Map.map callees + ~f:(List.filter ~f:(fun (name',_) -> String.(name <> name'))) let remove t (name,entry,_) : t = if Map.mem t.addrs (Block.addr entry) then @@ -67,7 +68,7 @@ let remove t (name,entry,_) : t = names = Map.remove t.names name; addrs = Map.remove t.addrs (Block.addr entry); memory = filter_mem t.memory name entry; - callnames = filter_callnames name t.callnames + callees = filter_callees name t.callees; } else t @@ -96,7 +97,10 @@ let name_of_fn = fst let entry_of_fn = snd let span fn = span fn |> Memmap.map ~f:(fun _ -> ()) -let add_call_name t b name = - { t with callnames = Map.set t.callnames (Block.addr b) name } +let add_call t b name edge = + {t with callees = Map.add_multi t.callees (Block.addr b) (name,edge)} -let find_call_name t addr = Map.find t.callnames addr +let enum_calls t addr = + match Map.find t.callees addr with + | None -> [] + | Some callees -> callees diff --git a/lib/bap_disasm/bap_disasm_symtab.mli b/lib/bap_disasm/bap_disasm_symtab.mli index 48dfbdb3c..9fc9b4d5b 100644 --- a/lib/bap_disasm/bap_disasm_symtab.mli +++ b/lib/bap_disasm/bap_disasm_symtab.mli @@ -3,6 +3,7 @@ open Bap_types.Std open Image_internal_std type block = Bap_disasm_block.t +type edge = Bap_disasm_block.edge type cfg = Bap_disasm_rec.Cfg.t type t [@@deriving compare, sexp_of] @@ -21,8 +22,10 @@ val intersecting : t -> mem -> fn list val to_sequence : t -> fn seq val span : fn -> unit memmap -(* remembers a call to a function from the given block *) -val add_call_name : t -> block -> string -> t +(** [add_call symtab block name edge] remembers a call to a function + [name] from the given block with [edge] *) +val add_call : t -> block -> string -> edge -> t -(* finds if there are any calls from the given block *) -val find_call_name : t -> addr -> string option +(** [enum_calls t addr] returns a list of calls from a block with + the given [addr] *) +val enum_calls : t -> addr -> (string * edge) list diff --git a/lib/bap_sema/bap_sema_lift.ml b/lib/bap_sema/bap_sema_lift.ml index 5ac2820bd..72e58aca7 100644 --- a/lib/bap_sema/bap_sema_lift.ml +++ b/lib/bap_sema/bap_sema_lift.ml @@ -127,6 +127,7 @@ let linear_of_stmt ?addr return insn stmt : linear list = Label finish :: [] in linearize stmt + let lift_insn ?addr fall init insn = List.fold (Insn.bil insn) ~init ~f:(fun init stmt -> List.fold (linear_of_stmt ?addr fall insn stmt) ~init @@ -152,8 +153,39 @@ let is_conditional_jump jmp = Insn.(may affect_control_flow) jmp && has_jump_under_condition (Insn.bil jmp) -let blk cfg block : blk term list = - let fall_label = label_of_fall cfg block in +let has_called block addr = + let finder = + object inherit [unit] Stmt.finder + method! enter_jmp e r = + match e with + | Bil.Int a when Addr.(a = addr) -> r.return (Some ()) + | _ -> r + end in + Bil.exists finder (Insn.bil (Block.terminator block)) + +let fall_of_symtab symtab block = + Option.( + symtab >>= fun symtab -> + match Symtab.enum_calls symtab (Block.addr block) with + | [] -> None + | calls -> + List.find_map calls + ~f:(fun (n,e) -> Option.some_if (e = `Fall) n) >>= fun name -> + Symtab.find_by_name symtab name >>= fun (_,entry,_) -> + Option.some_if Block.(block <> entry) entry >>= fun callee -> + let addr = Block.addr callee in + Option.some_if (not (has_called block addr)) () >>= fun () -> + let bldr = Ir_blk.Builder.create () in + let call = Call.create ~target:(Label.indirect Bil.(int addr)) () in + let () = Ir_blk.Builder.add_jmp bldr (Ir_jmp.create_call call) in + Some (Ir_blk.Builder.result bldr)) + +let blk ?symtab cfg block : blk term list = + let fall_to_fn = fall_of_symtab symtab block in + let fall_label = + match label_of_fall cfg block, fall_to_fn with + | None, Some b -> Some (Label.direct (Term.tid b)) + | fall_label,_ -> fall_label in List.fold (Block.insns block) ~init:([],Ir_blk.Builder.create ()) ~f:(fun init (mem,insn) -> let addr = Memory.min_addr mem in @@ -167,7 +199,10 @@ let blk cfg block : blk term list = | Some dst -> Some (`Jmp (Ir_jmp.create_goto dst)) in Option.iter fall ~f:(Ir_blk.Builder.add_elt b); let b = Ir_blk.Builder.result b in - List.rev (b::bs) |> function + let blocks = match fall_to_fn with + | None -> b :: bs + | Some b' -> b' :: b :: bs in + List.rev blocks |> function | [] -> assert false | b::bs -> Term.set_attr b address (Block.addr block) :: bs @@ -214,14 +249,14 @@ let remove_false_jmps blk = let unbound _ = true -let lift_sub entry cfg = +let lift_sub ?symtab entry cfg = let addrs = Addr.Table.create () in let recons acc b = let addr = Block.addr b in - let blks = blk cfg b in + let blks = blk ?symtab cfg b in Option.iter (List.hd blks) ~f:(fun blk -> Hashtbl.add_exn addrs ~key:addr ~data:(Term.tid blk)); - acc @ blks in + acc @ blks in let blocks = Graphlib.reverse_postorder_traverse (module Cfg) ~start:entry cfg in let blks = Seq.fold blocks ~init:[] ~f:recons in @@ -248,12 +283,11 @@ let indirect_target jmp = let is_indirect_call jmp = Option.is_some (indirect_target jmp) let with_address t ~f ~default = - match Term.get_attr t address with - | None -> default - | Some a -> f a + Option.value_map ~default ~f (Term.get_attr t address) -let find_call_name symtab blk = - with_address blk ~default:None ~f:(Symtab.find_call_name symtab) +let with_address_opt t ~f ~default = + let g a = Option.value (f a) ~default in + with_address t ~f:g ~default let update_unresolved symtab unresolved exts sub = let iter cls t ~f = Term.to_sequence cls t |> Seq.iter ~f in @@ -261,10 +295,11 @@ let update_unresolved symtab unresolved exts sub = Option.is_some (Symtab.find_by_name symtab name) in let is_known a = Option.is_some (Symtab.find_by_start symtab a) in let is_unknown name = not (symbol_exists name) in - let add_external name = - Hashtbl.update exts name ~f:(function - | None -> create_synthetic name - | Some x -> x) in + let add_external (name,_) = + if is_unknown name then + Hashtbl.update exts name ~f:(function + | None -> create_synthetic name + | Some x -> x) in iter blk_t sub ~f:(fun blk -> iter jmp_t blk ~f:(fun jmp -> match indirect_target jmp with @@ -273,24 +308,24 @@ let update_unresolved symtab unresolved exts sub = | _ -> with_address blk ~default:() ~f:(fun addr -> Hash_set.add unresolved addr; - match Symtab.find_call_name symtab addr with - | Some name when is_unknown name -> add_external name - | _ -> ()))) + Symtab.enum_calls symtab addr |> + List.iter ~f:add_external))) let resolve_indirect symtab exts blk jmp = let update_target tar = + Option.some @@ match Ir_jmp.kind jmp with | Call c -> Ir_jmp.with_kind jmp (Call (Call.with_target c tar)) | _ -> jmp in - match find_call_name symtab blk with - | None -> jmp - | Some name -> + let resolve_name (name,_) = match Symtab.find_by_name symtab name with | Some (_,b,_) -> update_target (Indirect (Int (Block.addr b))) - | None -> - match Hashtbl.find exts name with + | _ -> match Hashtbl.find exts name with | Some s -> update_target (Direct (Term.tid s)) - | None -> jmp + | None -> None in + with_address_opt blk ~default:jmp ~f:(fun addr -> + Symtab.enum_calls symtab addr |> + List.find_map ~f:resolve_name) let program symtab = let b = Ir_program.Builder.create () in @@ -299,7 +334,7 @@ let program symtab = let unresolved = Addr.Hash_set.create () in Seq.iter (Symtab.to_sequence symtab) ~f:(fun (name,entry,cfg) -> let addr = Block.addr entry in - let sub = lift_sub entry cfg in + let sub = lift_sub ~symtab entry cfg in Ir_program.Builder.add_sub b (Ir_sub.with_name sub name); Tid.set_name (Term.tid sub) name; Hashtbl.add_exn addrs ~key:addr ~data:(Term.tid sub); @@ -318,7 +353,8 @@ let program symtab = else j in resolve_jmp ~local:false addrs j))) -let sub = lift_sub +let sub = lift_sub ?symtab:None +let blk = blk ?symtab:None let insn insn = lift_insn None ([], Ir_blk.Builder.create ()) insn |>