Skip to content

Commit

Permalink
enables functions with multiple entry (#926)
Browse files Browse the repository at this point in the history
* enables functions with multiple entry

* refactoring

* solved through symtab data

* minor bug fix

* wip

* two algorithms for IR without repetitions

* finally end up with the last algorithm

* seems everything works togetherworks

* refactoring

* fixed oasis file

* fixed minor bug

* added a check, that fall call hasn't taken yet

* allow instructions to share bytes

* refactoring

* fixed a minor bug

* bug fising

* check shared edges

* refactoring

* it works

* few optimizations

* refactored

* save all edges

* renamed function
  • Loading branch information
gitoleg authored Apr 2, 2019
1 parent 7268622 commit e598acc
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 75 deletions.
13 changes: 7 additions & 6 deletions lib/bap_disasm/bap_disasm_rec.ml
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ type stage1 = {

type stage2 = {
stage1 : stage1;
addrs : mem Addrs.t; (* table of blocks *)
succs : dests Addrs.t;
preds : addr list Addrs.t;
addrs : mem Addrs.t; (* table of blocks *)
succs : dests Addrs.t;
preds : addr list Addrs.t;
disasm : mem -> decoded list;
}

Expand Down Expand Up @@ -161,13 +161,14 @@ let ok_nil = function
| Ok xs -> xs
| Error _ -> []

let is_jump s mem insn =
let is_terminator s mem insn =
Dis.Insn.is insn `May_affect_control_flow ||
has_jump (ok_nil (s.lift mem insn))
has_jump (ok_nil (s.lift mem insn)) ||
Set.mem s.inits (Addr.succ (Memory.max_addr mem))

let update s mem insn dests : stage1 =
let s = { s with visited = Visited.add_insn s.visited mem } in
if is_jump s mem insn then
if is_terminator s mem insn then
let () = update_dests s mem dests in
let roots = List.(filter_map ~f:fst dests |> rev_append s.roots) in
{ s with roots }
Expand Down
95 changes: 65 additions & 30 deletions lib/bap_disasm/bap_disasm_reconstructor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -71,43 +71,78 @@ let is_unresolved blk cfg =
deg = 0 ||
(deg = 1 && is_fall (Seq.hd_exn (Cfg.Node.outputs blk cfg)))

let add_callnames syms name cfg blk =
if is_call blk then
let add_call symtab blk name label =
Symtab.add_call symtab blk name label

let add_unresolved syms name cfg blk =
if is_unresolved blk cfg then
let call_addr = terminator_addr blk in
if is_unresolved blk cfg then
Symtab.add_call_name syms blk (name call_addr)
else
Seq.fold ~init:syms (Cfg.Node.outputs blk cfg)
~f:(fun syms e ->
if is_fall e then syms
else
Cfg.Edge.dst e |> Block.addr |> name |>
Symtab.add_call_name syms blk)
add_call syms blk (name call_addr) `Fall
else syms

let collect name cfg roots =
Seq.fold (Cfg.nodes cfg) ~init:(Block.Set.empty, Symtab.empty)
~f:(fun (entries,syms) blk ->
Set.union entries (entries_of_block cfg roots blk),
add_callnames syms name cfg blk)

let reconstruct name roots prog =
let roots = Addr.Set.of_list roots in
let entries,syms = collect name prog roots in
let is_call e = Set.mem entries (Cfg.Edge.dst e) in
let rec add cfg node =
let cfg = Cfg.Node.insert node cfg in
Seq.fold (Cfg.Node.outputs node prog) ~init:cfg
~f:(fun cfg edge ->
if is_call edge then cfg
~f:(fun (entries, syms) blk ->
let entries' = entries_of_block cfg roots blk in
Set.union entries entries', add_unresolved syms name cfg blk)

let reachable cfg from =
let rec loop nodes node =
Seq.fold (Cfg.Node.outputs node cfg)
~init:(Set.add nodes node)
~f:(fun nodes edge ->
if Set.mem nodes ( Cfg.Edge.dst edge) then nodes
else loop nodes (Cfg.Edge.dst edge)) in
loop Block.Set.empty from

let sub roots prog start =
let is_call e = Set.mem roots (Cfg.Edge.dst e) in
let update_inputs node init =
Seq.fold ~init (Cfg.Node.inputs node prog) ~f:Set.add in
let rec loop cfg inputs node =
Seq.fold (Cfg.Node.outputs node prog)
~init:(cfg, update_inputs node inputs)
~f:(fun (cfg, inputs) edge ->
if is_call edge then cfg,inputs
else
let cfg' = Cfg.Edge.insert edge cfg in
if Cfg.Node.mem (Cfg.Edge.dst edge) cfg then cfg'
else add cfg' (Cfg.Edge.dst edge)) in
Set.fold entries ~init:syms ~f:(fun syms entry ->
let name = name (Block.addr entry) in
let cfg = add Cfg.empty entry in
Symtab.add_symbol syms (name,entry,cfg))
if Cfg.Node.mem (Cfg.Edge.dst edge) cfg then cfg',inputs
else loop cfg' inputs (Cfg.Edge.dst edge)) in
let cfg = Cfg.Node.insert start Cfg.empty in
loop cfg Cfg.Edge.Set.empty start

let edges_of_seq s = Seq.fold s ~init:Cfg.Edge.Set.empty ~f:Set.add

let reconstruct name initial_roots prog =
let (--) = Set.diff in
let update_symtab syms cfg entry inputs =
let name = name (Block.addr entry) in
let syms = Symtab.add_symbol syms (name,entry,cfg) in
Set.fold inputs ~init:syms ~f:(fun syms e ->
add_call syms (Cfg.Edge.src e) name (Cfg.Edge.label e)) in
let remove_node cfg n = Cfg.Node.remove n cfg in
let remove_reachable cfg from =
let reachable = reachable cfg from in
let cfg = Set.fold reachable ~init:cfg ~f:remove_node in
remove_node cfg from in
let collect_destinations edges =
Set.fold edges ~init:Block.Set.empty ~f:(fun bs e ->
Set.add bs (Cfg.Edge.dst e)) in
let rec loop known_roots syms = function
| [] -> syms,known_roots
| root :: roots ->
let self_inputs = edges_of_seq (Cfg.Node.inputs root prog) in
let cfg, inputs = sub known_roots prog root in
let edges = edges_of_seq (Cfg.edges cfg) in
let calls = collect_destinations (inputs -- edges -- self_inputs) in
let cfg = Set.fold calls ~init:cfg ~f:remove_reachable in
let known = Set.union known_roots calls in
let syms = update_symtab syms cfg root self_inputs in
let syms,known = loop known syms (Set.to_list calls) in
loop known syms roots in
let initial_roots = Addr.Set.of_list initial_roots in
let roots,syms = collect name prog initial_roots in
fst @@ loop roots syms (Set.to_list roots)

let of_blocks syms =
let reconstruct (cfg : cfg) =
Expand Down
22 changes: 13 additions & 9 deletions lib/bap_disasm/bap_disasm_symtab.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module Insn = Bap_disasm_insn


type block = Block.t [@@deriving compare, sexp_of]
type edge = Block.edge [@@deriving compare, sexp_of]
type cfg = Cfg.t [@@deriving compare]


Expand All @@ -29,11 +30,10 @@ type t = {
addrs : fn Addr.Map.t;
names : fn String.Map.t;
memory : fn Memmap.t;
callnames : string Addr.Map.t;
callees : (string * edge) list Addr.Map.t;
} [@@deriving sexp_of]



let compare t1 t2 =
Addr.Map.compare Fn.compare t1.addrs t2.addrs

Expand All @@ -47,7 +47,7 @@ let empty = {
addrs = Addr.Map.empty;
names = String.Map.empty;
memory = Memmap.empty;
callnames = Addr.Map.empty;
callees = Addr.Map.empty;
}

let merge m1 m2 =
Expand All @@ -58,16 +58,17 @@ let filter_mem mem name entry =
Memmap.filter mem ~f:(fun (n,e,_) ->
not(String.(name = n) || Block.(entry = e)))

let filter_callnames name =
Map.filter ~f:( fun name' -> String.(name <> name'))
let filter_callees name callees =
Map.map callees
~f:(List.filter ~f:(fun (name',_) -> String.(name <> name')))

let remove t (name,entry,_) : t =
if Map.mem t.addrs (Block.addr entry) then
{
names = Map.remove t.names name;
addrs = Map.remove t.addrs (Block.addr entry);
memory = filter_mem t.memory name entry;
callnames = filter_callnames name t.callnames
callees = filter_callees name t.callees;
}
else t

Expand Down Expand Up @@ -96,7 +97,10 @@ let name_of_fn = fst
let entry_of_fn = snd
let span fn = span fn |> Memmap.map ~f:(fun _ -> ())

let add_call_name t b name =
{ t with callnames = Map.set t.callnames (Block.addr b) name }
let add_call t b name edge =
{t with callees = Map.add_multi t.callees (Block.addr b) (name,edge)}

let find_call_name t addr = Map.find t.callnames addr
let enum_calls t addr =
match Map.find t.callees addr with
| None -> []
| Some callees -> callees
11 changes: 7 additions & 4 deletions lib/bap_disasm/bap_disasm_symtab.mli
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ open Bap_types.Std
open Image_internal_std

type block = Bap_disasm_block.t
type edge = Bap_disasm_block.edge
type cfg = Bap_disasm_rec.Cfg.t

type t [@@deriving compare, sexp_of]
Expand All @@ -21,8 +22,10 @@ val intersecting : t -> mem -> fn list
val to_sequence : t -> fn seq
val span : fn -> unit memmap

(* remembers a call to a function from the given block *)
val add_call_name : t -> block -> string -> t
(** [add_call symtab block name edge] remembers a call to a function
[name] from the given block with [edge] *)
val add_call : t -> block -> string -> edge -> t

(* finds if there are any calls from the given block *)
val find_call_name : t -> addr -> string option
(** [enum_calls t addr] returns a list of calls from a block with
the given [addr] *)
val enum_calls : t -> addr -> (string * edge) list
88 changes: 62 additions & 26 deletions lib/bap_sema/bap_sema_lift.ml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ let linear_of_stmt ?addr return insn stmt : linear list =
Label finish :: [] in
linearize stmt


let lift_insn ?addr fall init insn =
List.fold (Insn.bil insn) ~init ~f:(fun init stmt ->
List.fold (linear_of_stmt ?addr fall insn stmt) ~init
Expand All @@ -152,8 +153,39 @@ let is_conditional_jump jmp =
Insn.(may affect_control_flow) jmp &&
has_jump_under_condition (Insn.bil jmp)

let blk cfg block : blk term list =
let fall_label = label_of_fall cfg block in
let has_called block addr =
let finder =
object inherit [unit] Stmt.finder
method! enter_jmp e r =
match e with
| Bil.Int a when Addr.(a = addr) -> r.return (Some ())
| _ -> r
end in
Bil.exists finder (Insn.bil (Block.terminator block))

let fall_of_symtab symtab block =
Option.(
symtab >>= fun symtab ->
match Symtab.enum_calls symtab (Block.addr block) with
| [] -> None
| calls ->
List.find_map calls
~f:(fun (n,e) -> Option.some_if (e = `Fall) n) >>= fun name ->
Symtab.find_by_name symtab name >>= fun (_,entry,_) ->
Option.some_if Block.(block <> entry) entry >>= fun callee ->
let addr = Block.addr callee in
Option.some_if (not (has_called block addr)) () >>= fun () ->
let bldr = Ir_blk.Builder.create () in
let call = Call.create ~target:(Label.indirect Bil.(int addr)) () in
let () = Ir_blk.Builder.add_jmp bldr (Ir_jmp.create_call call) in
Some (Ir_blk.Builder.result bldr))

let blk ?symtab cfg block : blk term list =
let fall_to_fn = fall_of_symtab symtab block in
let fall_label =
match label_of_fall cfg block, fall_to_fn with
| None, Some b -> Some (Label.direct (Term.tid b))
| fall_label,_ -> fall_label in
List.fold (Block.insns block) ~init:([],Ir_blk.Builder.create ())
~f:(fun init (mem,insn) ->
let addr = Memory.min_addr mem in
Expand All @@ -167,7 +199,10 @@ let blk cfg block : blk term list =
| Some dst -> Some (`Jmp (Ir_jmp.create_goto dst)) in
Option.iter fall ~f:(Ir_blk.Builder.add_elt b);
let b = Ir_blk.Builder.result b in
List.rev (b::bs) |> function
let blocks = match fall_to_fn with
| None -> b :: bs
| Some b' -> b' :: b :: bs in
List.rev blocks |> function
| [] -> assert false
| b::bs -> Term.set_attr b address (Block.addr block) :: bs

Expand Down Expand Up @@ -214,14 +249,14 @@ let remove_false_jmps blk =

let unbound _ = true

let lift_sub entry cfg =
let lift_sub ?symtab entry cfg =
let addrs = Addr.Table.create () in
let recons acc b =
let addr = Block.addr b in
let blks = blk cfg b in
let blks = blk ?symtab cfg b in
Option.iter (List.hd blks) ~f:(fun blk ->
Hashtbl.add_exn addrs ~key:addr ~data:(Term.tid blk));
acc @ blks in
acc @ blks in
let blocks = Graphlib.reverse_postorder_traverse
(module Cfg) ~start:entry cfg in
let blks = Seq.fold blocks ~init:[] ~f:recons in
Expand All @@ -248,23 +283,23 @@ let indirect_target jmp =
let is_indirect_call jmp = Option.is_some (indirect_target jmp)

let with_address t ~f ~default =
match Term.get_attr t address with
| None -> default
| Some a -> f a
Option.value_map ~default ~f (Term.get_attr t address)

let find_call_name symtab blk =
with_address blk ~default:None ~f:(Symtab.find_call_name symtab)
let with_address_opt t ~f ~default =
let g a = Option.value (f a) ~default in
with_address t ~f:g ~default

let update_unresolved symtab unresolved exts sub =
let iter cls t ~f = Term.to_sequence cls t |> Seq.iter ~f in
let symbol_exists name =
Option.is_some (Symtab.find_by_name symtab name) in
let is_known a = Option.is_some (Symtab.find_by_start symtab a) in
let is_unknown name = not (symbol_exists name) in
let add_external name =
Hashtbl.update exts name ~f:(function
| None -> create_synthetic name
| Some x -> x) in
let add_external (name,_) =
if is_unknown name then
Hashtbl.update exts name ~f:(function
| None -> create_synthetic name
| Some x -> x) in
iter blk_t sub ~f:(fun blk ->
iter jmp_t blk ~f:(fun jmp ->
match indirect_target jmp with
Expand All @@ -273,24 +308,24 @@ let update_unresolved symtab unresolved exts sub =
| _ ->
with_address blk ~default:() ~f:(fun addr ->
Hash_set.add unresolved addr;
match Symtab.find_call_name symtab addr with
| Some name when is_unknown name -> add_external name
| _ -> ())))
Symtab.enum_calls symtab addr |>
List.iter ~f:add_external)))

let resolve_indirect symtab exts blk jmp =
let update_target tar =
Option.some @@
match Ir_jmp.kind jmp with
| Call c -> Ir_jmp.with_kind jmp (Call (Call.with_target c tar))
| _ -> jmp in
match find_call_name symtab blk with
| None -> jmp
| Some name ->
let resolve_name (name,_) =
match Symtab.find_by_name symtab name with
| Some (_,b,_) -> update_target (Indirect (Int (Block.addr b)))
| None ->
match Hashtbl.find exts name with
| _ -> match Hashtbl.find exts name with
| Some s -> update_target (Direct (Term.tid s))
| None -> jmp
| None -> None in
with_address_opt blk ~default:jmp ~f:(fun addr ->
Symtab.enum_calls symtab addr |>
List.find_map ~f:resolve_name)

let program symtab =
let b = Ir_program.Builder.create () in
Expand All @@ -299,7 +334,7 @@ let program symtab =
let unresolved = Addr.Hash_set.create () in
Seq.iter (Symtab.to_sequence symtab) ~f:(fun (name,entry,cfg) ->
let addr = Block.addr entry in
let sub = lift_sub entry cfg in
let sub = lift_sub ~symtab entry cfg in
Ir_program.Builder.add_sub b (Ir_sub.with_name sub name);
Tid.set_name (Term.tid sub) name;
Hashtbl.add_exn addrs ~key:addr ~data:(Term.tid sub);
Expand All @@ -318,7 +353,8 @@ let program symtab =
else j in
resolve_jmp ~local:false addrs j)))

let sub = lift_sub
let sub = lift_sub ?symtab:None
let blk = blk ?symtab:None

let insn insn =
lift_insn None ([], Ir_blk.Builder.create ()) insn |>
Expand Down

0 comments on commit e598acc

Please sign in to comment.