diff --git a/.gitignore b/.gitignore index 9cf44ed1b..da2f9c426 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ _oasis /plugins/primus_greedy/.merlin /plugins/primus_lisp/primus_lisp_config.ml /setup.exe +/tools/bap_config diff --git a/.merlin b/.merlin index 9f0b52c30..7a50f72a0 100644 --- a/.merlin +++ b/.merlin @@ -13,6 +13,9 @@ B _build B _build/lib/bap B _build/lib/bap_abi B _build/lib/bap_api +B _build/lib/bitvec +B _build/lib/bitvec_order +B _build/lib/bitvec_sexp B _build/lib/bap_bundle B _build/lib/bap_c B _build/lib/bap_config @@ -24,6 +27,7 @@ B _build/lib/bap_image B _build/lib/bap_plugins B _build/lib/bap_sema B _build/lib/bap_types +B _build/lib/bap_core_theory B _build/lib/graphlib B _build/lib/monads B _build/lib/ogre @@ -31,6 +35,7 @@ B _build/lib/bap_primus B _build/lib/bap_llvm B _build/lib/regular B _build/lib/text_tags +B _build/lib/knowledge S lib/bap S lib/bap_abi diff --git a/lib/arm/arm_env.ml b/lib/arm/arm_env.ml index 19772e804..410edb8c4 100644 --- a/lib/arm/arm_env.ml +++ b/lib/arm/arm_env.ml @@ -15,10 +15,6 @@ let reg32 reg = make_register reg reg32_t let spsr = reg32 `SPSR let cpsr = reg32 `CPSR - -(* Memory definition *) -(* let mem = new_var "mem32" (TMem (Reg 32, Reg 8)) *) - (* Arithmetic flags, individually *) let nf = "NF" %: bool_t let zf = "ZF" %: bool_t @@ -85,4 +81,4 @@ let of_reg : reg -> var = function | #ccr_reg as reg -> var_of_ccr reg let new_var name = Var.create name reg32_t -let mem = Var.create "mem" (mem32_t `r32) +let mem = Var.create "mem" (mem32_t `r8) diff --git a/lib/bap/.merlin b/lib/bap/.merlin index 7b54bdc25..bda004da0 100644 --- a/lib/bap/.merlin +++ b/lib/bap/.merlin @@ -6,5 +6,6 @@ B ../../_build/lib/bap_image B ../../_build/lib/bap_disasm B ../../_build/lib/bap_sema B ../../_build/lib/bap_bundle +B ../../_build/lib/knowledge REC diff --git a/lib/bap/bap.mli b/lib/bap/bap.mli index 57538bc49..e89409db5 100644 --- a/lib/bap/bap.mli +++ b/lib/bap/bap.mli @@ -5,6 +5,8 @@ open Monads.Std open Regular.Std open Graphlib.Std open Bap_future.Std +open Bap_knowledge +open Bap_core_theory module Std : sig [@@@warning "-D"] @@ -154,7 +156,7 @@ module Std : sig - hashset is available under [Hash_set] name - sexpable and binable interface; - [to_string], [str], [pp], [ppo], [pps] functions - for pretty-printing. + for pretty-printing. It is a convention, that for each type, there is a module with the same name that implements its interface. For example, type @@ -214,17 +216,17 @@ module Std : sig provides interfaces for the memory objects: - {{!Memory}mem} - a contiguous array of bytes, indexed with - absolute addresses; + absolute addresses; - {{!Table} 'a table} - a mapping from a memory regions to - arbitrary data (no duplicates or intersections); + arbitrary data (no duplicates or intersections); - {{!Memmap}a memmap} - a mapping from memory region to arbitrary data with duplicates and intersections allowed, aka segment tree or interval map; - {{!Image}image} - represents a binary object with all its - symbols, segments, sections and other meta information. + symbols, segments, sections and other meta information. The [Image] module uses the plugin system to load binary objects. In order to add new loader, one should implement the @@ -237,10 +239,10 @@ module Std : sig are provided: - {{!Disasm}Disasm} - a regular interface that hides all - complexities, but may not always be very flexible. + complexities, but may not always be very flexible. - {{!Disasm_expert}Disasm_expert} - an expert interface that - provides access to a low-level representation. It is very - flexible and fast, but harder to use. + provides access to a low-level representation. It is very + flexible and fast, but harder to use. To disassemble files or data with the regular interface, use one of the following functions: @@ -487,18 +489,19 @@ module Std : sig By default the memory is annotated with the following attributes: - {{!Image.section}section} -- for regions of memory that had a - particular name in the original binary. For example, in ELF, - sections have names that annotate a corresponding memory - region. If project was created from memory object, then the - overall memory will be marked as a ["bap.user"] section. + particular name in the original binary. For example, in ELF, + sections have names that annotate a corresponding memory + region. If project was created from memory object, then the + overall memory will be marked as a ["bap.user"] section. - {{!Image.segment}segment} -- if the binary data was loaded - from a binary format that contains segments, then the - corresponding memory regions are be marked. Segments provide - access to permission information. *) + from a binary format that contains segments, then the + corresponding memory regions are be marked. Segments provide + access to permission information. *) (** {1:api BAP API} *) + (** Abstract integral type. This module describes an interface of an integral arithmetic @@ -1011,7 +1014,7 @@ module Std : sig ] [@@deriving variants] type 'a p = 'a constraint 'a = [< all] - [@@deriving bin_io, compare, sexp] + [@@deriving bin_io, compare, sexp] type t = all p [@@deriving bin_io, compare, sexp] @@ -1185,6 +1188,10 @@ module Std : sig (** {2 Constructors} *) + + (** [create v w] creates a word from bitvector [v] of width [w].*) + val create : Bitvec.t -> int -> t + (** [of_string s] parses a bitvector from a string representation defined in section {!bv_string}. *) val of_string : string -> t @@ -1242,6 +1249,9 @@ module Std : sig (** {2 Conversions to OCaml built in integer types } *) + (** [to_bitvec x] returns a Bitvec represenation of [x] *) + val to_bitvec : t -> Bitvec.t + (** [to_int x] projects [x] in to OCaml [int]. *) val to_int : t -> int Or_error.t @@ -1438,26 +1448,26 @@ module Std : sig none of the nine preinstantiated suits you. @param prefix defines whether or not a number is prefixed: - - [`auto] (default) - a prefix that corresponds to the chosen + - [`auto] (default) - a prefix that corresponds to the chosen format is printed if it is necessary to disambiguate a number from a decimal representation; - - [`base] - a corresponding prefix is always printed; - - [`none] - the prefix is never printed; - - [`this p] - the user specified prefix [p] is always + - [`base] - a corresponding prefix is always printed; + - [`none] - the prefix is never printed; + - [`this p] - the user specified prefix [p] is always printed; @param suffix defines how the suffix should be printed: - - [`none] (default) - the suffix is never printed; - - [`full] - a full suffix that denotes size and signedness + - [`none] (default) - the suffix is never printed; + - [`full] - a full suffix that denotes size and signedness is printed, e.g., [0xDE:32s] is a signed integer modulo [32]. - - [`size] - only the modulo is printed, e.g., [0xDE:32s] is + - [`size] - only the modulo is printed, e.g., [0xDE:32s] is printed as [0xDE:32] @param format defines the textual representation format: - - [hex] (default) - hexadecimal - - [dec] - decimal - - [oct] - octal - - [bin] - binary (0 and 1). + - [hex] (default) - hexadecimal + - [dec] - decimal + - [oct] - octal + - [bin] - binary (0 and 1). @param case defines the case of hexadecimal letters *) @@ -1712,11 +1722,11 @@ module Std : sig Bitvector comes with 4 predefined prefix trees: - [Trie.Big.Bits] - big endian prefix tree, where each - token is a bit, and bitvector is tokenized from msb to lsb. + token is a bit, and bitvector is tokenized from msb to lsb. - [Trie.Big.Byte] - big endian prefix tree, where each token - is a byte, and bitvector is tokenized from most significant - byte to less significant + is a byte, and bitvector is tokenized from most significant + byte to less significant - [Trie.Little.Bits] - is a little endian bit tree. @@ -1747,9 +1757,9 @@ module Std : sig (** Shortcut for bitvectors that represent addresses *) module Addr : sig include module type of Bitvector - with type t = addr - and type endian = endian - and type comparator_witness = Bitvector.comparator_witness + with type t = addr + and type endian = endian + and type comparator_witness = Bitvector.comparator_witness (** [memref ?disp ?index ?scale base] mimics a memory reference syntax in gas assembler, [dis(base,index,scale)] @@ -1855,7 +1865,8 @@ module Std : sig | Concat of exp * exp (** concatenate two words *) and typ = | Imm of int (** [Imm n] - n-bit immediate *) - | Mem of addr_size * size (** [Mem (a,t)] memory with a specified addr_size *) + | Mem of addr_size * size (** [Mem (a,t)] memory with a specifed addr_size *) + | Unk [@@deriving bin_io, compare, sexp] type stmt = @@ -1886,6 +1897,10 @@ module Std : sig include Printable.S with type t := t include Data.S with type t := t + val domain : stmt list Knowledge.domain + val persistent : stmt list Knowledge.persistent + val slot : (Theory.Program.Semantics.cls, stmt list) Knowledge.slot + (** [printf "%a" pp_binop op] prints a binary operation [op]. *) val pp_binop : binop printer @@ -2470,7 +2485,6 @@ module Std : sig type exp = Bil.exp [@@deriving bin_io, compare, sexp] type stmt = Bil.stmt [@@deriving bin_io, compare, sexp] type unop = Bil.unop [@@deriving bin_io, compare, sexp] - (** The type of a BIL expression. Each BIL expression is either an immediate value of a given @@ -2492,6 +2506,7 @@ module Std : sig type t = Bil.typ = | Imm of int | Mem of addr_size * size + | Unk [@@deriving variants] (** type error *) @@ -2631,6 +2646,10 @@ module Std : sig type t = var + val reify : 'a Theory.var -> t + val ident : t -> Theory.Var.ident + val sort : t -> Theory.Value.Sort.Top.t + (** [create ?register ?fresh name typ] creates a variable with a given [name] and [typ]e. @@ -3137,7 +3156,7 @@ module Std : sig (** BIL {{!Bili}interpreter} @deprecated Use the Primus Framework - *) + *) class ['a] bili : ['a] Bili.t [@@deprecated "[since 2018-03] in favor of the Primus Framework"] @@ -3150,13 +3169,13 @@ module Std : sig following sorts of effects: - coeffects - a value of an expression depends on the outside - world, that is further subdivided by the read effect, when an - expression reads a CPU register, and the load effect, when an - expression an expression accesses the memory. + world, that is further subdivided by the read effect, when an + expression reads a CPU register, and the load effect, when an + expression an expression accesses the memory. - effects - a value modifies the state of the world, by either - storing a value in the memory, or by raising a CPU exception - via the division by zero or accessing the memory. + storing a value in the memory, or by raising a CPU exception + via the division by zero or accessing the memory. An expression that doesn't have effects or coeffects is idempotent and can be moved arbitrary in a tree, removed or @@ -3235,6 +3254,8 @@ module Std : sig module Exp : sig type t = Bil.exp + val slot : (Theory.Value.cls, exp) KB.slot + (** All visitors provide some information about the current position of the visitor *) class state : object @@ -3448,38 +3469,38 @@ module Std : sig The following code simplification are applied: - constant folding: if an expression can be computed - statically then it is substituted with the result of - computation, e.g., [1 + 2 -> 3] + statically then it is substituted with the result of + computation, e.g., [1 + 2 -> 3] - neutral element elimination: binary operations with one of - the operands being known to be neutral, are substituted with - the other operand, e.g., [x * 1 -> x] + the operands being known to be neutral, are substituted with + the other operand, e.g., [x * 1 -> x] - zero element propagation: binary operations applied to a - zero element are substituted with the zero element, e.g., - [x * 0 -> 0] + zero element are substituted with the zero element, e.g., + [x * 0 -> 0] - symbolic equality reduction: if both branches of a - comparison are syntactically equal then the comparison is - reduced to a boolean constant, e.g., [a = a -> true], - [a < a -> false]. Note, by default a read from a register is - considered as a (co)effect, hence the above transformations - wouldn't be applied, consider passing [~ignore:[Eff.reads]] - if you want such expressions to be reduced. + comparison are syntactically equal then the comparison is + reduced to a boolean constant, e.g., [a = a -> true], + [a < a -> false]. Note, by default a read from a register is + considered as a (co)effect, hence the above transformations + wouldn't be applied, consider passing [~ignore:[Eff.reads]] + if you want such expressions to be reduced. - double complement reduction: an odd amount of complement - operations (one and two) are reduced to one complement of - the same sort, e.g., [~~~1 -> ~1] + operations (one and two) are reduced to one complement of + the same sort, e.g., [~~~1 -> ~1] - binary to unary reduction: reduce a subtraction from zero - to the unary negation, e.g., [0 - x -> -x] + to the unary negation, e.g., [0 - x -> -x] - exclusive disjunction reduction: reduces an exclusive - disjunction of syntactically equal expressions to zero, e.g, - [42 ^ 42 -> 0]. Note, by default a read from a register is - considered as a (co)effect, thus [xor eax eax] is not - reduced, consider passing [~ignore:[Eff.reads]] if you want - such expressions to be reduced. + disjunction of syntactically equal expressions to zero, e.g, + [42 ^ 42 -> 0]. Note, by default a read from a register is + considered as a (co)effect, thus [xor eax eax] is not + reduced, consider passing [~ignore:[Eff.reads]] if you want + such expressions to be reduced. @since 1.3 *) @@ -3720,14 +3741,14 @@ module Std : sig language, where expressions have the following properties: - Memory load expressions can be only applied to a memory. This - effectively disallows creation of temporary memory regions, - and requires all store operations to be committed via the - assignment operation. Also, this provides a guarantee, that - store expressions will not occur in integer assignments, jmp - destinations, and conditional expressions, leaving them valid - only in an assignment statement where the rhs has type mem_t. - This is effectively the same as make the [Load] constructor to - have type ([Load (var,exp,endian,size)]). + effectively disallows creation of temporary memory regions, + and requires all store operations to be committed via the + assignment operation. Also, this provides a guarantee, that + store expressions will not occur in integer assignments, jmp + destinations, and conditional expressions, leaving them valid + only in an assignment statement where the rhs has type mem_t. + This is effectively the same as make the [Load] constructor to + have type ([Load (var,exp,endian,size)]). - No load or store expressions in the following positions: 1. the right-hand side of the let expression; @@ -3739,11 +3760,11 @@ module Std : sig puts the following restrictions: - No let expressions - new variables can be created only with - the Move instruction. + the Move instruction. - All memory operations have sizes equal to one byte. Thus the - size and endianness can be ignored in analysis. During the - normalization, the following rewrites are performed + size and endianness can be ignored in analysis. During the + normalization, the following rewrites are performed {v let x = in ... x ... => ... ... x[a,el]:n => x[a+n-1] @ ... @ x[a] @@ -3908,6 +3929,8 @@ module Std : sig (** [endian arch] returns a word endianness of the [arch] *) val endian : t -> endian + val slot : (Theory.program, t option) Knowledge.slot + (** [arch] type implements [Regular] interface *) include Regular.S with type t := t end @@ -4123,9 +4146,15 @@ module Std : sig The returned value of type [T.t tag] is a special key that can be used with [create] and [get] functions to pack and unpack values of type [T.t] into [value]. *) - val register : name:literal -> uuid:literal -> + val register : name:string -> uuid:string -> + (module S with type t = 'a) -> 'a tag + + + val register_slot : (Theory.program,'a option) KB.slot -> (module S with type t = 'a) -> 'a tag + val slot : 'a t -> (Theory.program, 'a option) KB.slot + (** [name cons] returns a name of a constructor. *) val name : 'a t -> string @@ -4342,7 +4371,7 @@ module Std : sig type jmp [@@deriving bin_io, compare, sexp] type nil [@@deriving bin_io, compare, sexp] - type tid [@@deriving bin_io, compare, sexp] + type tid = Theory.Label.t [@@deriving bin_io, compare, sexp] type call [@@deriving bin_io, compare, sexp] (** target of control transfer *) @@ -4515,7 +4544,7 @@ module Std : sig (** BIR {{!Biri}interpreter} @deprecated Use the Primus Framework - *) + *) class ['a] biri : ['a] Biri.t [@@deprecated "[since 2018-03] in favor of the Primus Framework"] @@ -4634,7 +4663,7 @@ module Std : sig The [data] may not be copied and the returned memory view may reference the same bigstring object. - *) + *) val create : ?pos:int -> (** defaults to [0] *) ?len:int -> (** defaults to full length *) @@ -4642,6 +4671,7 @@ module Std : sig addr -> Bigstring.t -> t Or_error.t + val slot : (Theory.program, mem option) Knowledge.slot (** [of_file endian start name] creates a memory region from file. Takes data stored in a file with the given [name] and maps it @@ -4686,8 +4716,19 @@ module Std : sig (** returns the order of bytes in a word *) val endian : t -> endian - (** [get word_size mem addr] reads memory value from the specified - address. [word_size] default to [`r8] *) + (** [get ?disp ?index ?scale ?addr mem] reads a [scale] sized word from [mem]. + + Parameters mimic the reference syntax in the gas assembler, + e.g., [dis(base,index,scale)] denotes address at [base + index * scale + dis]. + + The size of the returned word is equal to [scale], bytes are read in + the [endian mem] order. + + + @param disp is the base offset and defaults to [0] + @param index defaults to [0] + @param scale defaults to [`r8] + *) val get : ?disp:int -> ?index:int -> ?scale:size -> ?addr:addr -> t -> word Or_error.t (** [m^n] dereferences a byte at address [n] *) @@ -4931,21 +4972,21 @@ module Std : sig summarized below: - [one_to_many] means that a particular region from table [t1] can - span several memory regions from table [t2]. Example: segments - to symbols relation. + span several memory regions from table [t2]. Example: segments + to symbols relation. - [one_to_one] means that for each value of type ['a] there is - exactly one value of type ['b]. This relation should be used with - caution, since it is quantified over _all_ values of type - ['a]. Indeed, it should be used only for cases, when it can be - guaranteed, that it is impossible to create such value of type - ['b], that has no correspondence in table [t2]. Otherwise, - [one_to_maybe_one] relation should be used. Example: llvm - machine code to assembly string relation. + exactly one value of type ['b]. This relation should be used with + caution, since it is quantified over _all_ values of type + ['a]. Indeed, it should be used only for cases, when it can be + guaranteed, that it is impossible to create such value of type + ['b], that has no correspondence in table [t2]. Otherwise, + [one_to_maybe_one] relation should be used. Example: llvm + machine code to assembly string relation. - [one_to_maybe_one] means that for each value in table [t1] there - exists at most one value in table [t2]. Example: function to - symbol relation. + exists at most one value in table [t2]. Example: function to + symbol relation. {3 Examples} @@ -5114,7 +5155,7 @@ module Std : sig accessible for loading images. @deprecated Use new Ogre-powered loader interface - *) + *) module Backend : sig (** memory access permissions *) @@ -5316,9 +5357,9 @@ module Std : sig are possible with the following interpretation: - [Ok None] - a loader doesn't know how handle files of this - type. + type. - [Ok (Some doc)] - a loader was able to obtain some - information from the input. + information from the input. - [Error err] - a file was corrupted, according to the loader. *) @@ -5580,8 +5621,8 @@ module Std : sig type disasm (** values of type [insn] represents machine instructions decoded - from the a given piece of memory *) - type insn [@@deriving bin_io, compare, sexp_of] + from a given piece of memory *) + type insn = Theory.Program.Semantics.t [@@deriving bin_io, compare, sexp] (** [block] is a region of memory that is believed to be a basic block of control flow graph to the best of our knowledge. *) @@ -5789,7 +5830,7 @@ module Std : sig applies function [f] to it. Once [f] is evaluated the disassembler is closed with [close] function. *) val with_disasm : - ?debug_level:int -> ?cpu:string -> backend:string -> string -> + ?debug_level:int -> ?cpu:string -> ?backend:string -> string -> f:((empty, empty) t -> 'a Or_error.t) -> 'a Or_error.t (** [create ?debug_level ?cpu ~backend target] creates a @@ -5801,7 +5842,7 @@ module Std : sig [create ~debug_level:3 ~backend:"llvm" "x86_64" ~f:process] *) - val create : ?debug_level:int -> ?cpu:string -> backend:string -> string -> + val create : ?debug_level:int -> ?cpu:string -> ?backend:string -> string -> (empty, empty) t Or_error.t (** [close d] closes a disassembler [d]. *) @@ -5931,6 +5972,8 @@ module Std : sig type ('a,'k) t = ('a,'k) insn + val slot : (Theory.program, full_insn option) Knowledge.slot + (** [sexp_of_t insn] returns a sexp representation of [insn] *) val sexp_of_t : ('a,'k) t -> Sexp.t @@ -5963,6 +6006,7 @@ module Std : sig (** [ops insn] gives an access to [insn]'s operands. *) val ops : ('a,'k) t -> op array + end (** Trie maps over instructions *) @@ -6062,7 +6106,16 @@ module Std : sig *) module Insn : sig - type t = insn [@@deriving bin_io, compare, sexp] + type t = Theory.Program.Semantics.t [@@deriving bin_io, compare, sexp] + + module Slot : sig + type 'a t = (Theory.Program.Semantics.cls, 'a) KB.slot + val name : string t + val asm : string t + val ops : op array option t + val delay : int option t + val dests : Set.M(Theory.Label).t option t + end (** {3 Creating} The following functions will create [insn] instances from a lower @@ -6070,6 +6123,9 @@ module Std : sig *) val of_basic : ?bil:bil -> Disasm_expert.Basic.full_insn -> t + (** [empty] is an instruction with no known semantics *) + val empty : t + (** returns backend specific name of instruction *) val name : t -> string @@ -6124,6 +6180,9 @@ module Std : sig (** instruction is a return from a call *) val return : must property + (** the instruction has no fall-through *) + val barrier : must property + (** the instruction may perform a non-regular control flow *) val affect_control_flow : may property @@ -6189,7 +6248,7 @@ module Std : sig The following invariants must be preserved: - there is no known jump in the program, that points to an - instruction that is not a leader of a basic block; + instruction that is not a leader of a basic block; - any jump instruction is a terminator of some basic block; - each basic block consists of at least one instruction. *) @@ -6421,6 +6480,28 @@ module Std : sig (** Disassembled program. An interface for diassembling things. *) module Disasm : sig + + module Driver : sig + type state + type insns + + val init : state + val scan : mem -> state -> state knowledge + + val explore : + ?entry:addr -> + ?follow:(addr -> bool knowledge) -> + block:(mem -> insns -> 'n knowledge) -> + node:('n -> 'c -> 'c knowledge) -> + edge:('n -> 'n -> 'c -> 'c knowledge) -> + init:'c -> + state -> 'c knowledge + + val list_insns : ?rev:bool -> insns -> Theory.Label.t list + val execution_order : insns -> Theory.Label.t list knowledge + end + + type t = disasm (** [create cfg] *) @@ -6650,6 +6731,11 @@ module Std : sig (** [create ()] creates a fresh newly term identifier *) val create : unit -> t + val for_name : string -> t + val for_addr : addr -> t + val for_ivec : int -> t + + (** [set_name tid name] associates a [name] with a given term identifier [tid]. Any previous associations are overridden.*) @@ -6875,6 +6961,7 @@ module Std : sig (** [del_attr term attr] deletes attribute [attr] from [term] *) val del_attr : 'a t -> 'b tag -> 'a t + (** {Predefined attributes} *) (** a term was artificially produced from a term with a given tid. *) @@ -7013,6 +7100,7 @@ module Std : sig ?jmp:(jmp term -> 'a) -> 't term -> 'a + val slot : (Theory.Program.Semantics.cls, blk term list) Knowledge.slot end (** Program in Intermediate representation. *) @@ -7054,7 +7142,7 @@ module Std : sig (** fixes the result *) val result : t -> program term end - + val pp_slots : string list -> Format.formatter -> t -> unit include Regular.S with type t := t end @@ -7184,7 +7272,7 @@ module Std : sig (** returns current result *) val result : t -> sub term end - + val pp_slots : string list -> Format.formatter -> t -> unit include Regular.S with type t := t end @@ -7389,7 +7477,7 @@ module Std : sig (** returns current result *) val result : t -> blk term end - + val pp_slots : string list -> Format.formatter -> t -> unit include Regular.S with type t := t end @@ -7404,6 +7492,13 @@ module Std : sig type t = def term + val reify : ?tid:tid -> 'a Theory.var -> 'a Theory.value -> t + + val var : t -> unit Theory.var + + val value : t -> unit Theory.value + + (** [create ?tid x exp] creates definition [x := exp] *) val create : ?tid:tid -> var -> exp -> t @@ -7432,6 +7527,8 @@ module Std : sig for more information. *) val free_vars : t -> Var.Set.t + val pp_slots : string list -> Format.formatter -> t -> unit + include Regular.S with type t := t end @@ -7446,19 +7543,37 @@ module Std : sig Jumps are further subdivided into categories: - goto - is a local control transfer instruction. The label - can be only local to subroutine; + can be only local to subroutine; - call - transfer a control to another subroutine. A call - contains a continuation, i.e., a label to which we're hoping - to return after subroutine returns the control to us. Of - course, called subroutine can in general return to another - position, or not to return at all. + contains a continuation, i.e., a label to which we're hoping + to return after subroutine returns the control to us. Of + course, called subroutine can in general return to another + position, or not to return at all. - ret - performs a return from subroutine - int - calls to interrupt subroutine. If interrupt returns, - then continue with the provided label. - *) + then continue with the provided label. + *) type t = jmp term + type dst + + + val reify : ?tid:tid -> + ?cnd:Theory.Bool.t Theory.value -> + ?alt:dst -> ?dst:dst -> unit -> t + + + val guard : t -> Theory.Bool.t Theory.value option + val with_guard : t -> Theory.Bool.t Theory.value option -> t + val dst : t -> dst option + val alt : t -> dst option + + val resolved : tid -> dst + val indirect : 'a Theory.Bitv.t Theory.value -> dst + val resolve : dst -> (tid,'a Theory.Bitv.t Theory.value) Either.t + + (** [create ?cond kind] creates a jump of a given kind *) val create : ?tid:tid -> ?cond:exp -> jmp_kind -> t @@ -7504,6 +7619,7 @@ module Std : sig (** updated jump's kind *) val with_kind : t -> jmp_kind -> t + val pp_slots : string list -> Format.formatter -> t -> unit include Regular.S with type t := t end @@ -7519,6 +7635,15 @@ module Std : sig incoming edge. *) type t = phi term + val reify : ?tid:tid -> + 'a Theory.var -> + (tid * 'a Theory.value) list -> + t + + val var : t -> unit Theory.var + val options : t -> (tid * unit Theory.value) seq + + (** [create var label exp] creates a phi-node that associates a variable [var] with an expression [exp]. This expression should be selected if a control flow enters a block, that owns @@ -7571,6 +7696,7 @@ module Std : sig (** [remove def id] removes definition with a given [id] *) val remove : t -> tid -> t + val pp_slots : string list -> Format.formatter -> t -> unit include Regular.S with type t := t end @@ -7583,6 +7709,14 @@ module Std : sig type t = arg term + val reify : ?tid:tid -> ?intent:intent -> + 'a Theory.var -> + 'a Theory.value -> t + + val var : t -> unit Theory.var + val value : t -> unit Theory.value + + (** [create ?intent var exp] creates an argument. If intent is not specified it is left unknown. *) val create : ?tid:tid -> ?intent:intent -> var -> exp -> t @@ -7865,6 +7999,8 @@ module Std : sig (** symbolizer data type *) type t = symbolizer + val provide : Knowledge.agent -> t -> unit + (** [create fn] creates a symbolizer for a given function *) val create : (addr -> string option) -> t @@ -7887,15 +8023,17 @@ module Std : sig (** [empty] is a symbolizer that knows nothing. *) val empty : t - (** A factory of symbolizers. Use it register and create - symbolizers. *) module Factory : Source.Factory.S with type t = t + [@@deprecated "[since 2019-05] use [provide]"] + end (** Rooter finds starts of functions in the binary. *) module Rooter : sig type t = rooter + val provide : t -> unit + (** [create seq] creates a rooter from a given sequence of addresses *) val create : addr seq -> t @@ -7918,6 +8056,8 @@ module Std : sig (** A factory of rooters. Useful to register custom rooters *) module Factory : Source.Factory.S with type t = t + [@@deprecated "[since 2019-05] use [provide]"] + end (** Brancher is responsible for resolving destinations of branch @@ -7944,7 +8084,10 @@ module Std : sig the instruction [insn], that occupies memory region [mem]. *) val resolve : t -> mem -> full_insn -> dests + val provide : t -> unit + module Factory : Source.Factory.S with type t = t + [@@deprecated "[since 2019-05] use [provide]"] end @@ -8061,6 +8204,17 @@ module Std : sig type project + (**/**) + (* Explicitly undocumented right now, as we will later + republish it as a separate library. + *) + module Toplevel : sig + val set : Knowledge.state -> unit + val reset : unit -> unit + val current : unit -> Knowledge.state + end + (**/**) + (** Disassembled program. Project contains data that we were able to reconstruct during @@ -8078,6 +8232,7 @@ module Std : sig module Project : sig type t = project + type state [@@deriving bin_io] type input (** IO interface to a project data structure. *) @@ -8206,6 +8361,7 @@ module Std : sig or it can just provide an empty information. *) val create : + ?state:state -> ?disassembler:string -> ?brancher:brancher source -> ?symbolizer:symbolizer source -> @@ -8216,6 +8372,8 @@ module Std : sig (** [arch project] reveals the architecture of a loaded file *) val arch : t -> arch + val state : t -> state + (** [disasm project] returns results of disassembling *) val disasm : t -> disasm @@ -8259,18 +8417,18 @@ module Std : sig The following substitutions are supported: - [$section{_name,_addr,_min_addr,_max_addr}] - name of region of file - to which it belongs. For example, in ELF this name will - correspond to the section name + to which it belongs. For example, in ELF this name will + correspond to the section name - [$symbol{_name,_addr,_min_addr,_max_addr}] - name or address - of the symbol to which this memory belongs + of the symbol to which this memory belongs - [$asm] - assembler listing of the memory region - [$bil] - BIL code of the tagged memory region - [$block{_name,_addr,_min_addr,_max_addr}] - name or address of a basic - block to which this region belongs + block to which this region belongs - [$min_addr, $addr] - starting address of a memory region @@ -8437,15 +8595,15 @@ module Std : sig (** An error that can occur when loading or running pass. - [Not_loaded name] pass with a given [name] wasn't loaded for - some reason. This is a very unlikely error, indicating - either a logic error in the plugin system implementation or - something very weird, that we didn't expect. + some reason. This is a very unlikely error, indicating + either a logic error in the plugin system implementation or + something very weird, that we didn't expect. - [Not_loaded name] when we tried to load plugin with a given - [name] we failed to find it in our search paths. + [name] we failed to find it in our search paths. - [Runtime_error (name,exn)] when plugin with a given [name] - was run it raised an [exn]. + was run it raised an [exn]. *) type error = diff --git a/lib/bap/bap_project.ml b/lib/bap/bap_project.ml index 9f6c9cd59..062b70a11 100644 --- a/lib/bap/bap_project.ml +++ b/lib/bap/bap_project.ml @@ -1,5 +1,6 @@ open Core_kernel open Regular.Std +open Bap_core_theory open Graphlib.Std open Bap_future.Std open Bap_types.Std @@ -9,29 +10,63 @@ open Bap_sema.Std open Or_error.Monad_infix open Format +module Driver = Bap_disasm_driver + module Event = Bap_event include Bap_self.Create() let find name = FileUtil.which name -module State = struct +module Kernel = struct + open KB.Syntax + module Driver = Bap_disasm_driver + module Calls = Bap_disasm_calls + module Disasm = Disasm_expert.Recursive + type t = { - tids : Tid.Tid_generator.t; - name : Tid.Name_resolver.t; - vars : Var.Id.t; + default : arch; + state : Driver.state; + calls : Calls.t; + } [@@deriving bin_io] + + let empty arch = { + default = arch; + state = Driver.init; + calls = Calls.empty; } + + let update self mem = + Disasm.scan self.default mem self.state >>= fun state -> + Calls.update self.calls state >>| fun calls -> + {self with state; calls} + + let symtab {state; calls} = Symtab.create state calls + let disasm {state} = + Disasm_expert.Recursive.global_cfg state + + module Toplevel = struct + let result = Toplevel.var "result" + let run k = + Toplevel.put result begin + k >>= fun k -> + disasm k >>= fun g -> + symtab k >>| fun s -> g,s,k + end; + Toplevel.get result + end end -type state = State.t +type state = Kernel.t [@@deriving bin_io] type t = { arch : arch; + core : Kernel.t; + disasm : disasm; memory : value memmap; storage : dict; program : program term; symbols : Symtab.t; - state : state; passes : string list; } [@@deriving fields] @@ -47,6 +82,7 @@ module Info = struct let spec,got_spec = Stream.create () end + module Input = struct type result = { arch : arch; @@ -80,12 +116,28 @@ module Input = struct finish; } + let symtab_agent = + let reliability = KB.Agent.authorative in + KB.Agent.register + ~reliability + ~desc:"extracts symbols from the image symtab entries" + ~package:"bap.std" + "symtab" + + let provide_image image = + let image_symbols = Symbolizer.of_image image in + let image_roots = Rooter.of_image image in + info "providing rooter and symbolizer from image"; + Symbolizer.provide symtab_agent image_symbols; + Rooter.provide image_roots + let of_image ?loader filename = Image.create ?backend:loader filename >>| fun (img,warns) -> List.iter warns ~f:(fun e -> warning "%a" Error.pp e); let spec = Image.spec img in Signal.send Info.got_img img; Signal.send Info.got_spec spec; + provide_image img; let finish proj = { proj with storage = Dict.set proj.storage Image.specification spec; @@ -173,56 +225,6 @@ let roots rooter = match rooter with | None -> [] | Some r -> Rooter.roots r |> Seq.to_list - -let fresh_state () = State.{ - tids = Tid.Tid_generator.fresh (); - name = Tid.Name_resolver.fresh (); - vars = Var.Id.fresh (); - } - -module MVar = struct - type 'a t = { - mutable value : 'a Or_error.t; - mutable updated : bool; - compare : 'a -> 'a -> int; - } - - let create ?(compare=fun _ _ -> 1) x = - {value=Ok x; updated=true; compare} - let peek x = ok_exn x.value - let read x = x.updated <- false; peek x - let is_updated x = x.updated - let write x v = - if x.compare (ok_exn x.value) v <> 0 then x.updated <- true; - x.value <- Ok v - - let fail x err = - x.value <- Error err; - x.updated <- true - - let ignore x = - Result.iter_error x.value ~f:Error.raise; - x.updated <- false - - let from_source s = - let x = create None in - Stream.observe s (function - | Ok v -> write x (Some v) - | Error e -> fail x e); - x - - let from_optional_source ?(default=fun () -> None) = function - | Some s -> from_source s - | None -> match default () with - | None -> create None - | Some s -> from_source s -end - -let phase_triggered phase mvar = - let trigger = MVar.is_updated mvar in - if trigger then Signal.send phase (MVar.read mvar); - trigger - module Cfg = Graphs.Cfg let empty_disasm = Disasm.create Cfg.empty @@ -242,101 +244,51 @@ let union_memory m1 m2 = Memmap.to_sequence m2 |> Seq.fold ~init:m1 ~f:(fun m1 (mem,v) -> Memmap.add m1 mem v) + +let build ?state ~code ~data arch = + let init = match state with + | Some state -> state + | None -> Kernel.empty arch in + let kernel = + Memmap.to_sequence code |> KB.Seq.fold ~init ~f:(fun k (mem,_) -> + Kernel.update k mem) in + let cfg,symbols,core = Kernel.Toplevel.run kernel in + { + core; + disasm = Disasm.create cfg; + program = Program.lift symbols; + symbols; + arch; memory=union_memory code data; + storage = Dict.empty; + passes=[] + } + +let state {core} = core + let create_exn - ?disassembler:backend - ?brancher - ?symbolizer - ?rooter - ?reconstructor + ?state + ?disassembler:_ + ?brancher:_ + ?symbolizer:_ + ?rooter:_ + ?reconstructor:_ (read : input) = - let state = fresh_state () in - let mrooter = - MVar.from_optional_source ~default:Merge.rooter rooter in - let msymbolizer = - MVar.from_optional_source ~default:Merge.symbolizer symbolizer in - let mbrancher = MVar.from_optional_source brancher in - let mreconstructor = MVar.from_optional_source reconstructor in - let cfg = MVar.create ~compare:Cfg.compare Cfg.empty in - let symtab = MVar.create ~compare:Symtab.compare Symtab.empty in - let program = MVar.create ~compare:Program.compare (Program.create ()) in - let task = "loading" in - report_progress ~task ~stage:0 ~total:5 ~note:"reading" (); let {Input.arch; data; code; file; finish} = read () in Signal.send Info.got_file file; Signal.send Info.got_arch arch; Signal.send Info.got_data data; Signal.send Info.got_code code; - let rec loop () = - let updated = MVar.is_updated mbrancher || MVar.is_updated mrooter in - let brancher = MVar.read mbrancher - and rooter = MVar.read mrooter in - let disassemble () = - report_progress ~task ~stage:1 ~note:"disassembling" (); - let run mem = - let dis = - Disasm.With_exn.of_mem ?backend ?brancher ?rooter arch mem in - Disasm.errors dis |> - List.iter ~f:(fun e -> warning "%a" pp_disasm_error e); - Disasm.cfg dis in - Memmap.to_sequence code |> - Seq.fold ~init:Cfg.empty ~f:(fun cfg (mem,_) -> - Graphlib.union (module Cfg) cfg (run mem)) |> - MVar.write cfg in - if updated then disassemble (); - let is_cfg_updated = phase_triggered Info.got_cfg cfg in - let g = MVar.read cfg in - let reconstruct () = - if is_cfg_updated || MVar.is_updated msymbolizer then - let symbolizer = match MVar.read msymbolizer with - | None -> Symbolizer.empty - | Some s -> s in - let name = Symbolizer.resolve symbolizer in - let syms = - let () = report_progress ~task ~stage:2 ~note:"reconstructing" () in - Reconstructor.(run (default name (roots rooter)) g) in - MVar.write symtab syms in - if is_cfg_updated || MVar.is_updated mreconstructor - then match MVar.read mreconstructor with - | Some r -> - MVar.ignore msymbolizer; - report_progress ~task ~stage:2 ~note:"reconstructing" (); - MVar.write symtab (Reconstructor.run r g) - | None -> reconstruct () - else reconstruct (); - let is_symtab_updated = phase_triggered Info.got_symtab symtab in - if is_symtab_updated - then begin - report_progress ~task ~stage:3 ~note:"lifting" (); - MVar.write program (Program.lift (MVar.read symtab)) - end; - let _ = phase_triggered Info.got_program program in - if MVar.is_updated mrooter || - MVar.is_updated mbrancher || - MVar.is_updated msymbolizer || - MVar.is_updated mreconstructor then loop () - else - let disasm = Disasm.create g in - let program = MVar.read program in - report_progress ~task ~stage:4 ~note:"finishing" (); - finish { - disasm; - program; - symbols = MVar.read symtab; - arch; memory=union_memory code data; - storage = Dict.set Dict.empty filename file; - state; passes=[] - } in - loop () + finish @@ build ?state ~code ~data arch let create - ?disassembler ?brancher ?symbolizer ?rooter ?reconstructor input = + ?state ?disassembler ?brancher ?symbolizer ?rooter ?reconstructor input = Or_error.try_with ~backtrace:true (fun () -> create_exn - ?disassembler ?brancher ?symbolizer ?rooter ?reconstructor input) + ?state ?disassembler ?brancher ?symbolizer ?rooter ?reconstructor input) -let restore_state {state={State.tids; name}} = - Tid.Tid_generator.store tids; - Tid.Name_resolver.store name +let restore_state _ = + failwith "Project.restore_state: this function should no be used. + Please use the Toplevel module to save/restore the state." let with_memory = Field.fset Fields.memory let with_symbols = Field.fset Fields.symbols @@ -553,18 +505,6 @@ module type S = sig module Factory : Bap_disasm_source.Factory with type t = t end -let register x = - let module S = (val x : S) in - let stream = - Stream.map Info.img ~f:(fun img -> - Or_error.try_with (fun () -> S.of_image img)) in - S.Factory.register "internal" stream - -let () = - register (module Brancher); - register (module Rooter); - register (module Symbolizer) - include Data.Make(struct type nonrec t = t let version = "1.0.0" diff --git a/lib/bap/bap_project.mli b/lib/bap/bap_project.mli index e38395475..3016e0e1b 100644 --- a/lib/bap/bap_project.mli +++ b/lib/bap/bap_project.mli @@ -1,3 +1,5 @@ +open Bap_knowledge + open Core_kernel open Regular.Std open Bap_future.Std @@ -11,9 +13,14 @@ type t type project = t type pass [@@deriving sexp_of] type input +type state [@@deriving bin_io] type second = float +val state : t -> state + + val create : + ?state:state -> ?disassembler:string -> ?brancher:brancher source -> ?symbolizer:symbolizer source -> @@ -71,7 +78,7 @@ module Pass : sig type error = | Unsat_dep of pass * string | Runtime_error of pass * exn - [@@deriving sexp_of] + [@@deriving sexp_of] exception Failed of error [@@deriving sexp] diff --git a/lib/bap_build/bap_build.ml b/lib/bap_build/bap_build.ml index 90c18557c..87d355da8 100644 --- a/lib/bap_build/bap_build.ml +++ b/lib/bap_build/bap_build.ml @@ -1,10 +1,12 @@ module Plugin_rules = struct module Fl = Findlib + open Printf open Ocamlbuild_plugin - open Core_kernel module Ocamlbuild = Ocamlbuild_pack + module List = ListLabels + module String = Ocamlbuild_plugin.String let (/) = Pathname.concat let () = @@ -28,7 +30,7 @@ module Plugin_rules = struct let needs_threads ~predicates pkgs = let deps = Fl.package_deep_ancestors predicates pkgs in - List.mem deps "threads" ~equal:String.equal + List.mem ~set:deps "threads" let infer_thread_predicates ~predicates pkg = if needs_threads ~predicates pkg @@ -59,6 +61,7 @@ module Plugin_rules = struct topological_closure ~predicates:(bap_predicates ~native:true) packages + let findlibs ?(native=true) ?(predicates=pkg_predicates ~native) @@ -69,17 +72,17 @@ module Plugin_rules = struct else predicates in let arch,preds = Fl.package_property_2 preds pkg "archive" in let base = Fl.package_directory pkg in - if dynamic && not (List.mem ~equal:Polymorphic_compare.equal preds (`Pred "plugin")) - then raise Caml.Not_found; - String.split ~on:' ' arch |> + if dynamic && not (List.mem ~set:preds (`Pred "plugin")) + then raise Not_found; + String.split_on_char ' ' arch |> List.map ~f:(Fl.resolve_path ~base) - with Caml.Not_found -> [] + with Not_found -> [] let externals pkgs = let interns = interns () in pkgs |> topological_closure ~predicates:(pkg_predicates ~native:true) |> - List.filter ~f:(fun dep -> not (List.mem ~equal:String.equal interns dep)) + List.filter ~f:(fun dep -> not (List.mem ~set:interns dep)) let packages () = externals !Options.ocaml_pkgs @@ -125,21 +128,24 @@ module Plugin_rules = struct | xs -> List.map xs ~f:(fun src -> cp src Pathname.current_dir_name) + + let concat_map xs ~f = List.(concat (map xs ~f)) + let generate_plugins_for_packages () = packages () |> - List.concat_map ~f:(fun name -> - List.concat_map [`native; `byte] ~f:(fun code -> + concat_map ~f:(fun name -> + concat_map [`native; `byte] ~f:(fun code -> generate_plugin_for_package code name)) let make_list_option option = function | [] -> N - | xs -> S [A option; A (String.concat ~sep:"," xs)] + | xs -> S [A option; A (String.concat "," xs)] let is_cmx file = Filename.check_suffix file ".cmx" let bundle env = let requires = - packages () |> List.concat_map ~f:(fun pkg -> + packages () |> concat_map ~f:(fun pkg -> findlibs ~dynamic:false pkg |> List.map ~f:(fun path -> let name = path |> @@ -174,7 +180,7 @@ module Plugin_rules = struct rule "bap: cmxs & packages -> bundle" ~deps:["%.cmxs"] ~stamp:"%.bundle" - (fun env _ -> Seq (generate_plugins_for_packages ())) + (fun _ _ -> Seq (generate_plugins_for_packages ())) let register_plugin_rule () = rule "bap: cmxs & cma & bundle -> plugin" @@ -182,8 +188,8 @@ module Plugin_rules = struct ~deps:["%.bundle"; "%.cmxs"; "%.cma"] (fun env _ -> Seq [bundle env; symlink env]) -let pass_pp_to_link_phase () = - pflag ["ocaml"; "link"] "pp" (fun s -> S [A "-pp"; A s]) + let pass_pp_to_link_phase () = + pflag ["ocaml"; "link"] "pp" (fun s -> S [A "-pp"; A s]) let install () = register_cmxs_of_cmxa_rule (); diff --git a/lib/bap_c/bap_c_abi.ml b/lib/bap_c/bap_c_abi.ml index 0bff62f90..280d7d74d 100644 --- a/lib/bap_c/bap_c_abi.ml +++ b/lib/bap_c/bap_c_abi.ml @@ -35,7 +35,7 @@ let arg_intent : ctype -> intent = function type error = [ | `Unknown_interface of string | `Parser_error of string * Error.t -] +] [@@deriving sexp_of] type param = Bap_c_data.t * exp @@ -51,7 +51,7 @@ type t = { } -exception Failed of error +exception Failed of error [@@deriving sexp_of] let fail x = raise (Failed x) let data (size : #Bap_c_size.base) (t : Bap_c_type.t) = diff --git a/lib/bap_core_theory/.merlin b/lib/bap_core_theory/.merlin new file mode 100644 index 000000000..093ce10a1 --- /dev/null +++ b/lib/bap_core_theory/.merlin @@ -0,0 +1,3 @@ +REC +B ../../_build/lib/bap_core_theory +B ../../_build/lib/knowledge diff --git a/lib/bap_core_theory/bap_core_theory.ml b/lib/bap_core_theory/bap_core_theory.ml new file mode 100644 index 000000000..e2e472e69 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory.ml @@ -0,0 +1,68 @@ +open Bap_knowledge + +module KB = Knowledge + +module Theory = struct + module Value = Bap_core_theory_value + module Bool = Value.Bool + module Bitv = Value.Bitv + module Mem = Value.Mem + module Float = Value.Float + module Rmode = Value.Rmode + module Effect = Bap_core_theory_effect + module Var = Bap_core_theory_var + module Program = Bap_core_theory_program + module Label = Program.Label + type program = Program.cls + + + type 'a value = 'a Value.t + type 'a effect = 'a Effect.t + type 'a pure = 'a value knowledge + type 'a eff = 'a effect knowledge + + + type bool = Bool.t pure + type 'a bitv = 'a Bitv.t pure + type ('a,'b) mem = ('a,'b) Mem.t pure + type 'f float = 'f Float.t pure + type rmode = Rmode.t pure + + type data = Effect.Sort.data + type ctrl = Effect.Sort.ctrl + + type ('r,'s) format = ('r,'s) Float.format + + type 'a var = 'a Var.t + type word = Bitvec.t + type label = program Knowledge.Object.t + module type Init = Bap_core_theory_definition.Init + module type Bool = Bap_core_theory_definition.Bool + module type Bitv = Bap_core_theory_definition.Bitv + module type Memory = Bap_core_theory_definition.Memory + module type Effect = Bap_core_theory_definition.Effect + module type Minimal = Bap_core_theory_definition.Minimal + module type Basic = Bap_core_theory_definition.Basic + module type Fbasic = Bap_core_theory_definition.Fbasic + module type Float = Bap_core_theory_definition.Float + module type Trans = Bap_core_theory_definition.Trans + module type Core = Bap_core_theory_definition.Core + + module Basic = struct + module Empty : Basic = Bap_core_theory_empty.Core + module Make = Bap_core_theory_basic.Make + end + + module Core = struct + module Empty : Core = Bap_core_theory_empty.Core + end + + module Manager = Bap_core_theory_manager.Theory + let register = Bap_core_theory_manager.register + + module IEEE754 = Bap_core_theory_IEEE754 + + module Grammar = Bap_core_theory_grammar_definition + module Parser = Bap_core_theory_parser + +end diff --git a/lib/bap_core_theory/bap_core_theory.mli b/lib/bap_core_theory/bap_core_theory.mli new file mode 100644 index 000000000..14fb8de24 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory.mli @@ -0,0 +1,723 @@ +open Core_kernel +open Caml.Format + +open Bap_knowledge + +module KB = Knowledge + +module Theory : sig + module Value : sig + type +'a sort + type cls + + type 'a t = (cls,'a sort) KB.cls KB.value + val cls : (cls,unit) KB.cls + + val empty : 'a sort -> 'a t + val sort : 'a t -> 'a sort + + module Sort : sig + type +'a t = 'a sort + type +'a sym + type +'a num + type name + + val sym : name -> 'a sym sort + val int : int -> 'a num sort + val app : 'a sort -> 'b sort -> ('a -> 'b) sort + val (@->) : 'a sort -> 'b sort -> ('a -> 'b) sort + + val value : 'a num sort -> int + val name : 'a sym sort -> name + + val hd : ('a -> 'b) sort -> 'a sort + val tl : ('a -> 'b) sort -> 'b sort + + val refine : name -> unit sort -> 'a t option + + val forget : 'a t -> unit t + + val same : 'a t -> 'b t -> bool + + val pp : formatter -> 'a t -> unit + + module Top : sig + type t = unit sort [@@deriving bin_io, compare, sexp] + include Base.Comparable.S with type t := t + end + + module Name : sig + type t + val declare : ?package:string -> string -> name + include Base.Comparable.S with type t := t + end + end + end + + module Effect : sig + type +'a sort + type cls + + type 'a t = (cls,'a sort) KB.cls KB.value + val cls : (cls,unit) KB.cls + + val empty : 'a sort -> 'a t + val sort : 'a t -> 'a sort + + module Sort : sig + type +'a t = 'a sort + type data = private Data + type ctrl = private Ctrl + + val data : string -> data t + val ctrl : string -> ctrl t + + val top : unit t + val bot : 'a t + + val both : 'a t -> 'a t -> 'a t + val (&&) : 'a t -> 'a t -> 'a t + val union : 'a t list -> 'a t + val join : 'a t list -> 'b t list -> unit t + + val order : 'a t -> 'b t -> KB.Order.partial + + + val rreg : data t + val wreg : data t + val rmem : data t + val wmem : data t + val barr : data t + val fall : ctrl t + val jump : ctrl t + val cjmp : ctrl t + end + end + + type 'a value = 'a Value.t + type 'a effect = 'a Effect.t + + module Bool : sig + type t + val t : t Value.sort + val refine : unit Value.sort -> t Value.sort option + end + + + module Bitv : sig + type 'a t + val define : int -> 'a t Value.sort + val refine : unit Value.sort -> 'a t Value.sort option + val size : 'a t Value.sort -> int + end + + module Mem : sig + type ('a,'b) t + val define : 'a Bitv.t Value.sort -> 'b Bitv.t Value.sort -> ('a,'b) t Value.sort + val refine : unit Value.sort -> ('a,'b) t Value.sort option + val keys : ('a,'b) t Value.sort -> 'a Bitv.t Value.sort + val vals : ('a,'b) t Value.sort -> 'b Bitv.t Value.sort + end + + module Float : sig + module Format : sig + type ('r,'s) t + val define : 'r Value.sort -> 's Bitv.t Value.sort -> ('r,'s) t Value.sort + val bits : ('r,'s) t Value.sort -> 's Bitv.t Value.sort + val exp : ('r,'s) t Value.sort -> 'r Value.sort + end + + type ('r,'s) format = ('r,'s) Format.t + type 'f t + + val define : ('r,'s) format Value.sort -> ('r,'s) format t Value.sort + val refine : unit Value.sort -> ('r,'s) format t Value.sort option + val format : ('r,'s) format t Value.sort -> ('r,'s) format Value.sort + val size : ('r,'s) format t Value.sort -> 's Bitv.t Value.sort + end + + module Rmode : sig + type t + val t : t Value.sort + val refine : unit Value.sort -> t Value.sort option + end + + type 'a pure = 'a value knowledge + type 'a eff = 'a effect knowledge + + type ('r,'s) format = ('r,'s) Float.format + + module Var : sig + type 'a t + type ident [@@deriving bin_io, compare, sexp] + type ord + + val define : 'a Value.sort -> string -> 'a t + val create : 'a Value.sort -> ident -> 'a t + val forget : 'a t -> unit t + val resort : 'a t -> 'b Value.sort -> 'b t + + val versioned: 'a t -> int -> 'a t + val version : 'a t -> int + + val ident : 'a t -> ident + val name : 'a t -> string + val sort : 'a t -> 'a Value.sort + val is_virtual : 'a t -> bool + val is_mutable : 'a t -> bool + val fresh : 'a Value.sort -> 'a t knowledge + val scoped : 'a Value.sort -> ('a t -> 'b pure) -> 'b pure + + module Ident : sig + type t = ident [@@deriving bin_io, compare, sexp] + include Stringable.S with type t := t + include Base.Comparable.S with type t := t + and type comparator_witness = ord + end + + module Top : sig + type nonrec t = unit t [@@deriving bin_io, compare, sexp] + include Base.Comparable.S with type t := t + end + end + + type data = Effect.Sort.data + type ctrl = Effect.Sort.ctrl + + type word = Bitvec.t + type 'a var = 'a Var.t + + type program + type label = program KB.Object.t + + module Program : sig + type t = (program,unit) KB.cls KB.value + val cls : (program,unit) KB.cls + module Semantics : sig + type cls = Effect.cls + type t = unit Effect.t + val cls : (cls, unit Effect.sort) Knowledge.cls + val slot : (program, t) Knowledge.slot + include Knowledge.Value.S with type t := t + end + include Knowledge.Value.S with type t := t + end + + module Label : sig + type t = label + + val addr : (program, Bitvec.t option) KB.slot + val name : (program, string option) KB.slot + val ivec : (program, int option) KB.slot + val aliases : (program, Set.M(String).t) KB.slot + + val is_valid : (program, bool option) KB.slot + val is_subroutine : (program, bool option) KB.slot + + val for_addr : Bitvec.t -> t knowledge + val for_name : string -> t knowledge + val for_ivec : int -> t knowledge + + include Knowledge.Object.S with type t := t + end + + + type bool = Bool.t pure + type 'a bitv = 'a Bitv.t pure + type ('a,'b) mem = ('a,'b) Mem.t pure + type 'f float = 'f Float.t pure + type rmode = Rmode.t pure + + module type Init = sig + val var : 'a var -> 'a pure + val unk : 'a Value.sort -> 'a pure + val let_ : 'a var -> 'a pure -> 'b pure -> 'b pure + end + + module type Bool = sig + val b0 : bool + val b1 : bool + val inv : bool -> bool + val and_ : bool -> bool -> bool + val or_ : bool -> bool -> bool + end + + module type Bitv = sig + val int : 'a Bitv.t Value.sort -> word -> 'a bitv + val msb : 'a bitv -> bool + val lsb : 'a bitv -> bool + val neg : 'a bitv -> 'a bitv + val not : 'a bitv -> 'a bitv + val add : 'a bitv -> 'a bitv -> 'a bitv + val sub : 'a bitv -> 'a bitv -> 'a bitv + val mul : 'a bitv -> 'a bitv -> 'a bitv + val div : 'a bitv -> 'a bitv -> 'a bitv + val sdiv : 'a bitv -> 'a bitv -> 'a bitv + val modulo : 'a bitv -> 'a bitv -> 'a bitv + val smodulo : 'a bitv -> 'a bitv -> 'a bitv + val logand : 'a bitv -> 'a bitv -> 'a bitv + val logor : 'a bitv -> 'a bitv -> 'a bitv + val logxor : 'a bitv -> 'a bitv -> 'a bitv + val shiftr : bool -> 'a bitv -> 'b bitv -> 'a bitv + val shiftl : bool -> 'a bitv -> 'b bitv -> 'a bitv + val ite : bool -> 'a pure -> 'a pure -> 'a pure + val sle : 'a bitv -> 'a bitv -> bool + val ule : 'a bitv -> 'a bitv -> bool + val cast : 'a Bitv.t Value.sort -> bool -> 'b bitv -> 'a bitv + val concat : 'a Bitv.t Value.sort -> 'b bitv list -> 'a bitv + val append : 'a Bitv.t Value.sort -> 'b bitv -> 'c bitv -> 'a bitv + end + + module type Memory = sig + val load : ('a,'b) mem -> 'a bitv -> 'b bitv + val store : ('a,'b) mem -> 'a bitv -> 'b bitv -> ('a,'b) mem + end + + module type Effect = sig + val perform : 'a Effect.sort -> 'a eff + val set : 'a var -> 'a pure -> data eff + val jmp : _ bitv -> ctrl eff + val goto : label -> ctrl eff + val seq : 'a eff -> 'a eff -> 'a eff + val blk : label -> data eff -> ctrl eff -> unit eff + val repeat : bool -> data eff -> data eff + val branch : bool -> 'a eff -> 'a eff -> 'a eff + end + + + module type Minimal = sig + include Init + include Bool + include Bitv + include Memory + include Effect + end + + module type Basic = sig + include Minimal + val zero : 'a Bitv.t Value.sort -> 'a bitv + val is_zero : 'a bitv -> bool + val non_zero : 'a bitv -> bool + val succ : 'a bitv -> 'a bitv + val pred : 'a bitv -> 'a bitv + val nsucc : 'a bitv -> int -> 'a bitv + val npred : 'a bitv -> int -> 'a bitv + val high : 'a Bitv.t Value.sort -> 'b bitv -> 'a bitv + val low : 'a Bitv.t Value.sort -> 'b bitv -> 'a bitv + val signed : 'a Bitv.t Value.sort -> 'b bitv -> 'a bitv + val unsigned : 'a Bitv.t Value.sort -> 'b bitv -> 'a bitv + val extract : 'a Bitv.t Value.sort -> 'b bitv -> 'b bitv -> _ bitv -> 'a bitv + val loadw : 'c Bitv.t Value.sort -> bool -> ('a, _) mem -> 'a bitv -> 'c bitv + val storew : bool -> ('a, 'b) mem -> 'a bitv -> 'c bitv -> ('a, 'b) mem + val arshift : 'a bitv -> 'b bitv -> 'a bitv + val rshift : 'a bitv -> 'b bitv -> 'a bitv + val lshift : 'a bitv -> 'b bitv -> 'a bitv + val eq : 'a bitv -> 'a bitv -> bool + val neq : 'a bitv -> 'a bitv -> bool + val slt : 'a bitv -> 'a bitv -> bool + val ult : 'a bitv -> 'a bitv -> bool + val sgt : 'a bitv -> 'a bitv -> bool + val ugt : 'a bitv -> 'a bitv -> bool + val sge : 'a bitv -> 'a bitv -> bool + val uge : 'a bitv -> 'a bitv -> bool + end + + module type Fbasic = sig + val float : ('r,'s) format Float.t Value.sort -> 's bitv -> ('r,'s) format float + val fbits : ('r,'s) format float -> 's bitv + + + val is_finite : 'f float -> bool + val is_nan : 'f float -> bool + val is_inf : 'f float -> bool + val is_fzero : 'f float -> bool + val is_fpos : 'f float -> bool + val is_fneg : 'f float -> bool + + val rne : rmode + val rna : rmode + val rtp : rmode + val rtn : rmode + val rtz : rmode + val requal : rmode -> rmode -> bool + + val cast_float : 'f Float.t Value.sort -> rmode -> 'a bitv -> 'f float + val cast_sfloat : 'f Float.t Value.sort -> rmode -> 'a bitv -> 'f float + val cast_int : 'a Bitv.t Value.sort -> rmode -> 'f float -> 'a bitv + val cast_sint : 'a Bitv.t Value.sort -> rmode -> 'f float -> 'a bitv + + val fneg : 'f float -> 'f float + val fabs : 'f float -> 'f float + + val fadd : rmode -> 'f float -> 'f float -> 'f float + val fsub : rmode -> 'f float -> 'f float -> 'f float + val fmul : rmode -> 'f float -> 'f float -> 'f float + val fdiv : rmode -> 'f float -> 'f float -> 'f float + val fsqrt : rmode -> 'f float -> 'f float + val fmodulo : rmode -> 'f float -> 'f float -> 'f float + val fmad : rmode -> 'f float -> 'f float -> 'f float -> 'f float + + val fround : rmode -> 'f float -> 'f float + val fconvert : 'f Float.t Value.sort -> rmode -> _ float -> 'f float + + val fsucc : 'f float -> 'f float + val fpred : 'f float -> 'f float + val forder : 'f float -> 'f float -> bool + end + + module type Float = sig + include Fbasic + val pow : rmode -> 'f float -> 'f float -> 'f float + val powr : rmode -> 'f float -> 'f float -> 'f float + val compound : rmode -> 'f float -> 'a bitv -> 'f float + val rootn : rmode -> 'f float -> 'a bitv -> 'f float + val pownn : rmode -> 'f float -> 'a bitv -> 'f float + val rsqrt : rmode -> 'f float -> 'f float + val hypot : rmode -> 'f float -> 'f float -> 'f float + end + + module type Trans = sig + val exp : rmode -> 'f float -> 'f float + val expm1 : rmode -> 'f float -> 'f float + val exp2 : rmode -> 'f float -> 'f float + val exp2m1 : rmode -> 'f float -> 'f float + val exp10 : rmode -> 'f float -> 'f float + val exp10m1 : rmode -> 'f float -> 'f float + val log : rmode -> 'f float -> 'f float + val log2 : rmode -> 'f float -> 'f float + val log10 : rmode -> 'f float -> 'f float + val logp1 : rmode -> 'f float -> 'f float + val log2p1 : rmode -> 'f float -> 'f float + val log10p1 : rmode -> 'f float -> 'f float + val sin : rmode -> 'f float -> 'f float + val cos : rmode -> 'f float -> 'f float + val tan : rmode -> 'f float -> 'f float + val sinpi : rmode -> 'f float -> 'f float + val cospi : rmode -> 'f float -> 'f float + val atanpi : rmode -> 'f float -> 'f float + val atan2pi : rmode -> 'f float -> 'f float -> 'f float + val asin : rmode -> 'f float -> 'f float + val acos : rmode -> 'f float -> 'f float + val atan : rmode -> 'f float -> 'f float + val atan2 : rmode -> 'f float -> 'f float -> 'f float + val sinh : rmode -> 'f float -> 'f float + val cosh : rmode -> 'f float -> 'f float + val tanh : rmode -> 'f float -> 'f float + val asinh : rmode -> 'f float -> 'f float + val acosh : rmode -> 'f float -> 'f float + val atanh : rmode -> 'f float -> 'f float + end + + + module type Core = sig + include Basic + include Float + include Trans + end + + module Basic : sig + module Make(S : Minimal) : Basic + module Empty : Basic + end + + module Core : sig + module Empty : Core + end + + module Manager : Core + + val register : ?desc:string -> name:string -> (module Core) -> unit + + + module IEEE754 : sig + type ('a,'e,'t) t + type ('a,'e,'t) ieee754 = ('a,'e,'t) t + (* see IEEE754 3.6 *) + type parameters = private { + base : int; + bias : int; + k : int; + p : int; + w : int; + t : int; + } + + + val binary16 : parameters + val binary32 : parameters + val binary64 : parameters + val binary80 : parameters + val binary128 : parameters + val decimal32 : parameters + val decimal64 : parameters + val decimal128 : parameters + + val binary : int -> parameters option + val decimal : int -> parameters option + + module Sort : sig + val define : parameters -> (('b,'e,'t) ieee754,'s) format Float.t Value.sort + val exps : (('b,'e,'t) ieee754,'s) format Float.t Value.sort -> 'e Bitv.t Value.sort + val sigs : (('b,'e,'t) ieee754,'s) format Float.t Value.sort -> 't Bitv.t Value.sort + val bits : (('b,'e,'t) ieee754,'s) format Float.t Value.sort -> 's Bitv.t Value.sort + val spec : (('b,'e,'t) ieee754,'s) format Float.t Value.sort -> parameters + end + end + + + module Grammar : sig + type ieee754 = IEEE754.parameters + module type Bitv = sig + type t + type exp + type rmode + + val error : t + + val unsigned : int -> exp -> t + val signed : int -> exp -> t + val high : int -> exp -> t + val low : int -> exp -> t + val cast : int -> exp -> exp -> t + val extract : int -> exp -> exp -> exp -> t + + val add : exp -> exp -> t + val sub : exp -> exp -> t + val mul : exp -> exp -> t + val div : exp -> exp -> t + val sdiv : exp -> exp -> t + val modulo : exp -> exp -> t + val smodulo : exp -> exp -> t + val lshift : exp -> exp -> t + val rshift : exp -> exp -> t + val arshift : exp -> exp -> t + val logand : exp -> exp -> t + val logor: exp -> exp -> t + val logxor : exp -> exp -> t + + val neg : exp -> t + val not : exp -> t + + val load_word : int -> exp -> exp -> exp -> t + val load : exp -> exp -> t + + + val var : string -> int -> t + val int : word -> int -> t + val unknown : int -> t + val ite : exp -> exp -> exp -> t + + val let_bit : string -> exp -> exp -> t + val let_reg : string -> exp -> exp -> t + val let_mem : string -> exp -> exp -> t + val let_float : string -> exp -> exp -> t + + val append : exp -> exp -> t + val concat : exp list -> t + + val cast_int : int -> rmode -> exp -> t + val cast_sint : int -> rmode -> exp -> t + val fbits : exp -> t + end + + module type Bool = sig + type t + type exp + + val error : t + + val eq : exp -> exp -> t + val neq : exp -> exp -> t + val lt : exp -> exp -> t + val le : exp -> exp -> t + val slt : exp -> exp -> t + val sle : exp -> exp -> t + val var : string -> t + val int : word -> t + val unknown : unit -> t + val ite : exp -> exp -> exp -> t + val let_bit : string -> exp -> exp -> t + val let_reg : string -> exp -> exp -> t + val let_mem : string -> exp -> exp -> t + val let_float : string -> exp -> exp -> t + + val high : exp -> t + val low : exp -> t + val extract : int -> exp -> t + + val not : exp -> t + val logand : exp -> exp -> t + val logor: exp -> exp -> t + val logxor : exp -> exp -> t + + val is_inf : exp -> t + val is_nan : exp -> t + val is_fzero : exp -> t + val is_fpos : exp -> t + val is_fneg : exp -> t + + val fle : exp -> exp -> t + val flt : exp -> exp -> t + val feq : exp -> exp -> t + end + + + module type Mem = sig + type t + type exp + + val error : t + + (** [store mem key data] *) + val store : exp -> exp -> exp -> t + + + (** [store_word dir mem key data ] *) + val store_word : exp -> exp -> exp -> exp -> t + val var : string -> int -> int -> t + val unknown : int -> int -> t + val ite : exp -> exp -> exp -> t + val let_bit : string -> exp -> exp -> t + val let_reg : string -> exp -> exp -> t + val let_mem : string -> exp -> exp -> t + val let_float : string -> exp -> exp -> t + end + + module type Stmt = sig + type t + type exp + type rmode + type stmt + + val error : t + + val set_mem : string -> int -> int -> exp -> t + val set_reg : string -> int -> exp -> t + val set_bit : string -> exp -> t + val set_ieee754 : string -> ieee754 -> exp -> t + val set_rmode : string -> rmode -> t + + val tmp_mem : string -> exp -> t + val tmp_reg : string -> exp -> t + val tmp_bit : string -> exp -> t + val tmp_float : string -> exp -> t + val tmp_rmode : string -> rmode -> t + + val let_mem : string -> exp -> stmt -> t + val let_reg : string -> exp -> stmt -> t + val let_bit : string -> exp -> stmt -> t + val let_float : string -> exp -> stmt -> t + val let_rmode : string -> rmode -> stmt -> t + + val jmp : exp -> t + val goto : word -> t + val call : string -> t + val special : string -> t + val cpuexn : int -> t + + val while_ : exp -> stmt list -> t + val if_ : exp -> stmt list -> stmt list -> t + + val seq : stmt list -> t + end + + module type Float = sig + type t + type exp + type rmode + + val error : t + + val ieee754 : ieee754 -> exp -> t + val ieee754_var : ieee754 -> string -> t + val ieee754_unk : ieee754 -> t + val ieee754_cast : ieee754 -> rmode -> exp -> t + val ieee754_cast_signed : ieee754 -> rmode -> exp -> t + val ieee754_convert : ieee754 -> rmode -> exp -> t + + val ite : exp -> exp -> exp -> t + + val fadd : rmode -> exp -> exp -> t + val fsub : rmode -> exp -> exp -> t + val fmul : rmode -> exp -> exp -> t + val fdiv : rmode -> exp -> exp -> t + val frem : rmode -> exp -> exp -> t + val fmin : exp -> exp -> t + val fmax : exp -> exp -> t + + val fabs : exp -> t + val fneg : exp -> t + val fsqrt : rmode -> exp -> t + val fround : rmode -> exp -> t + + val let_bit : string -> exp -> exp -> t + val let_reg : string -> exp -> exp -> t + val let_mem : string -> exp -> exp -> t + val let_float : string -> exp -> exp -> t + end + + module type Rmode = sig + type t + type exp + + val error : t + + val rne : t + val rtz : t + val rtp : t + val rtn : t + val rna : t + end + end + + module Parser : sig + type ('a,'e,'r) bitv_parser = + (module Grammar.Bitv with type t = 'a + and type exp = 'e + and type rmode = 'r) -> + 'e -> 'a + + type ('a,'e,'r) bool_parser = + (module Grammar.Bool with type t = 'a + and type exp = 'e) -> + 'e -> 'a + + type ('a,'e) mem_parser = + (module Grammar.Mem with type t = 'a + and type exp = 'e) -> + 'e -> 'a + + type ('a,'e,'r,'s) stmt_parser = + (module Grammar.Stmt with type t = 'a + and type exp = 'e + and type stmt = 's + and type rmode = 'r) -> + 's -> 'a + + type ('a,'e,'r) float_parser = + (module Grammar.Float with type t = 'a + and type exp = 'e + and type rmode = 'r) -> + 'e -> 'a + + type ('a,'e) rmode_parser = + (module Grammar.Rmode with type t = 'a + and type exp = 'e) -> + 'e -> 'a + + type ('e,'r,'s) t = { + bitv : 'a. ('a,'e,'r) bitv_parser; + bool : 'a. ('a,'e,'r) bool_parser; + mem : 'a. ('a,'e) mem_parser; + stmt : 'a. ('a,'e,'r,'s) stmt_parser; + float : 'a . ('a,'e,'r) float_parser; + rmode : 'a . ('a,'r) rmode_parser; + } + + type ('e,'r,'s) parser = ('e,'r,'s) t + + module Make(S : Core) : sig + val run : ('e,'r,'s) parser -> 's list -> unit eff + end + end +end diff --git a/lib/bap_core_theory/bap_core_theory_IEEE754.ml b/lib/bap_core_theory/bap_core_theory_IEEE754.ml new file mode 100644 index 000000000..3e3c33d15 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_IEEE754.ml @@ -0,0 +1,135 @@ +open Core_kernel +open Bap_core_theory_value + +type 'a num = 'a Sort.num +type 'a sym = 'a Sort.sym +type witness + +type ('b,'e,'t) ieee754 = ('b num -> 'e num -> 't num -> witness sym) +type ('b,'e,'t) t = ('b,'e,'t) ieee754 + +(* see IEEE754 3.6 *) +type parameters = { + base : int; + bias : int; + k : int; + p : int; + w : int; + t : int; +} + +let (^) b e = Int.of_float (float b ** float e) + +let log2 n = log n /. log 2. +let round x = round ~dir:`Nearest x + +let prec k = + let k = float k in Int.of_float @@ + k -. round(4. *. log2 k) +. 13. + +let ebits k = + let k = float k in Int.of_float @@ + round(4. *. log2 k) -. 13. + +let bias k p = (2^(k-p-1))-1 + +let binary k = + let p = prec k and w = ebits k in { + base = 2; + k; w; p; + t = k - w - 1; + bias = bias k p; + } + +let decimal k = + let p = 9 * k/32 - 2 in + let exp = k / 16 + 3 in + let emax = 3 * Int.of_float (2. ** float exp) in { + base = 10; + bias = emax + p - 2; + w = k / 16 + 9; + t = 15 * k / 16 - 10; + k; p; + } + +let binary16 = { + base = 2; + bias = 15; + k = 16; + p = 11; + w = 5; + t = 10; +} + +let binary32 = { + base = 2; + bias = 127; + k = 32; + p = 24; + w = 8; + t = 23; +} + +let binary80 = { + base = 2; + bias = 16383; + k = 80; + p = 64; + w = 15; + t = 64; +} + +let binary64 = binary 64 +let binary128 = binary 128 +let decimal32 = decimal 32 +let decimal64 = decimal 64 +let decimal128 = decimal 128 + +let binary = function + | 16 -> binary16 + | 32 -> binary32 + | 80 -> binary80 + | k -> binary k + +module Sort = struct + let ieee754 = Sort.Name.declare ~package:"core-theory" "IEEE754" + let format {base; w; t=x; k} = Float.Format.define + Sort.(int base @-> int w @-> int x @-> sym ieee754) + (Bitv.define k) + + let define p : (('b,'e,'t) ieee754,'s) Float.format Float.t sort = Float.define (format p) + + let spec e = + let fmt = Float.format e in + let k = Bitv.size (Float.Format.bits fmt) in + let base = Sort.(value @@ hd (Float.Format.exp fmt)) in + match base with + | 2 -> binary k + | 10 -> decimal k + | _ -> assert false + + let exps e = + let {w} = spec e in + Bitv.define w + + let sigs e = + let {p} = spec e in + Bitv.define p + + let bits e = + let {k} = spec e in + Bitv.define k +end + + +let binary = function + | 0 -> None + | 16 -> Some binary16 + | 32 -> Some binary32 + | 80 -> Some binary80 + | n when n mod 32 = 0 -> Some (binary n) + | _ -> None + +let decimal n = + if n > 0 && n mod 32 = 0 then Some (decimal n) + else None diff --git a/lib/bap_core_theory/bap_core_theory_IEEE754.mli b/lib/bap_core_theory/bap_core_theory_IEEE754.mli new file mode 100644 index 000000000..8f6af23bc --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_IEEE754.mli @@ -0,0 +1,34 @@ +open Bap_core_theory_value +type ('a,'e,'t) t +type ('a,'e,'t) ieee754 = ('a,'e,'t) t + +type parameters = private { + base : int; + bias : int; + k : int; + p : int; + w : int; + t : int; +} + + +val binary16 : parameters +val binary32 : parameters +val binary64 : parameters +val binary80 : parameters +val binary128 : parameters +val decimal32 : parameters +val decimal64 : parameters +val decimal128 : parameters + +val binary : int -> parameters option +val decimal : int -> parameters option + +module Sort : sig + open Float + val define : parameters -> (('b,'e,'t) ieee754,'s) format Float.t sort + val exps : (('b,'e,'t) ieee754,'s) format Float.t sort -> 'e Bitv.t sort + val sigs : (('b,'e,'t) ieee754,'s) format Float.t sort -> 't Bitv.t sort + val bits : (('b,'e,'t) ieee754,'s) format Float.t sort -> 's Bitv.t sort + val spec : (('b,'e,'t) ieee754,'s) format Float.t sort -> parameters +end diff --git a/lib/bap_core_theory/bap_core_theory_basic.ml b/lib/bap_core_theory/bap_core_theory_basic.ml new file mode 100644 index 000000000..f3c235b46 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_basic.ml @@ -0,0 +1,128 @@ +open Core_kernel +open Bap_knowledge + +open Bap_core_theory_definition +open Bap_core_theory_value + +open Knowledge.Syntax + +module Value = Knowledge.Value + +let size = Bitv.size +let (>>->) + = fun x f -> x >>= fun x -> f (KB.Class.sort (Value.cls x)) x + + +module Make(L : Minimal) = struct + open L + module BitLogic = struct + let (&&) = and_ and (||) = or_ and not = inv + end + + + include struct + open BitLogic + let eq x y = + x >>= fun x -> + y >>= fun y -> + ule !!x !!y && ule !!y !!x + let neq x y = not (eq x y) + let slt x y = + x >>= fun x -> + y >>= fun y -> + sle !!x !!y && not (sle !!y !!x) + let ult x y = + x >>= fun x -> + y >>= fun y -> + ule !!x !!y && not (ule !!y !!x) + let sgt x y = slt y x + let ugt x y = ult y x + let sge x y = sle y x + let uge x y = ule y x + end + + let small s x = int s Bitvec.(int x mod modulus (size s)) + + let zero s = int s Bitvec.zero + + let is_zero x = + x >>-> fun s x -> + eq !!x (zero s) + + let non_zero x = inv (is_zero x) + + let nsucc x n = + x >>-> fun s x -> + add !!x (small s n) + + let npred x n = + x >>-> fun s x -> + sub !!x (small s n) + + let succ x = nsucc x 1 + let pred x = npred x 1 + + let high s x = + x >>-> fun t x -> + let n = min (size t) (max 0 (size t - size s)) in + cast s b0 (shiftr b0 !!x (small t n)) + + let low s x = cast s b0 x + let signed s x = cast s (msb x) x + let unsigned s x = cast s b0 x + + + let bind exp body = + exp >>-> fun s exp -> + Var.scoped s @@ fun v -> + let_ v !!exp (body v) + + let loadw out dir mem key = + dir >>= fun dir -> + mem >>-> fun ms mem -> + key >>= fun key -> + let vs = Mem.vals ms in + let chunk_size = size vs in + let needed = size out in + let rec loop chunks loaded = + if loaded < needed then + let key = nsucc !!key (loaded / chunk_size) in + bind (load !!mem key) @@ fun chunk -> + loop (var chunk :: chunks) (loaded + chunk_size) + else + ite !!dir + (concat out (List.rev chunks)) + (concat out chunks) in + loop [] 0 + + let storew dir mem key data = + data >>-> fun data_t data -> + mem >>-> fun mem_t mem -> + let chunks = Mem.vals mem_t in + let needed = size data_t and chunk_len = size chunks in + let nth stored = + let shift_amount = ite dir + (small data_t stored) + (small data_t (needed - stored)) in + cast chunks b0 (shiftr b0 !!data shift_amount) in + let rec loop key stored mem = + if stored < needed then + loop + (succ key) + (stored + chunk_len) + (store mem key (nth stored)) + else mem in + loop key 0 !!mem + + let arshift x y = shiftr (msb x) x y + let rshift x y = shiftr b0 x y + let lshift x y = shiftl b0 x y + + let extract s hi lo x = + let n = succ (sub hi lo) in + x >>= fun x -> + let t = KB.Class.sort (Value.cls x) in + let mask = lshift (not (zero t)) n in + cast s b0 (logand (not mask) (shiftr b0 !!x lo)) + include L +end diff --git a/lib/bap_core_theory/bap_core_theory_basic.mli b/lib/bap_core_theory/bap_core_theory_basic.mli new file mode 100644 index 000000000..5547929e7 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_basic.mli @@ -0,0 +1,2 @@ +open Bap_core_theory_definition +module Make(S : Minimal) : Basic diff --git a/lib/bap_core_theory/bap_core_theory_definition.ml b/lib/bap_core_theory/bap_core_theory_definition.ml new file mode 100644 index 000000000..88a975797 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_definition.ml @@ -0,0 +1,219 @@ +open Bap_knowledge + +open Bap_core_theory_value + +module Var = Bap_core_theory_var +module Value = Bap_core_theory_value +module Effect = Bap_core_theory_effect +module Program = Bap_core_theory_program +module Label = Program.Label + +type 'a value = 'a Value.t +type 'a effect = 'a Effect.t +type program = Program.cls + + +type 'a pure = 'a value knowledge +type 'a eff = 'a effect knowledge + +type bool = Bool.t pure +type 'a bitv = 'a Bitv.t pure +type ('a,'b) mem = ('a,'b) Mem.t pure +type 'f float = 'f Float.t pure +type rmode = Rmode.t pure + +type data = Effect.Sort.data +type ctrl = Effect.Sort.ctrl + +type ('r,'s) format = ('r,'s) Float.format + +type word = Bitvec.t +type 'a var = 'a Var.t +type label = program Knowledge.Object.t + + + +module type Init = sig + val var : 'a var -> 'a pure + val unk : 'a sort -> 'a pure + val let_ : 'a var -> 'a pure -> 'b pure -> 'b pure +end + +module type Bool = sig + val b0 : bool + val b1 : bool + val inv : bool -> bool + val and_ : bool -> bool -> bool + val or_ : bool -> bool -> bool +end + +module type Bitv = sig + val int : 'a Bitv.t sort -> word -> 'a bitv + val msb : 'a bitv -> bool + val lsb : 'a bitv -> bool + val neg : 'a bitv -> 'a bitv + val not : 'a bitv -> 'a bitv + val add : 'a bitv -> 'a bitv -> 'a bitv + val sub : 'a bitv -> 'a bitv -> 'a bitv + val mul : 'a bitv -> 'a bitv -> 'a bitv + val div : 'a bitv -> 'a bitv -> 'a bitv + val sdiv : 'a bitv -> 'a bitv -> 'a bitv + val modulo : 'a bitv -> 'a bitv -> 'a bitv + val smodulo : 'a bitv -> 'a bitv -> 'a bitv + val logand : 'a bitv -> 'a bitv -> 'a bitv + val logor : 'a bitv -> 'a bitv -> 'a bitv + val logxor : 'a bitv -> 'a bitv -> 'a bitv + val shiftr : bool -> 'a bitv -> 'b bitv -> 'a bitv + val shiftl : bool -> 'a bitv -> 'b bitv -> 'a bitv + val ite : bool -> 'a pure -> 'a pure -> 'a pure + val sle : 'a bitv -> 'a bitv -> bool + val ule : 'a bitv -> 'a bitv -> bool + val cast : 'a Bitv.t sort -> bool -> 'b bitv -> 'a bitv + val concat : 'a Bitv.t sort -> 'b bitv list -> 'a bitv + val append : 'a Bitv.t sort -> 'b bitv -> 'c bitv -> 'a bitv +end + +module type Memory = sig + val load : ('a,'b) mem -> 'a bitv -> 'b bitv + val store : ('a,'b) mem -> 'a bitv -> 'b bitv -> ('a,'b) mem +end + +module type Effect = sig + val perform : 'a Effect.Sort.t -> 'a eff + val set : 'a var -> 'a pure -> data eff + val jmp : _ bitv -> ctrl eff + val goto : label -> ctrl eff + val seq : 'a eff -> 'a eff -> 'a eff + val blk : label -> data eff -> ctrl eff -> unit eff + val repeat : bool -> data eff -> data eff + val branch : bool -> 'a eff -> 'a eff -> 'a eff +end + +module type Minimal = sig + include Init + include Bool + include Bitv + include Memory + include Effect +end + +module type Basic = sig + include Minimal + val zero : 'a Bitv.t sort -> 'a bitv + val is_zero : 'a bitv -> bool + val non_zero : 'a bitv -> bool + val succ : 'a bitv -> 'a bitv + val pred : 'a bitv -> 'a bitv + val nsucc : 'a bitv -> int -> 'a bitv + val npred : 'a bitv -> int -> 'a bitv + val high : 'a Bitv.t sort -> 'b bitv -> 'a bitv + val low : 'a Bitv.t sort -> 'b bitv -> 'a bitv + val signed : 'a Bitv.t sort -> 'b bitv -> 'a bitv + val unsigned : 'a Bitv.t sort -> 'b bitv -> 'a bitv + val extract : 'a Bitv.t sort -> 'b bitv -> 'b bitv -> _ bitv -> 'a bitv + val loadw : 'c Bitv.t sort -> bool -> ('a, _) mem -> 'a bitv -> 'c bitv + val storew : bool -> ('a, 'b) mem -> 'a bitv -> 'c bitv -> ('a, 'b) mem + val arshift : 'a bitv -> 'b bitv -> 'a bitv + val rshift : 'a bitv -> 'b bitv -> 'a bitv + val lshift : 'a bitv -> 'b bitv -> 'a bitv + val eq : 'a bitv -> 'a bitv -> bool + val neq : 'a bitv -> 'a bitv -> bool + val slt : 'a bitv -> 'a bitv -> bool + val ult : 'a bitv -> 'a bitv -> bool + val sgt : 'a bitv -> 'a bitv -> bool + val ugt : 'a bitv -> 'a bitv -> bool + val sge : 'a bitv -> 'a bitv -> bool + val uge : 'a bitv -> 'a bitv -> bool +end + +module type Fbasic = sig + val float : ('r,'s) format Float.t sort -> 's bitv -> ('r,'s) format float + val fbits : ('r,'s) format float -> 's bitv + + + val is_finite : 'f float -> bool + val is_nan : 'f float -> bool + val is_inf : 'f float -> bool + val is_fzero : 'f float -> bool + val is_fpos : 'f float -> bool + val is_fneg : 'f float -> bool + + val rne : rmode + val rna : rmode + val rtp : rmode + val rtn : rmode + val rtz : rmode + val requal : rmode -> rmode -> bool + + val cast_float : 'f Float.t sort -> rmode -> 'a bitv -> 'f float + val cast_sfloat : 'f Float.t sort -> rmode -> 'a bitv -> 'f float + val cast_int : 'a Bitv.t sort -> rmode -> 'f float -> 'a bitv + val cast_sint : 'a Bitv.t sort -> rmode -> 'f float -> 'a bitv + + val fneg : 'f float -> 'f float + val fabs : 'f float -> 'f float + + val fadd : rmode -> 'f float -> 'f float -> 'f float + val fsub : rmode -> 'f float -> 'f float -> 'f float + val fmul : rmode -> 'f float -> 'f float -> 'f float + val fdiv : rmode -> 'f float -> 'f float -> 'f float + val fsqrt : rmode -> 'f float -> 'f float + val fmodulo : rmode -> 'f float -> 'f float -> 'f float + val fmad : rmode -> 'f float -> 'f float -> 'f float -> 'f float + + val fround : rmode -> 'f float -> 'f float + val fconvert : 'f Float.t sort -> rmode -> _ float -> 'f float + + val fsucc : 'f float -> 'f float + val fpred : 'f float -> 'f float + val forder : 'f float -> 'f float -> bool +end + +module type Float = sig + include Fbasic + val pow : rmode -> 'f float -> 'f float -> 'f float + val powr : rmode -> 'f float -> 'f float -> 'f float + val compound : rmode -> 'f float -> 'a bitv -> 'f float + val rootn : rmode -> 'f float -> 'a bitv -> 'f float + val pownn : rmode -> 'f float -> 'a bitv -> 'f float + val rsqrt : rmode -> 'f float -> 'f float + val hypot : rmode -> 'f float -> 'f float -> 'f float +end + +module type Trans = sig + val exp : rmode -> 'f float -> 'f float + val expm1 : rmode -> 'f float -> 'f float + val exp2 : rmode -> 'f float -> 'f float + val exp2m1 : rmode -> 'f float -> 'f float + val exp10 : rmode -> 'f float -> 'f float + val exp10m1 : rmode -> 'f float -> 'f float + val log : rmode -> 'f float -> 'f float + val log2 : rmode -> 'f float -> 'f float + val log10 : rmode -> 'f float -> 'f float + val logp1 : rmode -> 'f float -> 'f float + val log2p1 : rmode -> 'f float -> 'f float + val log10p1 : rmode -> 'f float -> 'f float + val sin : rmode -> 'f float -> 'f float + val cos : rmode -> 'f float -> 'f float + val tan : rmode -> 'f float -> 'f float + val sinpi : rmode -> 'f float -> 'f float + val cospi : rmode -> 'f float -> 'f float + val atanpi : rmode -> 'f float -> 'f float + val atan2pi : rmode -> 'f float -> 'f float -> 'f float + val asin : rmode -> 'f float -> 'f float + val acos : rmode -> 'f float -> 'f float + val atan : rmode -> 'f float -> 'f float + val atan2 : rmode -> 'f float -> 'f float -> 'f float + val sinh : rmode -> 'f float -> 'f float + val cosh : rmode -> 'f float -> 'f float + val tanh : rmode -> 'f float -> 'f float + val asinh : rmode -> 'f float -> 'f float + val acosh : rmode -> 'f float -> 'f float + val atanh : rmode -> 'f float -> 'f float +end + +module type Core = sig + include Basic + include Float + include Trans +end diff --git a/lib/bap_core_theory/bap_core_theory_effect.ml b/lib/bap_core_theory/bap_core_theory_effect.ml new file mode 100644 index 000000000..938b18d43 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_effect.ml @@ -0,0 +1,68 @@ +open Core_kernel +open Bap_knowledge + +module KB = Knowledge + +type cls = Effects + +let package = "core-theory" + +module Sort = struct + type effects = Top | Set of Set.M(String).t + type +'a t = effects + type data = Data + type ctrl = Ctrl + let single eff = Set (Set.singleton (module String) eff) + let make name = single name + let define name = make name + + let both x y = match x, y with + | Top,_ | _,Top -> Top + | Set x, Set y -> Set (Set.union x y) + + let (&&) = both + + let refine name other : 'a t = make name && other + + let top = Top + let bot = Set (Set.empty (module String)) + + let union xs = List.reduce xs ~f:both |> function + | Some x -> x + | None -> bot + + let join xs ys = union xs && union ys + + let order x y : Knowledge.Order.partial = + match x, y with + | Top,Top -> EQ + | Top,_ -> GT + | _,Top -> LT + | Set x, Set y -> + if + Set.equal x y then EQ else if + Set.is_subset x ~of_:y then LT else if + Set.is_subset y ~of_:x then GT else NC + + let rreg = define "rreg" + let wreg = define "wreg" + let rmem = define "rmem" + let wmem = define "wmem" + let barr = define "barr" + let fall = define "fall" + let jump = define "jump" + let cjmp = define "cjmp" + + let data = make + let ctrl = make +end + +type +'a sort = 'a Sort.t +let cls : (cls, unit) Knowledge.Class.t = + Knowledge.Class.declare ~package "effect" + ~desc:"denotation of a result of effectful computation" + () + +type 'a t = (cls,'a sort) KB.cls KB.value +let empty s = KB.Value.empty (KB.Class.refine cls s) +let sort v = KB.Class.sort (KB.Value.cls v) diff --git a/lib/bap_core_theory/bap_core_theory_effect.mli b/lib/bap_core_theory/bap_core_theory_effect.mli new file mode 100644 index 000000000..8878fd685 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_effect.mli @@ -0,0 +1,44 @@ +open Bap_knowledge + +module KB = Knowledge + +type cls +type +'a sort +type 'a t = (cls,'a sort) KB.cls KB.value + +val cls : (cls,unit) Knowledge.cls + +val empty : 'a sort -> 'a t +val sort : 'a t -> 'a sort + + +module Sort : sig + type data = private Data + type ctrl = private Ctrl + type +'a t = 'a sort + + + val data : string -> data t + val ctrl : string -> ctrl t + val top : unit t + val bot : 'a t + + val both : 'a t -> 'a t -> 'a t + val (&&) : 'a t -> 'a t -> 'a t + val union : 'a t list -> 'a t + val join : 'a t list -> 'b t list -> unit t + + val order : 'a t -> 'b t -> Knowledge.Order.partial + + + val rreg : data t + val wreg : data t + val rmem : data t + val wmem : data t + val barr : data t + + + val fall : ctrl t + val jump : ctrl t + val cjmp : ctrl t +end diff --git a/lib/bap_core_theory/bap_core_theory_empty.ml b/lib/bap_core_theory/bap_core_theory_empty.ml new file mode 100644 index 000000000..996af4ab6 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_empty.ml @@ -0,0 +1,175 @@ +open Bap_knowledge +open Bap_core_theory_definition +open Bap_core_theory_value + +open Knowledge.Syntax + +module Value = Knowledge.Value + +let bool = Bool.t +let sort x = x >>| Value.cls >>| KB.Class.sort + +module Core : Core = struct + type 'a t = 'a Knowledge.t + + let empty x = + Knowledge.return @@ + Value.empty (KB.Class.refine cls x) + + let neweff eff = + Knowledge.return @@ + Value.empty (KB.Class.refine Effect.cls eff) + + let data = neweff Effect.Sort.bot + let ctrl = neweff Effect.Sort.bot + let unit = neweff Effect.Sort.bot + + let var v = empty (Var.sort v) + let int s _ = empty s + let unk s = empty s + let b0 = empty bool + let b1 = empty bool + let inv _ = empty bool + let and_ _ _ = empty bool + let or_ _ _ = empty bool + let msb _ = empty bool + let lsb _ = empty bool + let neg x = sort x >>= empty + let not x = sort x >>= empty + let add x _ = sort x >>= empty + let sub x _ = sort x >>= empty + let mul x _ = sort x >>= empty + let div x _ = sort x >>= empty + let sdiv x _ = sort x >>= empty + let modulo x _ = sort x >>= empty + let smodulo x _ = sort x >>= empty + let logand x _ = sort x >>= empty + let logor x _ = sort x >>= empty + let logxor x _ = sort x >>= empty + let shiftr _ x _ = sort x >>= empty + let shiftl _ x _ = sort x >>= empty + let ite _ x _ = sort x >>= empty + let sle _ _ = empty bool + let ule _ _ = empty bool + let cast s _ _ = empty s + let concat s _ = empty s + let append s _ _ = empty s + let load m _ = sort m >>| Mem.vals >>= empty + let store m _ _ = sort m >>= empty + let pass = data + let skip = ctrl + let perform eff = neweff eff + + let set _ _ = data + let let_ _ _ x = sort x >>= empty + let jmp _ = ctrl + let goto _ = ctrl + let seq x _ = x >>| Value.cls >>| Value.empty + let blk _ _ _ = unit + let repeat _ _ = data + let branch _ x _ = x >>| Value.cls >>| Value.empty + let atomic _ = data + let mfence = data + let lfence = data + let sfence = data + + let zero = empty + let is_zero _ = empty bool + let non_zero _ = empty bool + let succ x = sort x >>= empty + let pred x = sort x >>= empty + let nsucc x _ = sort x >>= empty + let npred x _ = sort x >>= empty + let high s _ = empty s + let low s _ = empty s + let signed s _ = empty s + let unsigned s _ = empty s + let extract s _ _ _ = empty s + let loadw s _ _ _ = empty s + let storew _ x _ _ = sort x >>= empty + let arshift x _ = sort x >>= empty + let rshift x _ = sort x >>= empty + let lshift x _ = sort x >>= empty + + let eq _ _ = empty bool + let neq _ _ = empty bool + let slt _ _ = empty bool + let ult _ _ = empty bool + let sgt _ _ = empty bool + let ugt _ _ = empty bool + let sge _ _ = empty bool + let uge _ _ = empty bool + + let rne = empty Rmode.t + let rna = empty Rmode.t + let rtp = empty Rmode.t + let rtn = empty Rmode.t + let rtz = empty Rmode.t + let requal _ _ = empty bool + + let float s _ = empty s + let fbits x = sort x >>| Float.size >>= empty + + let is_finite _ = empty bool + let is_fzero _ = empty bool + let is_fneg _ = empty bool + let is_fpos _ = empty bool + let is_nan _ = empty bool + let is_inf _ = empty bool + let cast_float s _ _ = empty s + let cast_sfloat s _ _ = empty s + let cast_int s _ _ = empty s + let cast_sint s _ _ = empty s + let fneg x = sort x >>= empty + let fabs x = sort x >>= empty + let fadd _ x _ = sort x >>= empty + let fsub _ x _ = sort x >>= empty + let fmul _ x _ = sort x >>= empty + let fdiv _ x _ = sort x >>= empty + let fsqrt _ x = sort x >>= empty + let fmodulo _ x _ = sort x >>= empty + let fmad _ x _ _ = sort x >>= empty + let fround _ x = sort x >>= empty + let fconvert s _ _ = empty s + let fsucc x = sort x >>= empty + let fpred x = sort x >>= empty + let forder _ _ = empty bool + + let pow _ x _ = sort x >>= empty + let powr _ x _ = sort x >>= empty + let compound _ x _ = sort x >>= empty + let rootn _ x _ = sort x >>= empty + let pownn _ x _ = sort x >>= empty + let rsqrt _ x = sort x >>= empty + let hypot _ x _ = sort x >>= empty + + let exp _ x = sort x >>= empty + let expm1 _ x = sort x >>= empty + let exp2 _ x = sort x >>= empty + let exp2m1 _ x = sort x >>= empty + let exp10 _ x = sort x >>= empty + let exp10m1 _ x = sort x >>= empty + let log _ x = sort x >>= empty + let log2 _ x = sort x >>= empty + let log10 _ x = sort x >>= empty + let logp1 _ x = sort x >>= empty + let log2p1 _ x = sort x >>= empty + let log10p1 _ x = sort x >>= empty + let sin _ x = sort x >>= empty + let cos _ x = sort x >>= empty + let tan _ x = sort x >>= empty + let sinpi _ x = sort x >>= empty + let cospi _ x = sort x >>= empty + let atanpi _ x = sort x >>= empty + let atan2pi _ x _ = sort x >>= empty + let asin _ x = sort x >>= empty + let acos _ x = sort x >>= empty + let atan _ x = sort x >>= empty + let atan2 _ x _ = sort x >>= empty + let sinh _ x = sort x >>= empty + let cosh _ x = sort x >>= empty + let tanh _ x = sort x >>= empty + let asinh _ x = sort x >>= empty + let acosh _ x = sort x >>= empty + let atanh _ x = sort x >>= empty +end diff --git a/lib/bap_core_theory/bap_core_theory_empty.mli b/lib/bap_core_theory/bap_core_theory_empty.mli new file mode 100644 index 000000000..ee68cacb3 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_empty.mli @@ -0,0 +1,3 @@ +open Bap_core_theory_definition + +module Core : Core diff --git a/lib/bap_core_theory/bap_core_theory_grammar_definition.ml b/lib/bap_core_theory/bap_core_theory_grammar_definition.ml new file mode 100644 index 000000000..4a129811a --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_grammar_definition.ml @@ -0,0 +1,207 @@ +open Bap_core_theory_value +type word = Bitvec.t + +module IEEE754 = Bap_core_theory_IEEE754 +type ieee754 = IEEE754.parameters + +module type Bitv = sig + type t + type exp + type rmode + + val error : t + + val unsigned : int -> exp -> t + val signed : int -> exp -> t + val high : int -> exp -> t + val low : int -> exp -> t + val cast : int -> exp -> exp -> t + val extract : int -> exp -> exp -> exp -> t + + val add : exp -> exp -> t + val sub : exp -> exp -> t + val mul : exp -> exp -> t + val div : exp -> exp -> t + val sdiv : exp -> exp -> t + val modulo : exp -> exp -> t + val smodulo : exp -> exp -> t + val lshift : exp -> exp -> t + val rshift : exp -> exp -> t + val arshift : exp -> exp -> t + val logand : exp -> exp -> t + val logor: exp -> exp -> t + val logxor : exp -> exp -> t + + val neg : exp -> t + val not : exp -> t + + val load_word : int -> exp -> exp -> exp -> t + val load : exp -> exp -> t + + + val var : string -> int -> t + val int : word -> int -> t + val unknown : int -> t + val ite : exp -> exp -> exp -> t + + val let_bit : string -> exp -> exp -> t + val let_reg : string -> exp -> exp -> t + val let_mem : string -> exp -> exp -> t + val let_float : string -> exp -> exp -> t + + val append : exp -> exp -> t + val concat : exp list -> t + + val cast_int : int -> rmode -> exp -> t + val cast_sint : int -> rmode -> exp -> t + val fbits : exp -> t +end + +module type Bool = sig + type t + type exp + + val error : t + + val eq : exp -> exp -> t + val neq : exp -> exp -> t + val lt : exp -> exp -> t + val le : exp -> exp -> t + val slt : exp -> exp -> t + val sle : exp -> exp -> t + val var : string -> t + val int : word -> t + val unknown : unit -> t + val ite : exp -> exp -> exp -> t + val let_bit : string -> exp -> exp -> t + val let_reg : string -> exp -> exp -> t + val let_mem : string -> exp -> exp -> t + val let_float : string -> exp -> exp -> t + + val high : exp -> t + val low : exp -> t + val extract : int -> exp -> t + + val not : exp -> t + val logand : exp -> exp -> t + val logor: exp -> exp -> t + val logxor : exp -> exp -> t + + val is_inf : exp -> t + val is_nan : exp -> t + val is_fzero : exp -> t + val is_fpos : exp -> t + val is_fneg : exp -> t + + val fle : exp -> exp -> t + val flt : exp -> exp -> t + val feq : exp -> exp -> t +end + + +module type Mem = sig + type t + type exp + + val error : t + + (** [store mem key data] *) + val store : exp -> exp -> exp -> t + + + (** [store_word dir mem key data ] *) + val store_word : exp -> exp -> exp -> exp -> t + val var : string -> int -> int -> t + val unknown : int -> int -> t + val ite : exp -> exp -> exp -> t + val let_bit : string -> exp -> exp -> t + val let_reg : string -> exp -> exp -> t + val let_mem : string -> exp -> exp -> t + val let_float : string -> exp -> exp -> t +end + +module type Stmt = sig + type t + type exp + type rmode + type stmt + + val error : t + + val set_mem : string -> int -> int -> exp -> t + val set_reg : string -> int -> exp -> t + val set_bit : string -> exp -> t + val set_ieee754 : string -> ieee754 -> exp -> t + val set_rmode : string -> rmode -> t + + val tmp_mem : string -> exp -> t + val tmp_reg : string -> exp -> t + val tmp_bit : string -> exp -> t + val tmp_float : string -> exp -> t + val tmp_rmode : string -> rmode -> t + + val let_mem : string -> exp -> stmt -> t + val let_reg : string -> exp -> stmt -> t + val let_bit : string -> exp -> stmt -> t + val let_float : string -> exp -> stmt -> t + val let_rmode : string -> rmode -> stmt -> t + + val jmp : exp -> t + val goto : word -> t + val call : string -> t + val special : string -> t + val cpuexn : int -> t + + val while_ : exp -> stmt list -> t + val if_ : exp -> stmt list -> stmt list -> t + + val seq : stmt list -> t +end + +module type Float = sig + type t + type exp + type rmode + + val error : t + + val ieee754 : ieee754 -> exp -> t + val ieee754_var : ieee754 -> string -> t + val ieee754_unk : ieee754 -> t + val ieee754_cast : ieee754 -> rmode -> exp -> t + val ieee754_cast_signed : ieee754 -> rmode -> exp -> t + val ieee754_convert : ieee754 -> rmode -> exp -> t + + val ite : exp -> exp -> exp -> t + + val fadd : rmode -> exp -> exp -> t + val fsub : rmode -> exp -> exp -> t + val fmul : rmode -> exp -> exp -> t + val fdiv : rmode -> exp -> exp -> t + val frem : rmode -> exp -> exp -> t + val fmin : exp -> exp -> t + val fmax : exp -> exp -> t + + val fabs : exp -> t + val fneg : exp -> t + val fsqrt : rmode -> exp -> t + val fround : rmode -> exp -> t + + val let_bit : string -> exp -> exp -> t + val let_reg : string -> exp -> exp -> t + val let_mem : string -> exp -> exp -> t + val let_float : string -> exp -> exp -> t +end + +module type Rmode = sig + type t + type exp + + val error : t + + val rne : t + val rtz : t + val rtp : t + val rtn : t + val rna : t +end diff --git a/lib/bap_core_theory/bap_core_theory_manager.ml b/lib/bap_core_theory/bap_core_theory_manager.ml new file mode 100644 index 000000000..2a95a2ae5 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_manager.ml @@ -0,0 +1,332 @@ +open Core_kernel +open Bap_knowledge + +open Bap_core_theory_definition +open Bap_core_theory_value + +open Knowledge.Syntax + +module Value = Knowledge.Value + +let size = Bitv.size +let sort x = x >>| fun v -> KB.Class.sort (Value.cls v) +let effect x = x >>| fun v -> KB.Class.sort (Value.cls v) + +type 'a t = { + name : string; + desc : string; + proc : 'a; +} + +type provider = (module Core) + +let providers : provider t list ref = ref [] +let register ?(desc="") ~name x = + let provider = {name; desc; proc = x} in + providers := !providers @ [provider] + +let bool = Bool.t + + +let ret = Knowledge.return + +let newval s = + Knowledge.return @@ + Value.empty @@ + KB.Class.refine cls s +[@@inline] + +let neweff s = + Knowledge.return @@ + Value.empty @@ + KB.Class.refine Effect.cls s +[@@inline] + + +let foreach f init = Knowledge.List.fold !providers ~init ~f + +let lift0 gen join sort f = + gen sort >>= + foreach @@begin fun r {proc} -> + f proc >>| fun r' -> + join r r' + end + +let lift1 gen join x sort f = + x >>= fun x -> + sort !!x >>= gen >>= + foreach @@begin fun r {proc} -> + f proc !!x >>| fun r' -> + join r r' + end + +let lift2 gen join x y sort f = + x >>= fun x -> + y >>= fun y -> + sort !!x !!y >>= gen >>= + foreach @@begin fun r {proc} -> + f proc !!x !!y >>| fun r' -> + join r r' + end + +let lift3 gen join x y z sort f = + x >>= fun x -> + y >>= fun y -> + z >>= fun z -> + sort !!x !!y !!z >>= gen >>= + foreach @@begin fun r {proc} -> + f proc !!x !!y !!z >>| fun r' -> + join r r' + end + +let lift4 gen join x y z a sort f = + x >>= fun x -> + y >>= fun y -> + z >>= fun z -> + a >>= fun a -> + sort !!x !!y !!z !!a >>= gen >>= + foreach @@begin fun r {proc} -> + f proc !!x !!y !!z !!a >>| fun r' -> + join r r' + end + +let val0 sort f = lift0 newval Value.merge sort f +let val1 x sort f = lift1 newval Value.merge x sort f +let val2 x y sort f = lift2 newval Value.merge x y sort f +let val3 x y z sort f = lift3 newval Value.merge x y z sort f +let val4 x y z a sort f = lift4 newval Value.merge x y z a sort f +let eff0 sort f = lift0 neweff Value.merge sort f +let eff1 x sort f = lift1 neweff Value.merge x sort f +let eff2 x y sort f = lift2 neweff Value.merge x y sort f +let eff3 x y z sort f = lift3 neweff Value.merge x y z sort f + +module Theory : Core = struct + type 'a t = 'a Knowledge.t + + let var v = val0 (Var.sort v) @@ fun (module P) -> P.var v + let int s x = val0 s @@ fun (module P) -> P.int s x + let unk s = val0 s @@ fun (module P) -> P.unk s + let b0 = val0 bool @@ fun (module P) -> P.b0 + let b1 = val0 bool @@ fun (module P) -> P.b1 + let inv x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.inv + let and_ x y = val2 x y (fun _ _ -> !!bool) @@ fun (module P) -> P.and_ + let or_ x y = val2 x y (fun _ _ -> !!bool) @@ fun (module P) -> P.or_ + let msb x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.msb + let lsb x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.lsb + + let neg x = val1 x sort @@ fun (module P) -> P.neg + let not x = val1 x sort @@ fun (module P) -> P.not + + let uop x f = val1 x sort f + let aop x y f = val2 x y (fun x _ -> sort x) f + let add x y = aop x y @@ fun (module P) -> P.add + let sub x y = aop x y @@ fun (module P) -> P.sub + let mul x y = aop x y @@ fun (module P) -> P.mul + let div x y = aop x y @@ fun (module P) -> P.div + let sdiv x y = aop x y @@ fun (module P) -> P.sdiv + let modulo x y = aop x y @@ fun (module P) -> P.modulo + let smodulo x y = aop x y @@ fun (module P) -> P.smodulo + let logand x y = aop x y @@ fun (module P) -> P.logand + let logor x y = aop x y @@ fun (module P) -> P.logor + let logxor x y = aop x y @@ fun (module P) -> P.logxor + + let shiftr b x y = val3 b x y (fun _ x _ -> sort x) @@ + fun (module P) -> P.shiftr + let shiftl b x y = val3 b x y (fun _ x _ -> sort x) @@ + fun (module P) -> P.shiftl + let ite b x y = val3 b x y (fun _ x _ -> sort x) @@ + fun (module P) -> P.ite + + let lop x y f = val2 x y (fun _ _ -> !!bool) f + let sle x y = lop x y @@ fun (module P) -> P.sle + let ule x y = lop x y @@ fun (module P) -> P.ule + + let cast s x z = val2 x z (fun _ _ -> !!s) @@ + fun (module P) -> P.cast s + + let concat s xs = + Knowledge.List.all xs >>= fun xs -> + let xs = List.map ~f:(!!) xs in + newval s >>= + foreach @@begin fun r {proc=(module P)} -> + P.concat s xs >>| fun r' -> + Value.merge r r' + end + + let append s x y = val2 x y (fun _ _ -> !!s) @@ + fun (module P) -> P.append s + + let load m k = val2 m k (fun m _ -> sort m >>| Mem.vals) @@ + fun (module P) -> P.load + + let store m k v = val3 m k v (fun m _ _ -> sort m) @@ + fun (module P) -> P.store + + let perform s = eff0 s @@ fun (module P) -> P.perform s + + let set v x = eff1 x (fun _ -> !!Effect.Sort.bot) @@ fun (module P) -> + P.set v + + let let_ v x b = val2 x b (fun _ x -> sort x) @@ fun (module P) -> + P.let_ v + + let jmp d = eff1 d (fun _ -> !!Effect.Sort.bot) @@ fun (module P) -> + P.jmp + + let goto d = eff0 Effect.Sort.bot @@ fun (module P) -> P.goto d + + let seq x y = eff2 x y (fun x _ -> effect x) @@ fun (module P) -> + P.seq + let blk l x y = eff2 x y (fun _ _ -> !!Effect.Sort.bot) @@ fun (module P) -> + P.blk l + + let repeat b x = eff2 b x (fun _ _ -> !!Effect.Sort.bot) @@ fun (module P) -> + P.repeat + + let branch b x y = eff3 b x y (fun _ x _ -> effect x) @@ fun (module P) -> + P.branch + + + (* Provider *) + let zero s = val0 s @@ fun (module P) -> P.zero s + let is_zero x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.is_zero + let non_zero x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.non_zero + let succ x = val1 x sort @@ fun (module P) -> P.succ + let pred x = val1 x sort @@ fun (module P) -> P.pred + let nsucc x n = val1 x sort @@ fun (module P) x -> + P.nsucc x n + let npred x n = val1 x sort @@ fun (module P) x -> + P.npred x n + + let high s x = val1 x (fun _ -> !!s) @@ fun (module P) -> + P.high s + let low s x = val1 x (fun _ -> !!s) @@ fun (module P) -> + P.low s + let signed s x = val1 x (fun _ -> !!s) @@ fun (module P) -> + P.signed s + let unsigned s x = val1 x (fun _ -> !!s) @@ fun (module P) -> + P.unsigned s + + let extract s x y z = val3 x y z (fun _ _ _ -> !!s) @@ fun (module P) -> + P.extract s + + let loadw s d m k = val3 d m k (fun _ _ _ -> !!s) @@ fun (module P) -> + P.loadw s + + let storew d m k x = val4 d m k x (fun _ m _ _ -> sort m) @@ fun (module P) -> + P.storew + + let arshift x y = val2 x y (fun x _ -> sort x) @@ fun (module P) -> + P.arshift + let rshift x y = val2 x y (fun x _ -> sort x) @@ fun (module P) -> + P.rshift + let lshift x y = val2 x y (fun x _ -> sort x) @@ fun (module P) -> + P.lshift + + let eq x y = lop x y @@ fun (module P) -> P.eq + let neq x y = lop x y @@ fun (module P) -> P.neq + let slt x y = lop x y @@ fun (module P) -> P.slt + let ult x y = lop x y @@ fun (module P) -> P.ult + let sgt x y = lop x y @@ fun (module P) -> P.sgt + let ugt x y = lop x y @@ fun (module P) -> P.ugt + let sge x y = lop x y @@ fun (module P) -> P.sge + let uge x y = lop x y @@ fun (module P) -> P.uge + + + let rne = val0 Rmode.t @@ fun (module P) -> P.rne + let rna = val0 Rmode.t @@ fun (module P) -> P.rne + let rtp = val0 Rmode.t @@ fun (module P) -> P.rne + let rtn = val0 Rmode.t @@ fun (module P) -> P.rne + let rtz = val0 Rmode.t @@ fun (module P) -> P.rne + let requal x y = val2 x y (fun _ _ -> !!bool) @@ fun (module P) -> + P.requal + + let float s x = val1 x (fun _ -> !!s) @@ fun (module P) -> P.float s + let fbits x = val1 x (fun x -> sort x >>| Float.size) @@ fun (module P) -> + P.fbits + + let is_finite x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.is_finite + let is_fzero x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.is_fzero + let is_fneg x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.is_fneg + let is_fpos x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.is_fpos + let is_nan x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.is_nan + let is_inf x = val1 x (fun _ -> !!bool) @@ fun (module P) -> P.is_inf + + let cast_float s m x = val2 m x (fun _ _ -> !!s) @@ fun (module P) -> + P.cast_float s + + let cast_sfloat s m x = val2 m x (fun _ _ -> !!s) @@ fun (module P) -> + P.cast_sfloat s + + let cast_int s m x = val2 m x (fun _ _ -> !!s) @@ fun (module P) -> + P.cast_int s + + let cast_sint s m x = val2 m x (fun _ _ -> !!s) @@ fun (module P) -> + P.cast_sint s + + let fneg x = uop x @@ fun (module P) -> P.fneg + let fabs x = uop x @@ fun (module P) -> P.fabs + + let faop m x y f = val3 m x y (fun _ x _ -> sort x) f + let fadd m x y = faop m x y @@ fun (module P) -> P.fadd + let fsub m x y = faop m x y @@ fun (module P) -> P.fsub + let fmul m x y = faop m x y @@ fun (module P) -> P.fmul + let fdiv m x y = faop m x y @@ fun (module P) -> P.fdiv + let fmodulo m x y = faop m x y @@ fun (module P) -> P.fmodulo + + let fmad m x y z = val4 m x y z (fun _ x _ _ -> sort x) @@ fun (module P) -> + P.fmad + + let fround m x = val2 m x (fun _ x -> sort x) @@ fun (module P) -> + P.fround + let fconvert s x y = val2 x y (fun _ _ -> !!s) @@ fun (module P) -> + P.fconvert s + + let fsucc x = uop x @@ fun (module P) -> P.fsucc + let fpred x = uop x @@ fun (module P) -> P.fpred + + let forder x y = val2 x y (fun _ _ -> !!bool) @@ fun (module P) -> P.forder + + let pow m x y = faop m x y @@ fun (module P) -> P.pow + let powr m x y = faop m x y @@ fun (module P) -> P.powr + + + let compound m x y = faop m x y @@ fun (module P) -> P.compound + let rootn m x y = faop m x y @@ fun (module P) -> P.rootn + let pownn m x y = faop m x y @@ fun (module P) -> P.pownn + + let fuop m x f = val2 m x (fun _ x -> sort x) f + let fsqrt m x = fuop m x @@ fun (module P) -> P.fsqrt + let rsqrt m x = fuop m x @@ fun (module P) -> P.rsqrt + let hypot m x y = faop m x y @@ fun (module P) -> P.hypot + + let exp m x = fuop m x @@ fun (module P) -> P.exp + let expm1 m x = fuop m x @@ fun (module P) -> P.expm1 + let exp2 m x = fuop m x @@ fun (module P) -> P.exp2 + let exp2m1 m x = fuop m x @@ fun (module P) -> P.exp2m1 + let exp10 m x = fuop m x @@ fun (module P) -> P.exp10 + let exp10m1 m x = fuop m x @@ fun (module P) -> P.exp10m1 + let log m x = fuop m x @@ fun (module P) -> P.log + let log2 m x = fuop m x @@ fun (module P) -> P.log2 + let log10 m x = fuop m x @@ fun (module P) -> P.log10 + let logp1 m x = fuop m x @@ fun (module P) -> P.logp1 + let log2p1 m x = fuop m x @@ fun (module P) -> P.log2p1 + let log10p1 m x = fuop m x @@ fun (module P) -> P.log10p1 + let sin m x = fuop m x @@ fun (module P) -> P.sin + let cos m x = fuop m x @@ fun (module P) -> P.cos + let tan m x = fuop m x @@ fun (module P) -> P.tan + let sinpi m x = fuop m x @@ fun (module P) -> P.sinpi + let cospi m x = fuop m x @@ fun (module P) -> P.cospi + let atanpi m x = fuop m x @@ fun (module P) -> P.atanpi + let atan2pi m x y = faop m x y @@ fun (module P) -> P.atan2pi + let asin m x = fuop m x @@ fun (module P) -> P.asin + let acos m x = fuop m x @@ fun (module P) -> P.acos + let atan m x = fuop m x @@ fun (module P) -> P.atan + let atan2 m x y = faop m x y @@ fun (module P) -> P.atan2 + let sinh m x = fuop m x @@ fun (module P) -> P.sinh + let cosh m x = fuop m x @@ fun (module P) -> P.cosh + let tanh m x = fuop m x @@ fun (module P) -> P.tanh + let asinh m x = fuop m x @@ fun (module P) -> P.asinh + let acosh m x = fuop m x @@ fun (module P) -> P.acosh + let atanh m x = fuop m x @@ fun (module P) -> P.atanh +end diff --git a/lib/bap_core_theory/bap_core_theory_manager.mli b/lib/bap_core_theory/bap_core_theory_manager.mli new file mode 100644 index 000000000..4fbe80eeb --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_manager.mli @@ -0,0 +1,5 @@ +open Bap_core_theory_definition + +module Theory : Core + +val register : ?desc:string -> name:string -> (module Core) -> unit diff --git a/lib/bap_core_theory/bap_core_theory_parser.ml b/lib/bap_core_theory/bap_core_theory_parser.ml new file mode 100644 index 000000000..876198ede --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_parser.ml @@ -0,0 +1,576 @@ +open Core_kernel +open Bap_knowledge +open Bap_core_theory_definition +open Bap_core_theory_value + +module Value = Knowledge.Value +module Grammar = Bap_core_theory_grammar_definition +module Program = Bap_core_theory_program +module Label = Program.Label +module IEEE754 = Bap_core_theory_IEEE754 + +open Knowledge.Syntax + +type ('a,'e,'r) bitv_parser = + (module Grammar.Bitv with type t = 'a + and type exp = 'e + and type rmode = 'r) -> + 'e -> 'a + +type ('a,'e,'r) bool_parser = + (module Grammar.Bool with type t = 'a + and type exp = 'e) -> + 'e -> 'a + +type ('a,'e) mem_parser = + (module Grammar.Mem with type t = 'a + and type exp = 'e) -> + 'e -> 'a + +type ('a,'e,'r,'s) stmt_parser = + (module Grammar.Stmt with type t = 'a + and type exp = 'e + and type stmt = 's + and type rmode = 'r) -> + 's -> 'a + +type ('a,'e,'r) float_parser = + (module Grammar.Float with type t = 'a + and type exp = 'e + and type rmode = 'r) -> + 'e -> 'a + +type ('a,'e) rmode_parser = + (module Grammar.Rmode with type t = 'a + and type exp = 'e) -> + 'e -> 'a + +type ('e,'r,'s) t = { + bitv : 'a. ('a,'e,'r) bitv_parser; + bool : 'a. ('a,'e,'r) bool_parser; + mem : 'a. ('a,'e) mem_parser; + stmt : 'a. ('a,'e,'r,'s) stmt_parser; + float : 'a . ('a,'e,'r) float_parser; + rmode : 'a . ('a,'r) rmode_parser; +} + +type ('e,'r,'s) parser = ('e,'r,'s) t + +let bits = Bitv.define +let bool = Bool.t + +type Knowledge.conflict += Error + +let (>>->) x f = + x >>= fun x -> f (KB.Class.sort (Value.cls x)) x + +module Make(S : Core) = struct + open S + open Knowledge.Syntax + + type 'a t = 'a knowledge + + let of_word w s = int (bits s) w + let of_int s x = + let m = Bitvec.modulus (Bitv.size s) in + int s Bitvec.(int x mod m) + let join s1 s2 = bits (Bitv.size s1 + Bitv.size s2) + + let mkvar sort name = + Var.create sort (Var.Ident.of_string name) + + type context = (string * Var.ident) list + let rename (ctxt : context) v = + match List.Assoc.find ~equal:String.equal ctxt v with + | None -> Var.Ident.of_string v + | Some r -> r + + let pass = perform Effect.Sort.bot + let skip = perform Effect.Sort.bot + let newlabel = Knowledge.Object.create Program.cls + + let rec expw : type s b e r. + context -> + (e,r,b) parser -> e -> s bitv = + fun ctxt self -> self.bitv (module struct + type nonrec t = s bitv + type exp = e + type rmode = r + + let run = expw + let expw s = run ctxt self s + let expm s = expm ctxt self s + let expb s = expb ctxt self s + let expr s = expr ctxt self s + let expf s = expf ctxt self s + + let error = Knowledge.fail Error + + let load_word sz dir mem key = + loadw (bits sz) (expb dir) (expm mem) (expw key) + + let load mem key = load (expm mem) (expw key) + + let add x y = add (expw x) (expw y) + let sub x y = sub (expw x) (expw y) + let mul x y = mul (expw x) (expw y) + let div x y = div (expw x) (expw y) + let sdiv x y = sdiv (expw x) (expw y) + let modulo x y = modulo (expw x) (expw y) + let smodulo x y = smodulo (expw x) (expw y) + let lshift x y = lshift (expw x) (expw y) + let rshift x y = rshift (expw x) (expw y) + let arshift x y = arshift (expw x) (expw y) + let logand x y = logand (expw x) (expw y) + let logor x y = logor (expw x) (expw y) + let logxor x y = logxor (expw x) (expw y) + let var n sz = var (Var.create (bits sz) (rename ctxt n)) + + let int x s = of_word x s + let ite c x y = ite (expb c) (expw x) (expw y) + let signed w x = signed (bits w) (expw x) + let unsigned w x = unsigned (bits w) (expw x) + let high w x = high (bits w) (expw x) + let low w x = low (bits w) (expw x) + let append x y = + let x = expw x and y = expw y in + x >>-> fun sx x -> + y >>-> fun sy y -> + append (join sx sy) !!x !!y + + let cast rs bit x = + cast (bits rs) (expb bit) (expw x) + + let concat xs = + Knowledge.List.fold ~init:([],0) xs ~f:(fun (xs,s) x -> + expw x >>| fun x -> + !!x::xs, s + Bitv.size (KB.Class.sort (Value.cls x))) + >>= fun (xs,sz) -> + concat (bits sz) (List.rev xs) + + let let_bit v x y = + expb x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_reg v x y = + expw x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_mem v x y = + expm x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_float v x y = + expf x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let unknown w = unk (bits w) + let extract sz hi lo x = + extract (bits sz) (expw hi) (expw lo) (expw x) + let neg x = neg (expw x) + let not x = not (expw x) + + let cast_int s m w = cast_int (bits s) (expr m) (expf w) + let cast_sint s m w = cast_sint (bits s) (expr m) (expf w) + + let fbits x = fbits (expf x) + end) + and expm : type k x b e r. + context -> + (e,r,b) parser -> e -> (k,x) mem = + fun ctxt self -> self.mem (module struct + open Knowledge.Syntax + type nonrec t = (k, x) mem + type exp = e + + let run = expm + let expw s = expw ctxt self s + let expm s = expm ctxt self s + let expb s = expb ctxt self s + let expf s = expf ctxt self s + + let error = Knowledge.fail Error + + let store m k x = store (expm m) (expw k) (expw x) + let store_word d m k x = + storew (expb d) (expm m) (expw k) (expw x) + let var v ks vs = + let s = Mem.define (bits ks) (bits vs) in + var (Var.create s (rename ctxt v)) + let ite c x y = ite (expb c) (expm x) (expm y) + + let let_bit v x y = + expb x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_reg v x y = + expw x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_mem v x y = + expm x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_float v x y = + expm x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let unknown ks vs = unk (Mem.define (bits ks) (bits vs)) + end) + + and expb : type s b e r. + context -> + (e,r,b) parser -> e -> bool = + fun ctxt self -> self.bool (module struct + open Knowledge.Syntax + type nonrec t = bool + type exp = e + + let run = expb + let expw s = expw ctxt self s + let expm s = expm ctxt self s + let expf s = expf ctxt self s + let expb s = run ctxt self s + + let error = Knowledge.fail Error + + let var v = var (Var.create bool (rename ctxt v)) + let ite c x y = ite (expb c) (expb x) (expb y) + let le x y = ule (expw x) (expw y) + let sle x y = sle (expw x) (expw y) + let lt x y = ult (expw x) (expw y) + let slt x y = slt (expw x) (expw y) + let logor x y = or_ (expb x) (expb y) + let logand x y = and_ (expb x) (expb y) + let logxor x y = + let x = expb x and y = expb y in + and_ (or_ x y) (inv (and_ x y)) + let eq x y = eq (expw x) (expw y) + let neq x y = neq (expw x) (expw y) + + let let_bit v x y = + expb x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_reg v x y = + expw x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_mem v x y = + expm x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_float v x y = + expf x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let not x = inv (expb x) + let unknown _ = (unk bool) + + let int x = if Bitvec.(x = zero) then b0 else b1 + + let high x = lsb (high (bits 1) (expw x)) + let low x = lsb (low (bits 1) (expw x)) + + let extract n x = + expw x >>-> fun xs x -> + lsb (extract (bits 1) (of_int xs n) (of_int xs n) !!x) + + let fless = forder + let feq x y = and_ (inv (fless x y)) (inv (fless y x)) + + let fle x y = + let x = expf x and y = expf y in + or_ (fless x y) (feq x y) + + let flt x y = fless (expf x) (expf y) + let feq x y = feq (expf x) (expf y) + + let is_fneg x = is_fneg (expf x) + let is_fpos x = is_fpos (expf x) + let is_fzero x = is_fzero (expf x) + let is_nan x = is_nan (expf x) + let is_inf x = is_inf (expf x) + end) + and expf : type s b e r k n i g a. + context -> + (e,r,b) parser -> e -> ((i, g, a) IEEE754.t, s) format float = + fun ctxt self -> self.float (module struct + type nonrec t = ((i, g, a) IEEE754.t, s) format float + type exp = e + type rmode = r + + let run = expf + let expw s = expw ctxt self s + let expm s = expm ctxt self s + let expr s = expr ctxt self s + let expb s = expb ctxt self s + let expf s = run ctxt self s + + let error = Knowledge.fail Error + + let floats s = IEEE754.Sort.define s + let ieee754 s x : t = float (floats s) (expw x) + let ieee754_var s name : t = var (mkvar (floats s) name) + let ieee754_unk s = unk (floats s) + + let fadd m x y = fadd (expr m) (expf x) (expf y) + let fsub m x y = fsub (expr m) (expf x) (expf y) + let fmul m x y = fmul (expr m) (expf x) (expf y) + let fdiv m x y = fdiv (expr m) (expf x) (expf y) + let frem m x y = fmodulo (expr m) (expf x) (expf y) + let fmin x y = + let x = expf x and y = expf y in + ite (forder x y) x y + let fmax x y = + let x = expf x and y = expf y in + ite (forder x y) y x + + let ite c x y = ite (expb c) (expf x) (expf y) + let fabs x = fabs (expf x) + let fneg x = fneg (expf x) + let fsqrt m x = fsqrt (expr m) (expf x) + let fround m x = fround (expr m) (expf x) + + let ieee754_cast s m x = + cast_float (floats s) (expr m) (expw x) + + let ieee754_cast_signed s m x = + cast_sfloat (floats s) (expr m) (expw x) + + let ieee754_convert s m x = + fconvert (floats s) (expr m) (expf x) + + let let_bit v x y = + expb x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_reg v x y = + expw x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_mem v x y = + expm x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + + let let_float v x y = + expf x >>-> fun s x -> + Var.scoped s @@ fun r -> + let_ r !!x (run ((v,Var.ident r)::ctxt) self y) + end) + and expr : type b e r. + context -> + (e,r,b) parser -> r -> rmode = + fun _ctxt self -> self.rmode (module struct + type nonrec t = rmode + type exp = r + let error = Knowledge.fail Error + let rne = rne + let rtz = rtz + let rtp = rtp + let rtn = rtn + let rtz = rtz + let rna = rna + end) + + let rec run : type e s r. (e,r,s) parser -> s list -> unit eff = + fun parser code -> bil [] parser code + + and bil : type e s r. context -> (e,r,s) parser -> s list -> unit eff = + fun ctxt parser xs -> stmts ctxt parser xs + + and stmts : type e s r. + context -> + (e,r,s) parser -> s list -> unit eff = fun ctxt self -> function + | [] -> newlabel >>= fun lbl -> blk lbl pass skip + | x :: xs -> + self.stmt (module struct + type nonrec t = unit eff + type exp = e + type stmt = s + type rmode = r + + let next = stmts ctxt self + + let bind exp body = + exp >>-> fun s exp -> + Var.fresh s >>= fun v -> + newlabel >>= fun lbl -> + let b1 = blk lbl (set v !!exp) skip in + seq b1 (body v) + + let error = Knowledge.fail Error + + let special _ = + newlabel >>= fun lbl -> + seq (blk lbl pass skip) (next xs) + + let cpuexn n = + Label.for_ivec n >>= fun dst -> + newlabel >>= fun lbl -> + seq (blk lbl pass (goto dst)) (next xs) + + let while_ cnd ys = + newlabel >>= fun lbl -> + seq + (blk lbl (repeat (expb ctxt self cnd) (stmtd ctxt self ys)) skip) + (next xs) + + let if_ cnd yes nay = + seq + (branch (expb ctxt self cnd) + (bil ctxt self yes) + (bil ctxt self nay)) + (next xs) + + let jmp exp = + newlabel >>= fun lbl -> + seq (blk lbl pass (jmp (expw ctxt self exp))) (next xs) + + let call name = + newlabel >>= fun lbl -> + Label.for_name name >>= fun dst -> + seq (blk lbl pass (goto dst)) (next xs) + + let goto addr = + newlabel >>= fun lbl -> + Label.for_addr addr >>= fun dst -> + seq (blk lbl pass (goto dst)) (next xs) + + let move eff = + newlabel >>= fun lbl -> + seq (blk lbl eff skip) (next xs) + let set_bit var exp = move (set_bit ctxt self var exp) + let set_reg var sz exp = move (set_reg ctxt self var sz exp) + let set_mem var ks vs exp = move (set_mem ctxt self var ks vs exp) + let set_ieee754 var s exp = move (set_ieee754 ctxt self var s exp) + let set_rmode var exp = move (set_rmode ctxt self var exp) + let push var r = stmts ((var, Var.ident r) :: ctxt) self xs + let tmp_bit var exp = bind (expb ctxt self exp) (push var) + let tmp_reg var exp = bind (expw ctxt self exp) (push var) + let tmp_mem var exp = bind (expm ctxt self exp) (push var) + let tmp_float var exp = bind (expf ctxt self exp) (push var) + let tmp_rmode var exp = bind (expr ctxt self exp) (push var) + let let_gen t var exp body = + seq (bind (t ctxt self exp) + (fun r -> stmts ((var, Var.ident r) :: ctxt) self [body])) + (next xs) + + let let_bit = let_gen expb + let let_reg = let_gen expw + let let_mem = let_gen expm + let let_float = let_gen expf + let let_rmode = let_gen expr + + let seq ys = seq (next ys) (next xs) + end) x + + and set_bit : type e s r. + context -> + (e,r,s) parser -> string -> e -> data eff = + fun ctxt self v x -> set (mkvar bool v) (expb ctxt self x) + + and set_reg : type e s r. + context -> + (e,r,s) parser -> string -> int -> e -> data eff = + fun ctxt self v s x -> + set (mkvar (bits s) v) (expw ctxt self x) + + and set_mem : type e s r. + context -> + (e,r,s) parser -> string -> int -> int -> e -> data eff = + fun ctxt self v ks vs x -> + set (mkvar (Mem.define (bits ks) (bits vs)) v) (expm ctxt self x) + + and set_ieee754 : type e s r. + context -> + (e,r,s) parser -> string -> IEEE754.parameters -> e -> data eff = + fun ctxt self v fs x -> set (mkvar (IEEE754.Sort.define fs) v) (expf ctxt self x) + + and set_rmode : type e s r. + context -> + (e,r,s) parser -> string -> r -> data eff = + fun ctxt self v x -> set (mkvar Rmode.t v) (expr ctxt self x) + + and stmtd : type e s r. + context -> + (e,r,s) parser -> s list -> data eff = fun ctxt self -> function + | [] -> pass + | x :: xs -> + self.stmt (module struct + type nonrec t = data eff + type exp = e + type stmt = s + type rmode = r + + let next = stmtd ctxt self + + let bind exp body = + exp >>-> fun s exp -> + Var.fresh s >>= fun v -> + seq (set v !!exp) (body v) + + let error = Knowledge.fail Error + + let special _ = seq pass (next xs) + let cpuexn _ = assert false + let while_ cnd ys = + seq + (repeat (expb ctxt self cnd) (next ys)) + (next xs) + + let if_ cnd yes nay = + seq + (branch (expb ctxt self cnd) + (stmtd ctxt self yes) + (stmtd ctxt self nay)) + (next xs) + + + let jmp _ = assert false + let goto _ = assert false + let call _ = assert false + + let move eff = seq eff (next xs) + let set_bit var exp = move (set_bit ctxt self var exp) + let set_reg var sz exp = move (set_reg ctxt self var sz exp) + let set_mem var ks vs exp = move (set_mem ctxt self var ks vs exp) + let set_ieee754 var s exp = move (set_ieee754 ctxt self var s exp) + let set_rmode var exp = move (set_rmode ctxt self var exp) + + let push var r = stmtd ((var, Var.ident r) :: ctxt) self xs + let tmp_bit var exp = bind (expb ctxt self exp) (push var) + let tmp_reg var exp = bind (expw ctxt self exp) (push var) + let tmp_mem var exp = bind (expm ctxt self exp) (push var) + let tmp_float var exp = bind (expf ctxt self exp) (push var) + let tmp_rmode var exp = bind (expr ctxt self exp) (push var) + + let let_gen t var exp body = + seq (bind (t ctxt self exp) + (fun r -> stmtd ((var, Var.ident r) :: ctxt) self [body])) + (next xs) + + let let_bit = let_gen expb + let let_reg = let_gen expw + let let_mem = let_gen expm + let let_float = let_gen expf + let let_rmode = let_gen expr + + let seq ys = seq (next ys) (next xs) + end) x +end diff --git a/lib/bap_core_theory/bap_core_theory_parser.mli b/lib/bap_core_theory/bap_core_theory_parser.mli new file mode 100644 index 000000000..934d8ccc1 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_parser.mli @@ -0,0 +1,52 @@ +open Bap_knowledge +open Bap_core_theory_definition +module Grammar = Bap_core_theory_grammar_definition + +type ('a,'e,'r) bitv_parser = + (module Grammar.Bitv with type t = 'a + and type exp = 'e + and type rmode = 'r) -> + 'e -> 'a + +type ('a,'e,'r) bool_parser = + (module Grammar.Bool with type t = 'a + and type exp = 'e) -> + 'e -> 'a + +type ('a,'e) mem_parser = + (module Grammar.Mem with type t = 'a + and type exp = 'e) -> + 'e -> 'a + +type ('a,'e,'r,'s) stmt_parser = + (module Grammar.Stmt with type t = 'a + and type exp = 'e + and type stmt = 's + and type rmode = 'r) -> + 's -> 'a + +type ('a,'e,'r) float_parser = + (module Grammar.Float with type t = 'a + and type exp = 'e + and type rmode = 'r) -> + 'e -> 'a + +type ('a,'e) rmode_parser = + (module Grammar.Rmode with type t = 'a + and type exp = 'e) -> + 'e -> 'a + +type ('e,'r,'s) t = { + bitv : 'a. ('a,'e,'r) bitv_parser; + bool : 'a. ('a,'e,'r) bool_parser; + mem : 'a. ('a,'e) mem_parser; + stmt : 'a. ('a,'e,'r,'s) stmt_parser; + float : 'a . ('a,'e,'r) float_parser; + rmode : 'a . ('a,'r) rmode_parser; +} + +type ('e,'r,'s) parser = ('e,'r,'s) t + +module Make(S : Core) : sig + val run : ('e,'r,'s) parser -> 's list -> unit eff +end diff --git a/lib/bap_core_theory/bap_core_theory_program.ml b/lib/bap_core_theory/bap_core_theory_program.ml new file mode 100644 index 000000000..434f2645c --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_program.ml @@ -0,0 +1,96 @@ +open Core_kernel +open Bitvec_order.Comparators +open Bap_knowledge + +let package = "core-theory" +module Effect = Bap_core_theory_effect + +type cls = Program +type program = cls +let (cls : (cls,unit) Knowledge.cls) = Knowledge.Class.declare ~package "program" () +let program = cls + +module Label = struct + let word = Knowledge.Domain.optional "word" + ~equal:Bitvec.equal + ~inspect:Bitvec_sexp.sexp_of_t + + let name = Knowledge.Domain.optional "name" + ~equal:String.equal + ~inspect:sexp_of_string + + let names = Knowledge.Domain.powerset (module String) "names" + ~inspect:sexp_of_string + + + let int = Knowledge.Domain.optional "ivec" + ~equal:Int.equal + ~inspect:sexp_of_int + + let attr name = + let bool_t = Knowledge.Domain.optional + ~inspect:sexp_of_bool ~equal:Bool.equal "bool" in + Knowledge.Class.property ~package cls name bool_t + ~persistent:(Knowledge.Persistent.of_binable (module struct + type t = bool option [@@deriving bin_io] + end)) + + + let is_valid = attr "is-valid" + let is_subroutine = attr "is-subroutine" + + + let addr = Knowledge.Class.property ~package cls "label-addr" word + ~persistent:(Knowledge.Persistent.of_binable (module struct + type t = Bitvec_binprot.t option + [@@deriving bin_io] + end)) + + let name = + Knowledge.Class.property ~package cls "label-name" name + ~persistent:(Knowledge.Persistent.of_binable (module struct + type t = string option [@@deriving bin_io] + end)) + + let ivec = + Knowledge.Class.property ~package cls "label-ivec" int + ~persistent:(Knowledge.Persistent.of_binable (module struct + type t = int option [@@deriving bin_io] + end)) + + let aliases = + Knowledge.Class.property ~package cls "label-aliases" names + ~persistent:(Knowledge.Persistent.of_binable (module struct + type t = String.Set.t [@@deriving bin_io] + end)) + + + open Knowledge.Syntax + + let for_name s = + Knowledge.Symbol.intern ~package s cls >>= fun obj -> + Knowledge.provide name obj (Some s) >>| fun () -> obj + + let for_addr x = + let s = Bitvec.to_string x in + Knowledge.Symbol.intern ~package s cls >>= fun obj -> + Knowledge.provide addr obj (Some x) >>| fun () -> obj + + let for_ivec x = + let s = sprintf "int-%d" x in + Knowledge.Symbol.intern ~package:"label" s cls >>= fun obj -> + Knowledge.provide ivec obj (Some x) >>| fun () -> obj + + include (val Knowledge.Object.derive cls) +end + +module Semantics = struct + type cls = Effect.cls + let cls = Knowledge.Class.refine Effect.cls Effect.Sort.top + module Self = (val Knowledge.Value.derive cls) + let slot = Knowledge.Class.property ~package program "semantics" Self.domain + ~persistent:(Knowledge.Persistent.of_binable (module Self)) + include Self +end + +include (val Knowledge.Value.derive cls) diff --git a/lib/bap_core_theory/bap_core_theory_program.mli b/lib/bap_core_theory/bap_core_theory_program.mli new file mode 100644 index 000000000..2772f5396 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_program.mli @@ -0,0 +1,33 @@ +open Core_kernel +open Bap_knowledge + +module Effect = Bap_core_theory_effect + +type cls +type program = cls +type t = (program,unit) Knowledge.cls Knowledge.value +val cls : (program,unit) Knowledge.cls +module Semantics : sig + type cls = Effect.cls + type t = unit Effect.t + val cls : (cls, unit Effect.sort) Knowledge.cls + val slot : (program, t) Knowledge.slot + include Knowledge.Value.S with type t := t +end + +include Knowledge.Value.S with type t := t + +module Label : sig + type t = program Knowledge.obj + val addr : (program, Bitvec.t option) Knowledge.slot + val name : (program, string option) Knowledge.slot + val ivec : (program, Int.t option) Knowledge.slot + val aliases : (program, Set.M(String).t) Knowledge.slot + val is_valid : (program, bool option) Knowledge.slot + val is_subroutine : (program, bool option) Knowledge.slot + val for_addr : Bitvec.t -> t knowledge + val for_name : string -> t knowledge + val for_ivec : int -> t knowledge + + include Knowledge.Object.S with type t := t +end diff --git a/lib/bap_core_theory/bap_core_theory_value.ml b/lib/bap_core_theory/bap_core_theory_value.ml new file mode 100644 index 000000000..4b5bfc1e7 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_value.ml @@ -0,0 +1,275 @@ +open Core_kernel +open Caml.Format + +open Bap_knowledge +module KB = Knowledge + +let package = "core-theory" + + +module Sort : sig + type +'a exp + type +'a sym + type +'a num + type +'a t = 'a exp + type cls + + type top = unit t + type name + + val cls : (cls,unit) KB.cls + + val sym : name -> 'a sym exp + val int : int -> 'a num exp + val app : 'a exp -> 'b exp -> ('a -> 'b) exp + val (@->) : 'a exp -> 'b exp -> ('a -> 'b) exp + + val value : 'a num exp -> int + val name : 'a sym exp -> name + + val hd : ('a -> 'b) exp -> 'a exp + val tl : ('a -> 'b) exp -> 'b exp + + val pp : formatter -> 'a t -> unit + val forget : 'a t -> top + val refine : name -> top -> 'a t option + val same : 'a t -> 'b t -> bool + + module Top : sig + type t = top [@@deriving bin_io, compare, sexp] + include Base.Comparable.S with type t := t + end + + module Name : sig + type t + val declare : ?package:string -> string -> name + include Base.Comparable.S with type t := t + end +end += struct + type +'a sym + type +'a num + type cls = Values + + type name = { + package : string; + name : string; + } [@@deriving bin_io, compare, sexp] + + type names = { + unique : Hash_set.M(String).t; + packages : Set.M(String).t Hashtbl.M(String).t; + } + + let registry = { + unique = Hash_set.create (module String) (); + packages = Hashtbl.create (module String); + } + + let cls = KB.Class.declare ~package:"core-theory" + "value" () + + module Name = struct + type t = name [@@deriving bin_io, compare, sexp] + let declare ?(package="user") name = + Hashtbl.update registry.packages package ~f:(function + | None -> Set.singleton (module String) name + | Some names -> + if Set.mem names name + then failwithf "Type name `%s' is already defined \ + for package `%s'. Please, pick a unique \ + name or a different package." name package (); + Set.add names name); + if Hash_set.mem registry.unique name + then Hash_set.remove registry.unique name + else Hash_set.add registry.unique name; + {package; name} + + let to_string {package; name} = + if Hash_set.mem registry.unique name then name + else sprintf "%s:%s" package name + + include Base.Comparable.Make(struct + type t = name [@@deriving bin_io, compare, sexp] + end) + end + + module Exp = struct + type t = + | Sym of name + | Int of int + | App of {args : t list; name : name option} + [@@deriving bin_io, compare, sexp] + end + open Exp + + type +'a exp = Exp.t + type +'a t = 'a exp + type top = unit t + + let app s p = match p with + | App {args=xs; name} -> App {args=s::xs; name} + | Int _ -> App {args=[s;p]; name=None} + | Sym name -> App {args=[s;p]; name = Some name} + + let sym s = Sym s + let int s = Int s + + let (@->) = app + + let name = function Sym s -> s + | _ -> assert false + let value = function Int s -> s + | _ -> assert false + + let hd = function App {args=x::_} -> x + | _ -> assert false + let tl = function App {args=_::xs; name} -> App {args=xs;name} + | _ -> assert false + + let is_digit s = String.length s > 0 && Char.is_digit s.[0] + + open Format + + let pp_sep ppf () = fprintf ppf ",@ " + + + let rec pp ppf = function + | Sym s -> fprintf ppf "%s" (Name.to_string s) + | Int n -> fprintf ppf "%d" n + | App {args=xs} -> + let f,args = + let sx = List.rev xs in + List.hd_exn sx, List.(rev @@ tl_exn sx) in + fprintf ppf "%a(%a)" pp f + (pp_print_list ~pp_sep pp) args + + + + let forget = ident + let refine witness t = match t with + | App {name=Some name} + | Sym name when name = witness -> Some t + | _ -> None + + let forget : 'a t -> unit t = ident + + let same x y = Exp.compare x y = 0 + + module Top = struct + type t = top + include Sexpable.Of_sexpable(Exp)(struct + type t = top + let to_sexpable x = x + let of_sexpable x = x + end) + include Binable.Of_binable(Exp)(struct + type t = top + let to_binable x = x + let of_binable x = x + end) + include Base.Comparable.Inherit(Exp)(struct + type t = top + let sexp_of_t x = Exp.sexp_of_t x + let component x = x + end) + end +end + +type 'a sort = 'a Sort.t +type 'a sym = 'a Sort.sym +type 'a num = 'a Sort.num +type cls = Sort.cls +type 'a t = (cls,'a sort) KB.cls KB.value +let cls = Sort.cls +let empty s : 'a t = KB.Value.empty (KB.Class.refine cls s) +let sort v : 'a sort = KB.Class.sort (KB.Value.cls v) + + +module Bool : sig + type t + val t : t sort + val refine : Sort.top -> t sort option +end = struct + type bool and t = bool sym + let bool = Sort.Name.declare ~package "Bool" + let t = Sort.sym bool + let refine x = Sort.refine bool x +end + +module Bitv : sig + type 'a t + val define : int -> 'a t sort + val refine : Sort.top -> 'a t sort option + val size : 'a t sort -> int +end = struct + type bitv + type 'a t = 'a num -> bitv sym + let bitvec = Sort.Name.declare ~package "BitVec" + let define m : 'a t sort = Sort.(int m @-> sym bitvec) + let refine s = Sort.refine bitvec s + let size x = Sort.(value @@ hd x) +end + +module Mem : sig + type ('a,'b) t + val define : 'a Bitv.t sort -> 'b Bitv.t sort -> ('a,'b) t sort + val refine : Sort.top -> ('a,'b) t sort option + val keys : ('a,'b) t sort -> 'a Bitv.t sort + val vals : ('a,'b) t sort -> 'b Bitv.t sort +end = struct + type mem + type ('a,'b) t = 'a Bitv.t -> 'b Bitv.t -> mem sym + let mem = Sort.Name.declare ~package "Mem" + let define (ks : 'a Bitv.t sort) (vs : 'b Bitv.t sort) : ('a,'b) t sort = + Sort.(ks @-> vs @-> sym mem) + let refine x = Sort.refine mem x + let keys x = Sort.(hd x) + let vals x = Sort.(hd (tl x)) +end + + +module Float : sig + module Format : sig + type ('r,'s) t + val define : 'r Sort.exp -> 's Bitv.t sort -> ('r,'s) t Sort.exp + val bits : ('r,'s) t Sort.exp -> 's Bitv.t sort + val exp : ('r,'s) t Sort.exp -> 'r Sort.exp + end + + type ('r,'s) format = ('r,'s) Format.t + type 'f t + + val define : ('r,'s) format Sort.exp -> ('r,'s) format t sort + val refine : Sort.top -> ('r,'s) format t sort option + val format : ('r,'s) format t sort -> ('r,'s) format Sort.exp + val size : ('r,'s) format t sort -> 's Bitv.t sort +end = struct + module Format = struct + type ('r,'s) t = ('r -> 's Bitv.t) + let define repr bits = Sort.(repr @-> bits) + let bits x = Sort.(tl x) + let exp x = Sort.(hd x) + end + type float + type ('r,'s) format = ('r,'s) Format.t + type 'f t = 'f -> float sym + let float = Sort.Name.declare ~package "Float" + let define fmt = Sort.(fmt @-> sym float) + let refine x = Sort.refine float x + let format x = Sort.(hd x) + let size x = Format.bits (format x) +end + +module Rmode : sig + type t + val t : t sort + val refine : unit sort -> t sort option +end += struct + type rmode + type t = rmode sym + let rmode = Sort.Name.declare "Rmode" + let t = Sort.(sym rmode) + let refine x = Sort.refine rmode x +end diff --git a/lib/bap_core_theory/bap_core_theory_value.mli b/lib/bap_core_theory/bap_core_theory_value.mli new file mode 100644 index 000000000..6df1e2243 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_value.mli @@ -0,0 +1,96 @@ +open Core_kernel +open Caml.Format +open Bap_knowledge + +module KB = Knowledge + +type +'a sort +type cls + +type 'a t = (cls,'a sort) KB.cls KB.value +val cls : (cls,unit) KB.cls + +val empty : 'a sort -> 'a t +val sort : 'a t -> 'a sort + +module Sort : sig + type +'a t = 'a sort + type +'a sym + type +'a num + type name + type cls + + val sym : name -> 'a sym sort + val int : int -> 'a num sort + val app : 'a sort -> 'b sort -> ('a -> 'b) sort + val (@->) : 'a sort -> 'b sort -> ('a -> 'b) sort + + val value : 'a num sort -> int + val name : 'a sym sort -> name + + val hd : ('a -> 'b) sort -> 'a sort + val tl : ('a -> 'b) sort -> 'b sort + + + val forget : 'a t -> unit t + val refine : name -> unit sort -> 'a t option + val same : 'a t -> 'b t -> bool + + val pp : formatter -> 'a t -> unit + + module Top : sig + type t = unit sort [@@deriving bin_io, compare, sexp] + include Base.Comparable.S with type t := t + end + + module Name : sig + type t + val declare : ?package:string -> string -> name + include Base.Comparable.S with type t := t + end +end + +module Bool : sig + type t + val t : t sort + val refine : unit sort -> t sort option +end + + +module Bitv : sig + type 'a t + val define : int -> 'a t sort + val refine : unit sort -> 'a t sort option + val size : 'a t sort -> int +end + +module Mem : sig + type ('a,'b) t + val define : 'a Bitv.t sort -> 'b Bitv.t sort -> ('a,'b) t sort + val refine : unit sort -> ('a,'b) t sort option + val keys : ('a,'b) t sort -> 'a Bitv.t sort + val vals : ('a,'b) t sort -> 'b Bitv.t sort +end + +module Float : sig + module Format : sig + type ('r,'s) t + val define : 'r Sort.t -> 's Bitv.t sort -> ('r,'s) t Sort.t + val bits : ('r,'s) t Sort.t -> 's Bitv.t sort + val exp : ('r,'s) t Sort.t -> 'r Sort.t + end + + type ('r,'s) format = ('r,'s) Format.t + type 'f t + + val define : ('r,'s) format Sort.t -> ('r,'s) format t sort + val refine : unit sort -> ('r,'s) format t sort option + val format : ('r,'s) format t sort -> ('r,'s) format Sort.t + val size : ('r,'s) format t sort -> 's Bitv.t sort +end + +module Rmode : sig + type t + val t : t sort + val refine : unit sort -> t sort option +end diff --git a/lib/bap_core_theory/bap_core_theory_var.ml b/lib/bap_core_theory/bap_core_theory_var.ml new file mode 100644 index 000000000..083ec1925 --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_var.ml @@ -0,0 +1,160 @@ +open Core_kernel + +open Caml.Format +open Bap_knowledge +open Bap_core_theory_value +open Knowledge.Syntax + +module Value = Knowledge.Value + +let package = "core-theory" + +type const = Const [@@deriving bin_io, compare, sexp] +type mut = Mut [@@deriving bin_io, compare, sexp] + +let const = Knowledge.Class.declare ~package "const-var" Const + ~desc:"local immutable variables" + +let mut = Knowledge.Class.declare ~package "mut-var" Mut + ~desc:"temporary mutable variables" + +type ident = + | Reg of {name : string; ver : int} + | Let of {num : Int63.t} + | Var of {num : Int63.t; ver : int} +[@@deriving bin_io, compare, hash, sexp] + +type 'a var = 'a sort * ident +type 'a t = 'a var + +let valid_first_char = function + | '0'..'9' | '#' | '$' -> false + | _ -> true + +let valid_char c = + valid_first_char c || match c with + | '0' .. '9' | '\'' | '.' -> true + | _ -> false + +let non_empty name = + if String.length name = 0 + then invalid_arg "Invalid var literal: a variable can't be empty" + +let all_chars_valid name = + if not (valid_first_char name.[0]) + then invalid_argf + "Invalid var literal: a variable can't start from %c" name.[0] (); + match String.find name ~f:(Fn.non valid_char) with + | None -> () + | Some c -> + invalid_argf + "Invalid var literal: a variable can't contain char %c" c () + +let validate_variable name = + non_empty name; + all_chars_valid name + +let define sort name : 'a var = + validate_variable name; + non_empty name; + sort, Reg {name; ver=0} + +let create sort ident = sort,ident + +let forget (s,v) = Sort.forget s,v +let resort (_,v) s = s,v + +let pp_ver ppf = function + | 0 -> () + | n -> fprintf ppf ".%d" n + +let pp_ident ppf ident = match ident with + | Reg {name; ver} -> Format.fprintf ppf "%s%a" name pp_ver ver + | Let {num} -> + Format.fprintf ppf "$%a" Int63.pp num + | Var {num; ver} -> + Format.fprintf ppf "#%a%a" Int63.pp num pp_ver ver + +let name (_,v) = Format.asprintf "%a" pp_ident v +let ident (_,v) = v +let sort (s,_) = s +let is_virtual v = match ident v with + | Let _ | Var _ -> true + | Reg _ -> false +let is_mutable v = match ident v with + | Let _ -> false + | Reg _ | Var _ -> true + +let nat1 = Knowledge.Domain.total "nat1" + ~empty:0 + ~inspect:sexp_of_int + ~order:Int.compare + +let versioned (s,v) ver = match v with + | Let _ -> (s,v) + | Reg {name} -> s,Reg {name; ver} + | Var {num} -> s,Var {num; ver} + +let version v = match ident v with + | Let _ -> 0 + | Reg {ver} | Var {ver} -> ver + +let fresh s = + Knowledge.Object.create mut >>| fun v -> + create s (Var {num = Knowledge.Object.id v; ver=0}) + +type 'a pure = 'a Bap_core_theory_value.t knowledge + +(* we're ensuring that a variable is immutable by constraining + the scope computation to be pure. *) +let scoped : 'a sort -> ('a t -> 'b pure) -> 'b pure = fun s f -> + Knowledge.Object.scoped const @@ fun v -> + f @@ create s (Let {num = Knowledge.Object.id v}) + +module Ident = struct + type t = ident [@@deriving bin_io, compare, hash, sexp] + + let num s = try Int63.of_string s with _ -> + failwithf "`%s' is not a valid temporary value" s () + + let split_version s = + match String.rfindi s ~f:(fun _ c -> c = '.') with + | None -> s,0 + | Some n -> + String.subo ~len:n s, + Int.of_string (String.subo ~pos:(n+1) s) + + let of_string x = + let n = String.length x in + if n = 0 + then invalid_arg "a variable identifier can't be empty"; + Scanf.sscanf x "%c%s" @@ function + | '$' -> fun s -> Let {num = num s} + | '#' -> fun s -> + let s,ver = split_version s in + Var {num = num s; ver} + | _ -> fun _ -> + validate_variable x; + let name,ver = split_version x in + Reg {name; ver} + + let to_string x = Format.asprintf "%a" pp_ident x + include Base.Comparable.Make(struct + type t = ident [@@deriving bin_io, compare, sexp] + end) + +end +type ord = Ident.comparator_witness + +module Top : sig + type t = unit var [@@deriving bin_io, compare, sexp] + include Base.Comparable.S with type t := t +end = struct + type t = Sort.Top.t * ident [@@deriving bin_io, sexp] + + include Base.Comparable.Inherit(Ident)(struct + type t = Sort.Top.t * ident + let sexp_of_t = sexp_of_t + let component = snd + end) +end diff --git a/lib/bap_core_theory/bap_core_theory_var.mli b/lib/bap_core_theory/bap_core_theory_var.mli new file mode 100644 index 000000000..fc2f471fd --- /dev/null +++ b/lib/bap_core_theory/bap_core_theory_var.mli @@ -0,0 +1,38 @@ +open Core_kernel +open Bap_knowledge +open Bap_core_theory_value + + +type 'a t +type ord +type ident [@@deriving bin_io, compare, sexp] +type 'a pure = 'a Bap_core_theory_value.t knowledge + + +val define : 'a sort -> string -> 'a t +val create : 'a sort -> ident -> 'a t +val forget : 'a t -> unit t +val resort : 'a t -> 'b sort -> 'b t + +val version : 'a t -> int +val versioned : 'a t -> int -> 'a t + +val ident : 'a t -> ident +val name : 'a t -> string +val sort : 'a t -> 'a sort +val is_virtual : 'a t -> bool +val is_mutable : 'a t -> bool +val fresh : 'a sort -> 'a t knowledge +val scoped : 'a sort -> ('a t -> 'b pure) -> 'b pure + +module Ident : sig + type t = ident [@@deriving bin_io, compare, sexp] + include Stringable.S with type t := t + include Base.Comparable.S with type t := t + and type comparator_witness = ord +end + +module Top : sig + type nonrec t = unit t [@@deriving bin_io, compare, sexp] + include Base.Comparable.S with type t := t +end diff --git a/lib/bap_disasm/.merlin b/lib/bap_disasm/.merlin index 2fb3f0bc1..2575fd4b8 100644 --- a/lib/bap_disasm/.merlin +++ b/lib/bap_disasm/.merlin @@ -1 +1,2 @@ -REC \ No newline at end of file +REC +PKG zarith diff --git a/lib/bap_disasm/bap_disasm.ml b/lib/bap_disasm/bap_disasm.ml index ec3111054..eb52f26d8 100644 --- a/lib/bap_disasm/bap_disasm.ml +++ b/lib/bap_disasm/bap_disasm.ml @@ -29,7 +29,7 @@ type error = [ type mem_state = | Failed of error (** failed to decode anything *) | Decoded of insn * error option (** decoded with optional errors *) - [@@deriving sexp_of] +[@@deriving sexp_of] type cfg = Rec.cfg [@@deriving compare] @@ -62,6 +62,8 @@ let of_rec d = { module Disasm = struct + module Driver = Bap_disasm_driver + type t = disasm type 'a disassembler = ?backend:string -> ?brancher:brancher -> ?rooter:rooter -> 'a @@ -92,7 +94,7 @@ module Disasm = struct else return dis) let of_file ?backend ?brancher ?rooter ?loader filename = - Image.create ?backend:loader filename >>= fun (img,errs) -> + Image.create ?backend:loader filename >>= fun (img,_) -> of_image ?backend ?brancher ?rooter img module With_exn = struct @@ -104,6 +106,7 @@ module Disasm = struct of_image ?backend ?brancher ?rooter image |> ok_exn end + let insn = Value.Tag.register (module Insn) ~name:"insn" ~uuid:"8e2a3998-bf07-4a52-a791-f74ea190630a" diff --git a/lib/bap_disasm/bap_disasm.mli b/lib/bap_disasm/bap_disasm.mli index 94f004429..096d1bb9b 100644 --- a/lib/bap_disasm/bap_disasm.mli +++ b/lib/bap_disasm/bap_disasm.mli @@ -13,6 +13,8 @@ type block = Bap_disasm_block.t [@@deriving compare, sexp_of] type cfg = Bap_disasm_rec.Cfg.t [@@deriving compare] module Disasm : sig + module Driver = Bap_disasm_driver + type t = disasm type 'a disassembler = ?backend:string -> ?brancher:brancher -> ?rooter:rooter -> 'a val create : cfg -> disasm diff --git a/lib/bap_disasm/bap_disasm_basic.ml b/lib/bap_disasm/bap_disasm_basic.ml index fc90d2501..dd4cf6a0a 100644 --- a/lib/bap_disasm/bap_disasm_basic.ml +++ b/lib/bap_disasm/bap_disasm_basic.ml @@ -1,6 +1,7 @@ open Core_kernel open Regular.Std open Bap_types.Std +open Bap_core_theory open Or_error module Kind = Bap_insn_kind @@ -71,19 +72,30 @@ module Table = struct Bytes.to_string dst) end -type dis = { +type disassembler = { dd : int; insn_table : Table.t; reg_table : Table.t; + mutable users : int; +} + +let last_id = ref 0 +let disassemblers = Hashtbl.create (module String) + + + +type dis = { + name : string; asm : bool; kinds : bool; - mutable closed : bool; } -let (!!) dis = - if dis.closed then - failwith "with_disasm: dis value leaked the scope"; - dis.dd +let get {name} = match Hashtbl.find disassemblers name with + | None -> + failwith "Trying to access a closed disassembler" + | Some d -> d + +let (!!) h = (get h).dd module Reg = struct @@ -94,7 +106,7 @@ module Reg = struct if reg_code = 0 then "Nil" else let off = C.insn_op_reg_name !!dis ~insn ~oper in - (Table.lookup dis.reg_table off) in + (Table.lookup (get dis).reg_table off) in {reg_code; reg_name} in {insn; oper; data} @@ -103,7 +115,7 @@ module Reg = struct module T = struct type t = reg - [@@deriving bin_io, sexp, compare] + [@@deriving bin_io, sexp, compare] let module_name = Some "Bap.Std.Reg" let version = "1.0.0" @@ -146,7 +158,7 @@ module Imm = struct module T = struct type t = imm - [@@deriving bin_io, sexp, compare] + [@@deriving bin_io, sexp, compare] let module_name = Some "Bap.Std.Imm" let version = "1.0.0" let pp fmt t = @@ -176,7 +188,7 @@ module Fmm = struct module T = struct type t = fmm - [@@deriving bin_io, sexp, compare] + [@@deriving bin_io, sexp, compare] let module_name = Some "Bap.Std.Fmm" let version = "1.0.0" @@ -194,7 +206,7 @@ module Op = struct | Reg of reg | Imm of imm | Fmm of fmm - [@@deriving bin_io, compare, sexp] + [@@deriving bin_io, compare, sexp] let pr fmt = Format.fprintf fmt let pp fmt = function @@ -245,7 +257,7 @@ module Op = struct end type op = Op.t - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] let cpred_of_pred : pred -> C.pred = function | `Valid -> C.Is_true @@ -292,7 +304,7 @@ module Insn = struct let code = C.insn_code !!dis ~insn in let name = let off = C.insn_name !!dis ~insn in - Table.lookup dis.insn_table off in + Table.lookup (get dis).insn_table off in let asm = if asm then let data = Bytes.create (C.insn_asm_size !!dis ~insn) in @@ -317,6 +329,11 @@ module Insn = struct {code; name; asm; kinds; opers } + let domain = + KB.Domain.optional ~inspect:sexp_of_t "insn" + ~equal:(fun x y -> Int.equal x.code y.code) + let slot = KB.Class.property ~package:"bap.std" + Theory.Program.cls "insn" domain end type ('a,'k) insn = ('a,'k) Insn.t @@ -341,7 +358,7 @@ type (+'a,+'k) insns = (mem * ('a,'k) insn option) list module Pred = Comparable.Make(struct type t = pred [@@deriving compare, sexp] -end) + end) module Preds = Pred.Set type preds = Preds.t [@@deriving compare, sexp] @@ -388,9 +405,6 @@ let insn_mem s ~insn : mem = let from = Addr.(Mem.min_addr s.current.mem ++ off) in ok_exn (Mem.view s.current.mem ~from ~words) - -let kinds s ~insn : kind list = [] - let set_memory dis p : unit = let open Bigsubstring in let buf = Mem.to_buffer p.mem in @@ -483,28 +497,41 @@ let back s data = | x :: xs -> x,xs in step { s with current; history} data -let create ?(debug_level=0) ?(cpu="") ~backend triple = - let dd = match C.create ~backend ~triple ~cpu ~debug_level with - | n when n >= 0 -> Ok n - | -2 -> errorf "Unknown backend: %s" backend - | -3 -> errorf "Unsupported target: %s %s" triple cpu - | n -> errorf "Disasm.Basic: Unknown error %d" n in - dd >>= fun dd -> return { - dd; - insn_table = Table.create (C.insn_table dd); - reg_table = Table.create (C.reg_table dd); - asm = false; - kinds = false; - closed = false; - } +let create ?(debug_level=0) ?(cpu="") ?(backend="llvm") triple = + let name = sprintf "%s:%s%s" backend triple cpu in + match Hashtbl.find disassemblers name with + | Some d -> + d.users <- d.users + 1; + Ok {name; asm=false; kinds=false} + | None -> + let dd = match C.create ~backend ~triple ~cpu ~debug_level with + | n when n >= 0 -> Ok n + | -2 -> errorf "Unknown backend: %s" backend + | -3 -> errorf "Unsupported target: %s %s" triple cpu + | n -> errorf "Disasm.Basic: Unknown error %d" n in + dd >>= fun dd -> + let disassembler = { + dd; + insn_table = Table.create (C.insn_table dd); + reg_table = Table.create (C.reg_table dd); + users = 1; + } in + Hashtbl.add_exn disassemblers name disassembler; + Ok {name; asm = false; kinds = false} let close dis = - C.delete dis.dd; - dis.closed <- true + let disassembler = get dis in + disassembler.users <- disassembler.users - 1; + if disassembler.users = 0 + then begin + Hashtbl.remove disassemblers dis.name; + C.delete disassembler.dd; + end + -let with_disasm ?debug_level ?cpu ~backend triple ~f = - create ?debug_level ?cpu ~backend triple >>= fun dis -> +let with_disasm ?debug_level ?cpu ?backend triple ~f = + create ?debug_level ?cpu ?backend triple >>= fun dis -> f dis >>| fun res -> close dis; res type ('a,'k) t = dis @@ -514,15 +541,14 @@ let run ?backlog ?(stop_on=[]) ?invalid ?stopped ?hit dis ~return ~init mem = create_state ?backlog ?invalid ?stopped ?hit ~return dis mem in let state = with_preds state stop_on in + C.store_asm_string !!dis dis.asm; + C.store_predicates !!dis dis.kinds; jump state (memory state) init let store_kinds d = - C.store_predicates !!d true; {d with kinds = true} - let store_asm d = - C.store_asm_string !!d true; {d with asm = true} let insn_of_mem dis mem = @@ -555,8 +581,8 @@ module Trie = struct let length = fst let nth_token (_, State s) i = match s.insns.(i) with - | (mem, None) -> 0, [| |] - | (mem, Some insn) -> Insn.(insn.code, insn.opers) + | (_, None) -> 0, [| |] + | (_, Some insn) -> Insn.(insn.code, insn.opers) let token_hash = Hashtbl.hash end diff --git a/lib/bap_disasm/bap_disasm_basic.mli b/lib/bap_disasm/bap_disasm_basic.mli index 50e385e68..58f69c2ce 100644 --- a/lib/bap_disasm/bap_disasm_basic.mli +++ b/lib/bap_disasm/bap_disasm_basic.mli @@ -1,6 +1,7 @@ open Core_kernel open Regular.Std open Bap_types.Std +open Bap_core_theory type mem = Bap_memory.t [@@deriving sexp_of] type kind = Bap_insn_kind.t [@@deriving compare, sexp] @@ -24,10 +25,10 @@ type ('a,'k) t type (+'a,+'k,'s,'r) state val with_disasm : - ?debug_level:int -> ?cpu:string -> backend:string -> string -> + ?debug_level:int -> ?cpu:string -> ?backend:string -> string -> f:((empty, empty) t -> 'a Or_error.t) -> 'a Or_error.t -val create : ?debug_level:int -> ?cpu:string -> backend:string -> string -> +val create : ?debug_level:int -> ?cpu:string -> ?backend:string -> string -> (empty, empty) t Or_error.t val close : (_,_) t -> unit @@ -44,6 +45,7 @@ val run : ('a,'k) t -> return:('s -> 'r) -> init:'s -> mem -> 'r + val insn_of_mem : (_,_) t -> mem -> (mem * (asm,kinds) insn option * [`left of mem | `finished]) Or_error.t @@ -87,6 +89,7 @@ module Insn : sig val is : ('a,kinds) t -> kind -> bool val asm : (asm,'k) t -> string val ops : ('a,'k) t -> op array + val slot : (Theory.program,full_insn option) KB.slot end module Reg : sig diff --git a/lib/bap_disasm/bap_disasm_brancher.ml b/lib/bap_disasm/bap_disasm_brancher.ml index 4738a58ff..cfdf3ba2e 100644 --- a/lib/bap_disasm/bap_disasm_brancher.ml +++ b/lib/bap_disasm/bap_disasm_brancher.ml @@ -1,8 +1,11 @@ +open Bap_core_theory open Core_kernel open Bap_types.Std open Bap_image_std open Monads.Std +open KB.Syntax + module Source = Bap_disasm_source module Targets = Bap_disasm_target_factory module Dis = Bap_disasm_basic @@ -151,3 +154,22 @@ let of_image img = create (dests_of_bil ~rel_info (Image.arch img)) module Factory = Source.Factory.Make(struct type nonrec t = t end) + +let (>>=?) x f = x >>= function + | None -> KB.return Insn.empty + | Some x -> f x + + +let provide brancher = + let init = Set.empty (module Theory.Label) in + KB.promise Theory.Program.Semantics.slot @@ fun label -> + KB.collect Memory.slot label >>=? fun mem -> + KB.collect Dis.Insn.slot label >>=? fun insn -> + resolve brancher mem insn |> + KB.List.fold ~init ~f:(fun dsts dst -> + match dst with + | Some addr,_ -> + Theory.Label.for_addr (Word.to_bitvec addr) >>| fun dst -> + Set.add dsts dst + | None,_ -> KB.return dsts) >>| fun dests -> + KB.Value.put Insn.Slot.dests Insn.empty (Some dests) diff --git a/lib/bap_disasm/bap_disasm_brancher.mli b/lib/bap_disasm/bap_disasm_brancher.mli index 315e486b9..3522c6a23 100644 --- a/lib/bap_disasm/bap_disasm_brancher.mli +++ b/lib/bap_disasm/bap_disasm_brancher.mli @@ -20,4 +20,6 @@ val resolve : t -> mem -> full_insn -> dests val empty : t +val provide : t -> unit + module Factory : Factory with type t = t diff --git a/lib/bap_disasm/bap_disasm_calls.ml b/lib/bap_disasm/bap_disasm_calls.ml new file mode 100644 index 000000000..0aed3e87f --- /dev/null +++ b/lib/bap_disasm/bap_disasm_calls.ml @@ -0,0 +1,167 @@ +open Bap_core_theory + +module OCamlGraph = Graph + +open Core_kernel +open Graphlib.Std + +open Bap_types.Std +open Bap_image_std + +open KB.Syntax + +module Driver = Bap_disasm_driver +module Insn = Bap_disasm_insn + + +module Parent = struct + let none = Word.b0 + let unknown = Word.b1 + let equal = Word.equal + let is_root p = equal p none + let is_known p = not (equal p unknown) + let merge x y = + if equal x unknown then y else + if equal y unknown then x else + if equal x y then x else none + + let transfer self parent = + if equal parent none then self else parent +end + +module Parents = struct + type t = (word,word) Solution.t + include Binable.Of_binable(struct + type t = (word * word) Seq.t [@@deriving bin_io] + end)(struct + type t = (word,word) Solution.t + let to_binable = Solution.enum + let of_binable xs = + let init = ok_exn @@ + Map.of_increasing_sequence + (module Word) xs in + Solution.create init Parent.unknown + end) +end + +type input = Driver.state +type output = { + parents : Parents.t; + entries : Addr.Set.t; +} [@@deriving bin_io] + +type t = output [@@deriving bin_io] + +module Callgraph = struct + let entry = Word.b0 + let exit = Word.b1 + let is_entry = Word.equal entry + include Graphlib.Make(Addr)(Unit) + let mark_as_root n g = + if Word.equal n entry then g + else + let e = Edge.create entry n () in + Edge.insert e g +end + +let string_of_node n = + sprintf "%S" @@ if Callgraph.is_entry n + then "entry" + else Addr.string_of_value n + +let pp_callgraph ppf graph = + Graphlib.to_dot (module Callgraph) graph + ~formatter:ppf + ~string_of_node + + +let pp_roots ppf graph = + Graphlib.to_dot (module Callgraph) graph + ~formatter:ppf + ~string_of_node:(fun s -> + sprintf "%S" (Addr.string_of_value s)) + +let of_disasm disasm = + Driver.explore disasm ~init:Callgraph.empty + ~block:(fun mem _ -> KB.return (Memory.min_addr mem)) + ~node:(fun n g -> + let g = Callgraph.Node.insert n g in + Theory.Label.for_addr (Word.to_bitvec n) >>= fun code -> + KB.collect Theory.Label.is_subroutine code >>| function + | Some true -> Callgraph.mark_as_root n g + | _ -> g) + ~edge:(fun src dst g -> + let e = Callgraph.Edge.create src dst () in + KB.return (Callgraph.Edge.insert e g)) + + +let empty = + let root = + Map.singleton (module Addr) Callgraph.entry Parent.none in { + parents = Solution.create root Parent.unknown; + entries = Set.empty (module Addr); + } + +let connect_inputs g = + Callgraph.nodes g |> + Seq.fold ~init:g ~f:(fun g n -> + if Callgraph.Node.degree ~dir:`In n g = 0 + then Callgraph.mark_as_root n g + else g) + +let connect_unreachable_scc g = + Graphlib.depth_first_search (module Callgraph) g + ~start:Callgraph.entry + ~init:g + ~start_tree:Callgraph.mark_as_root + +let callgraph disasm = + of_disasm disasm >>| + connect_inputs >>| + connect_unreachable_scc + +let parent parents addr = + let parent = Solution.get parents addr in + if Parent.equal parent Parent.none then addr else parent + +let entries graph parents = + let init = Set.empty (module Addr) in + Callgraph.nodes graph |> Seq.fold ~init ~f:(fun entries n -> + if not (Parent.is_root n) && Parent.equal (parent parents n) n + then Set.add entries n + else entries) + + +let pp_calls ppf (parents,graph) = + Graphlib.to_dot (module Callgraph) graph + ~formatter:ppf + ~string_of_node + ~node_attrs:(fun n -> + if parent parents n = n + then [`Shape `Diamond; `Style `Filled] + else []) + +let update {parents} disasm = + callgraph disasm >>| fun graph -> + Graphlib.fixpoint (module Callgraph) graph + ~init:parents + ~start:Callgraph.entry + ~equal:Parent.equal + ~merge:Parent.merge + ~f:Parent.transfer + |> fun parents -> + { + parents; + entries = entries graph parents; + } + +let entry {parents} addr = parent parents addr + +let entries {entries} = entries + +let equal s1 s2 = + Set.equal s1.entries s2.entries && + Solution.equal ~equal:Word.equal s1.parents s2.parents + + +let domain = KB.Domain.flat ~empty ~equal "callgraph" diff --git a/lib/bap_disasm/bap_disasm_calls.mli b/lib/bap_disasm/bap_disasm_calls.mli new file mode 100644 index 000000000..f2dfe12aa --- /dev/null +++ b/lib/bap_disasm/bap_disasm_calls.mli @@ -0,0 +1,14 @@ +open Core_kernel +open Bap_core_theory +open Graphlib.Std +open Bap_types.Std +module Driver = Bap_disasm_driver + +type t [@@deriving bin_io] + +val empty : t +val equal : t -> t -> bool +val update : t -> Driver.state -> t KB.t +val entry : t -> addr -> addr +val entries : t -> Set.M(Addr).t +val domain : t KB.domain diff --git a/lib/bap_disasm/bap_disasm_driver.ml b/lib/bap_disasm/bap_disasm_driver.ml new file mode 100644 index 000000000..9e1957db6 --- /dev/null +++ b/lib/bap_disasm/bap_disasm_driver.ml @@ -0,0 +1,505 @@ +open Core_kernel +open Bap_types.Std +open Bap_core_theory +open Bap_image_std + +open KB.Syntax + +module Dis = Bap_disasm_basic +module Insn = Bap_disasm_insn + +type full_insn = Dis.full_insn [@@deriving sexp_of] +type insn = Insn.t [@@deriving sexp_of] +type edge = [`Jump | `Cond | `Fall] [@@deriving compare] + + +type dsts = { + barrier : bool; + indirect : bool; + resolved : Addr.Set.t; +} [@@deriving bin_io] + +module Machine : sig + type task = private + | Dest of {dst : addr; parent : task option} + | Fall of {dst : addr; parent : task; delay : slot} + | Jump of {src : addr; age: int; dsts : dsts; parent : task} + and slot = private + | Ready of task option + | Delay + + type state = private { + stop : bool; + work : task list; (* work list*) + curr : task; + addr : addr; (* current address *) + begs : Set.M(Addr).t; (* begins of basic blocks *) + jmps : dsts Map.M(Addr).t; (* jumps *) + code : Set.M(Addr).t; (* all valid instructions *) + data : Set.M(Addr).t; (* all non-instructions *) + usat : Set.M(Addr).t; (* unsatisfied constraints *) + } + + val start : + mem -> + code:Set.M(Addr).t -> + data:Set.M(Addr).t -> + init:Set.M(Addr).t -> + empty:(state -> 'a) -> + ready:(state -> mem -> 'a) -> 'a + + val view : state -> mem -> + empty:(state -> 'a) -> + ready:(state -> mem -> 'a) -> 'a + + val failed : state -> addr -> state + val jumped : state -> mem -> dsts -> int -> state + val stopped : state -> state + val moved : state -> mem -> state + val is_ready : state -> bool +end = struct + + type task = + | Dest of {dst : addr; parent : task option} + | Fall of {dst : addr; parent : task; delay : slot} + | Jump of {src : addr; age: int; dsts : dsts; parent : task} + and slot = + | Ready of task option + | Delay + + + let init_work roots = + Set.to_sequence ~order:`Decreasing roots |> + Seq.fold ~init:[] ~f:(fun work root -> + Dest {dst=root; parent=None} :: work) + + type state = { + stop : bool; + work : task list; (* work list*) + curr : task; + addr : addr; (* current address *) + begs : Set.M(Addr).t; (* begins of basic blocks *) + jmps : dsts Map.M(Addr).t; (* jumps *) + code : Set.M(Addr).t; (* all valid instructions *) + data : Set.M(Addr).t; (* all non-instructions *) + usat : Set.M(Addr).t; (* unsatisfied constraints *) + } + + let is_code s addr = Set.mem s.code addr + let is_data s addr = Set.mem s.data addr + let is_visited s addr = is_code s addr || is_data s addr + let is_ready s = s.stop + + + let mark_data s addr = { + s with + data = Set.add s.data addr; + begs = Set.remove s.begs addr; + usat = Set.remove s.usat addr; + code = Set.remove s.code addr; + jmps = Map.filter_map (Map.remove s.jmps addr) ~f:(fun dsts -> + let resolved = Set.remove dsts.resolved addr in + if Set.is_empty resolved && not dsts.indirect + then None + else Some {dsts with resolved}); + } + + let has_valid s dsts = + dsts.indirect || + Set.exists dsts.resolved ~f:(fun dst -> + not (Set.mem s.data dst)) + + let pp_task ppf = function + | Dest {dst; parent=None} -> + Format.fprintf ppf "Root %a" Addr.pp dst + | Dest {dst} -> + Format.fprintf ppf "Dest %a" Addr.pp dst + | Fall {dst} -> + Format.fprintf ppf "Fall %a" Addr.pp dst + | Jump {src; age} -> + Format.fprintf ppf "Delay%d %a" age Addr.pp src + + let rec cancel task s = match task with + | Dest {parent=None} -> s + | Dest {parent=Some parent} | Fall {parent} | Jump {parent} -> + match parent with + | Dest {dst} | Fall {dst} -> cancel parent (mark_data s dst) + | Jump {src; dsts} -> match task with + | Fall {delay=(Ready (Some _) | Delay)} -> + cancel parent (mark_data s src) + | Fall _ when has_valid s dsts -> s + | Fall _ | Dest _ -> cancel parent (mark_data s src) + | Jump _ -> assert false + + let rec step s = match s.work with + | [] -> + if Set.is_empty s.usat then {s with stop = true} + else step {s with work = [ + Dest {dst=Set.min_elt_exn s.usat; parent=None} + ]} + | Dest {dst=next} as curr :: work -> + let s = if is_data s next + then step @@ cancel curr {s with work} + else {s with begs = Set.add s.begs next} in + if is_visited s next then step {s with work} + else {s with work; addr=next; curr} + | Fall {dst=next} as curr :: work -> + if is_code s next + then step {s with begs = Set.add s.begs next; work} + else if is_data s next then step @@ cancel curr {s with work} + else {s with work; addr=next; curr} + | (Jump {src; dsts} as jump) :: ([] as work) + | (Jump {src; dsts; age=0} as jump) :: work -> + if Set.mem s.data src then step @@ cancel jump {s with work} + else + let resolved = Set.filter dsts.resolved ~f:(fun dst -> + not (Set.mem s.data dst)) in + let dsts = {dsts with resolved} in + let init = {s with jmps = Map.add_exn s.jmps src dsts; work} in + step @@ + Set.fold resolved ~init ~f:(fun s next -> + if not (is_visited s next) + then {s with work = Dest {dst=next; parent = Some jump} :: + s.work} + else s) + | Jump jmp as self :: Fall ({dst=next} as slot) :: work -> + let delay = if jmp.age = 1 then Ready (Some self) else Delay in + step { + s with + work = Fall {slot with delay} :: Jump { + jmp with age = jmp.age-1; src = next; + } :: work + } + | Jump jmp :: work -> step { + s with work = Jump {jmp with age=0} :: work + } + + let decoded s mem = + let addr = Memory.min_addr mem in { + s with code = Set.add s.code addr; + usat = Set.remove s.usat addr + } + + let jumped s mem dsts delay = + let s = decoded s mem in + let parent = s.curr in + let src = Memory.min_addr mem in + let jump = Jump {src; age=delay; dsts; parent} in + let next = Addr.succ (Memory.max_addr mem) in + let next = + if dsts.barrier && delay = 0 + then Dest {dst=next; parent=None} + else Fall {dst=next; parent=jump; delay = Ready None} in + step {s with work = jump :: next :: s.work } + + let insert_delayed t = function + | x :: xs -> x :: t :: xs + | [] -> [t] + + let moved s mem = + let parent = match s.curr with + | Fall {delay=Ready (Some parent)} -> parent + | _ -> s.curr in + let next = Addr.succ (Memory.max_addr mem) in + let next = match parent with + | Jump {dsts={barrier=true}} -> + Dest {dst=next; parent=None} + | parent -> Fall { + dst = next; + parent; + delay = Ready None; + } in + let work = match s.curr with + | Fall {delay = Delay} -> insert_delayed next s.work + | _ -> next :: s.work in + step @@ decoded {s with work} mem + + + let failed s addr = + step @@ cancel s.curr @@ mark_data s addr + + let stopped s = + step @@ cancel s.curr @@ mark_data s s.addr + + let rec view s base ~empty ~ready = + match Memory.view ~from:s.addr base with + | Ok mem -> ready s mem + | Error _ -> + let s = match s.curr with + | Fall _ as task -> + cancel task s + | _ -> s in + match s.work with + | [] -> empty (step s) + | _ -> view (step s) base ~empty ~ready + + let start mem ~code ~data ~init = + let init = if Set.is_empty init + then Set.singleton (module Addr) (Memory.min_addr mem) + else init in + let work = init_work init in + let start = Set.min_elt_exn init in + view { + work; data; usat=code; + addr = start; + curr = Dest {dst = start; parent = None}; + stop = false; + begs = Set.empty (module Addr); + jmps = Map.empty (module Addr); + code = Set.empty (module Addr); + } mem +end + +let new_insn arch mem insn = + let addr = Addr.to_bitvec (Memory.min_addr mem) in + Theory.Label.for_addr addr >>= fun code -> + KB.provide Arch.slot code (Some arch) >>= fun () -> + KB.provide Memory.slot code (Some mem) >>= fun () -> + KB.provide Dis.Insn.slot code (Some insn) >>| fun () -> + code + +let collect_dests arch mem insn = + let width = Size.in_bits (Arch.addr_size arch) in + let fall = Addr.to_bitvec (Addr.succ (Memory.max_addr mem)) in + new_insn arch mem insn >>= fun code -> + KB.collect Theory.Program.Semantics.slot code >>= fun insn -> + let init = { + barrier = Insn.(is barrier insn); + indirect = false; + resolved = Set.empty (module Addr) + } in + KB.Value.get Insn.Slot.dests insn |> function + | None -> KB.return init + | Some dests -> + Set.to_sequence dests |> + KB.Seq.fold ~init ~f:(fun {barrier; indirect; resolved} label -> + KB.collect Theory.Label.addr label >>| function + | Some d -> + if Bitvec.(d <> fall) + then { + barrier; + indirect; + resolved = Set.add resolved (Word.create d width) + } else {barrier; indirect; resolved} + | None -> + {barrier; indirect=true; resolved}) >>= fun res -> + KB.return res + +let pp_addr_opt ppf = function + | None -> Format.fprintf ppf "Unk" + | Some addr -> Format.fprintf ppf "%a" Bitvec.pp addr + +(* pre: insn is call /\ is a member of a valid chain *) +let mark_call_destinations mem dests = + let next = Addr.to_bitvec @@ Addr.succ @@ Memory.max_addr mem in + Set.to_sequence dests |> + KB.Seq.iter ~f:(fun dest -> + KB.collect Theory.Label.addr dest >>= fun addr -> + if Option.is_none addr || + Bitvec.(Option.value_exn addr <> next) + then KB.provide Theory.Label.is_subroutine dest (Some true) + else KB.return ()) + +let update_calls mem curr = + KB.collect Theory.Program.Semantics.slot curr >>= fun insn -> + if Insn.(is call) insn + then match KB.Value.get Insn.Slot.dests insn with + | None -> KB.return () + | Some dests -> mark_call_destinations mem dests + else KB.return () + +let delay arch mem insn = + new_insn arch mem insn >>= fun code -> + KB.collect Theory.Program.Semantics.slot code >>| fun insn -> + KB.Value.get Insn.Slot.delay insn |> function + | None -> 0 + | Some x -> x + +let classify_mem mem = + let empty = Set.empty (module Addr) in + let base = Memory.min_addr mem in + Seq.range 0 (Memory.length mem) |> + KB.Seq.fold ~init:(empty,empty,empty) ~f:(fun (code,data,root) off -> + let addr = Addr.(nsucc base off) in + let slot = Some (Addr.to_bitvec addr) in + KB.Object.scoped Theory.Program.cls @@ fun label -> + KB.provide Theory.Label.addr label slot >>= fun () -> + KB.collect Theory.Label.is_valid label >>= function + | Some false -> KB.return (code,Set.add data addr,root) + | r -> + let code = if Option.is_none r then code + else Set.add code addr in + KB.collect Theory.Label.is_subroutine label >>| function + | Some true -> (code,data,Set.add root addr) + | _ -> (code,data,root)) + +let scan_mem arch disasm base : Machine.state KB.t = + classify_mem base >>= fun (code,data,init) -> + let step d s = + if Machine.is_ready s then KB.return s + else Machine.view s base ~ready:(fun s mem -> Dis.jump d mem s) + ~empty:KB.return in + Machine.start base ~code ~data ~init + ~ready:(fun init mem -> + Dis.run disasm mem ~stop_on:[`Valid] + ~return:KB.return ~init + ~stopped:(fun d s -> step d (Machine.stopped s)) + ~hit:(fun d mem insn s -> + collect_dests arch mem insn >>= fun dests -> + if Set.is_empty dests.resolved && + not dests.indirect then + step d @@ Machine.moved s mem + else + delay arch mem insn >>= fun delay -> + step d @@ Machine.jumped s mem dests delay) + ~invalid:(fun d _ s -> step d (Machine.failed s s.addr))) + ~empty:KB.return + +type insns = Theory.Label.t list + +type state = { + begs : Addr.Set.t; + jmps : dsts Addr.Map.t; + data : Addr.Set.t; + mems : mem list; +} [@@deriving bin_io] + +let init = { + begs = Set.empty (module Addr); + jmps = Map.empty (module Addr); + data = Set.empty (module Addr); + mems = [] +} + +let query_arch addr = + KB.Object.scoped Theory.Program.cls @@ fun obj -> + KB.provide Theory.Label.addr obj (Some addr) >>= fun () -> + KB.collect Arch.slot obj + +let already_scanned addr s = + List.exists s.mems ~f:(fun mem -> + Memory.contains mem addr) + +let scan mem s = + let open KB.Syntax in + let start = Memory.min_addr mem in + if already_scanned start s + then KB.return s + else query_arch (Word.to_bitvec start) >>= function + | None -> KB.return s + | Some arch -> match Dis.create (Arch.to_string arch) with + | Error _ -> KB.return s + | Ok dis -> + scan_mem arch dis mem >>| fun {Machine.begs; jmps; data} -> + let jmps = Map.merge s.jmps jmps ~f:(fun ~key:_ -> function + | `Left dsts | `Right dsts | `Both (_,dsts) -> Some dsts) in + let begs = Set.union s.begs begs in + let data = Set.union s.data data in + {begs; data; jmps; mems = mem :: s.mems} + +let merge t1 t2 = { + begs = Set.union t1.begs t2.begs; + data = Set.union t1.data t2.data; + mems = List.rev_append t2.mems t1.mems; + jmps = Map.merge t1.jmps t2.jmps ~f:(fun ~key:_ -> function + | `Left dsts | `Right dsts -> Some dsts + | `Both (d1,d2) -> Some { + barrier = d1.barrier || d2.barrier; + indirect = d1.indirect || d2.indirect; + resolved = Set.union d1.resolved d2.resolved; + }) +} + +let list_insns ?(rev=false) insns = + if rev then insns else List.rev insns + +let rec insert pos x xs = + if pos = 0 then x::xs else match xs with + | x' :: xs -> x' :: insert (pos-1) x xs + | [] -> [x] + +let execution_order stack = + KB.List.fold stack ~init:[] ~f:(fun insns insn -> + KB.collect Theory.Program.Semantics.slot insn >>| fun s -> + match KB.Value.get Insn.Slot.delay s with + | None -> insn::insns + | Some d -> insert d insn insns) + +let always _ = KB.return true + +let with_disasm beg cfg f = + query_arch (Word.to_bitvec beg) >>= function + | None -> KB.return (cfg,None) + | Some arch -> + match Dis.create (Arch.to_string arch) with + | Error _ -> KB.return (cfg,None) + | Ok dis -> f arch dis + +let may_fall insn = + KB.collect Theory.Program.Semantics.slot insn >>| fun insn -> + not Insn.(is barrier insn) + +let explore + ?entry:start ?(follow=always) ~block ~node ~edge ~init + {begs; jmps; data; mems} = + let find_base addr = + if Set.mem data addr then None + else List.find mems ~f:(fun mem -> Memory.contains mem addr) in + let blocks = Hashtbl.create (module Addr) in + let edge_insert cfg src dst = match dst with + | None -> KB.return cfg + | Some dst -> edge src dst cfg in + let view ?len from mem = ok_exn (Memory.view ?words:len ~from mem) in + let rec build cfg beg = + if Set.mem data beg then KB.return (cfg,None) + else follow beg >>= function + | false -> KB.return (cfg,None) + | true -> with_disasm beg cfg @@ fun arch dis -> + match Hashtbl.find blocks beg with + | Some block -> KB.return (cfg, Some block) + | None -> match find_base beg with + | None -> KB.return (cfg,None) + | Some base -> + Dis.run dis (view beg base) ~stop_on:[`Valid] + ~init:(beg,0,[],true) ~return:KB.return + ~hit:(fun s mem insn (curr,len,insns,_) -> + new_insn arch mem insn >>= fun insn -> + update_calls mem insn >>= fun () -> + let len = Memory.length mem + len in + let last = Memory.max_addr mem in + let next = Addr.succ last in + if Set.mem data next && not (Map.mem jmps curr) + then assert false + else + if Set.mem begs next || Map.mem jmps curr + then + may_fall insn >>| fun may -> + (curr, len, insn::insns,may) + else Dis.jump s (view next base) + (next, len, insn::insns,true)) + >>= fun (fin,len,insns,may_fall) -> + let mem = view ~len beg base in + block mem insns >>= fun block -> + let fall = Addr.succ (Memory.max_addr mem) in + Hashtbl.add_exn blocks beg block; + node block cfg >>= fun cfg -> + match Map.find jmps fin with + | None when may_fall -> + build cfg fall >>= fun (cfg,dst) -> + edge_insert cfg block dst >>= fun cfg -> + KB.return (cfg, Some block) + | None -> KB.return (cfg, Some block) + | Some {resolved=dsts} -> + let dsts = if may_fall + then Set.add dsts fall else dsts in + Set.to_sequence dsts |> + KB.Seq.fold ~init:cfg ~f:(fun cfg dst -> + build cfg dst >>= fun (cfg,dst) -> + edge_insert cfg block dst) >>= fun cfg -> + KB.return (cfg,Some block) in + match start with + | None -> + Set.to_sequence begs |> + KB.Seq.fold ~init ~f:(fun cfg beg -> + build cfg beg >>| fst) + | Some start -> build init start >>| fst diff --git a/lib/bap_disasm/bap_disasm_driver.mli b/lib/bap_disasm/bap_disasm_driver.mli new file mode 100644 index 000000000..e0c30209d --- /dev/null +++ b/lib/bap_disasm/bap_disasm_driver.mli @@ -0,0 +1,26 @@ +open Core_kernel +open Bap_types.Std +open Bap_image_std +open Bap_knowledge +open Bap_core_theory +module Dis = Bap_disasm_basic + +type state [@@deriving bin_io] +type insns + +val init : state +val scan : mem -> state -> state knowledge +val merge : state -> state -> state + +val explore : + ?entry:addr -> + ?follow:(addr -> bool knowledge) -> + block:(mem -> insns -> 'n knowledge) -> + node:('n -> 'c -> 'c knowledge) -> + edge:('n -> 'n -> 'c -> 'c knowledge) -> + init:'c -> + state -> 'c knowledge + + +val list_insns : ?rev:bool -> insns -> Theory.Label.t list +val execution_order : insns -> Theory.Label.t list knowledge diff --git a/lib/bap_disasm/bap_disasm_insn.ml b/lib/bap_disasm/bap_disasm_insn.ml index 863b97428..222753bc9 100644 --- a/lib/bap_disasm/bap_disasm_insn.ml +++ b/lib/bap_disasm/bap_disasm_insn.ml @@ -1,9 +1,11 @@ open Core_kernel +open Bap_core_theory open Regular.Std open Bap_types.Std open Bap_disasm_types module Insn = Bap_disasm_basic.Insn +let package = "bap.std" type must = Must type may = May @@ -12,19 +14,23 @@ type 'a property = Z.t * string let known_properties = ref [] let new_property _ name : 'a property = + let name = sprintf ":%s" name in let bit = List.length !known_properties in let property = Z.shift_left Z.one bit, name in known_properties := !known_properties @ [property]; property let prop = new_property () +(* must be the first one *) +let invalid = prop "invalid" let jump = prop "jump" let conditional = prop "cond" let indirect = prop "indirect" let call = prop "call" let return = prop "return" -let affect_control_flow = prop "affect_control_flow" +let barrier = prop "barrier" +let affect_control_flow = prop "affect-control-flow" let load = prop "load" let store = prop "store" @@ -42,21 +48,109 @@ module Props = struct Z.logand flags flag = flag let set_if cond flag = if cond then fun flags -> flags + flag else ident - include Sexpable.Of_stringable(Bits) - include Binable.Of_stringable(Bits) + + module T = struct + type t = Z.t + include Sexpable.Of_stringable(Bits) + include Binable.Of_stringable(Bits) + end + + let name = snd + + let assoc_of_props props = + List.map !known_properties ~f:(fun p -> + name p, has props p) + + let domain = KB.Domain.flat "props" + ~empty:Z.one ~equal:Z.equal + ~inspect:(fun props -> + [%sexp_of: (string * bool) list] + (assoc_of_props props)) + + let persistent = KB.Persistent.of_binable (module T) + + let slot = KB.Class.property ~package:"bap.std" + ~persistent + Theory.Program.Semantics.cls "insn-properties" domain end -type t = { - code : int; - name : string; - asm : string; - bil : bil; - ops : Op.t array; - props : Props.t; -} [@@deriving bin_io, fields, compare, sexp] +type t = Theory.Program.Semantics.t type op = Op.t [@@deriving bin_io, compare, sexp] + +module Slot = struct + type 'a t = (Theory.Effect.cls, 'a) KB.slot + let empty = "#undefined" + let text = KB.Domain.flat "text" + ~inspect:sexp_of_string ~empty + ~equal:String.equal + + let delay_t = KB.Domain.optional "delay_t" + ~inspect:sexp_of_int + ~equal:Int.equal + + + let name = KB.Class.property ~package:"bap.std" + ~persistent:KB.Persistent.string + Theory.Program.Semantics.cls "insn-opcode" text + + let asm = KB.Class.property ~package:"bap.std" + ~persistent:KB.Persistent.string + Theory.Program.Semantics.cls "insn-asm" text + + let sexp_of_op = function + | Op.Reg r -> Sexp.Atom (Reg.name r) + | Op.Imm w -> sexp_of_int64 (Imm.to_int64 w) + | Op.Fmm w -> sexp_of_float (Fmm.to_float w) + + + let ops_domain = KB.Domain.optional "insn-ops" + ~equal:[%compare.equal: Op.t array] + ~inspect:[%sexp_of: op array] + + let ops_persistent = KB.Persistent.of_binable (module struct + type t = Op.t array option [@@deriving bin_io] + end) + + let ops = KB.Class.property ~package:"bap.std" + ~persistent:ops_persistent + Theory.Program.Semantics.cls "insn-ops" ops_domain + + let delay = KB.Class.property ~package:"bap.std" + Theory.Program.Semantics.cls "insn-delay" delay_t + ~persistent:(KB.Persistent.of_binable (module struct + type t = int option [@@deriving bin_io] + end)) + + type KB.conflict += Jump_vs_Move + + let dests = + let empty = Some (Set.empty (module Theory.Label)) in + let order x y : KB.Order.partial = match x,y with + | None,None -> EQ + | None,_ | _,None -> NC + | Some x, Some y -> + if Set.equal x y then EQ else + if Set.is_subset x y then LT else + if Set.is_subset y x then GT else NC in + let join x y = match x,y with + | None,None -> Ok None + | None,_ |Some _,None -> Error Jump_vs_Move + | Some x, Some y -> Ok (Some (Set.union x y)) in + let module IO = struct + module Set = Set.Make_binable_using_comparator(Theory.Label) + type t = Set.t option [@@deriving bin_io, sexp_of] + end in + let inspect = IO.sexp_of_t in + let data = KB.Domain.define ~empty ~order ~join ~inspect "dest-set" in + let persistent = KB.Persistent.of_binable (module IO) in + KB.Class.property ~package:"bap.std" Theory.Program.Semantics.cls + ~persistent + "insn-dests" data + +end + let normalize_asm asm = String.substr_replace_all asm ~pattern:"\t" ~with_:" " |> String.strip @@ -68,7 +162,7 @@ let lookup_jumps bil = (object | Bil.Int _ when under_condition -> [`Conditional_branch] | Bil.Int _ -> [`Unconditional_branch] | _ when under_condition -> [`Conditional_branch; `Indirect_branch] - | _ -> [`Indirect_branch] + | _ -> [`Unconditional_branch; `Indirect_branch] end)#run bil [] let lookup_side_effects bil = (object @@ -79,7 +173,12 @@ let lookup_side_effects bil = (object `May_load :: acc end)#run bil [] -let of_basic ?bil insn = +let (<--) slot value insn = KB.Value.put slot insn value + +let write init ops = + List.fold ~init ops ~f:(fun init f -> f init) + +let of_basic ?bil insn : t = let bil_kinds = match bil with | Some bil -> lookup_jumps bil @ lookup_side_effects bil | None -> [] in @@ -88,14 +187,23 @@ let of_basic ?bil insn = if bil <> None then List.mem ~equal:[%compare.equal : kind] bil_kinds kind else is kind in + (* those two are the only which we can't get from the BIL semantics *) + let is_return = is `Return in + let is_call = is `Call in + let is_conditional_jump = is_bil `Conditional_branch in let is_jump = is_conditional_jump || is_bil `Unconditional_branch in let is_indirect_jump = is_bil `Indirect_branch in - let is_return = is `Return in - let is_call = is `Call in - let may_affect_control_flow = is `May_affect_control_flow in + let may_affect_control_flow = + is_jump || + is `May_affect_control_flow in + let is_barrier = is_jump && not is_call && not is_conditional_jump in let may_load = is_bil `May_load in let may_store = is_bil `May_store in + let effect = + KB.Value.put Bil.slot + (KB.Value.empty Theory.Program.Semantics.cls) + (Option.value bil ~default:[]) in let props = Props.empty |> Props.set_if is_jump jump |> @@ -103,25 +211,35 @@ let of_basic ?bil insn = Props.set_if is_indirect_jump indirect |> Props.set_if is_call call |> Props.set_if is_return return |> + Props.set_if is_barrier barrier |> Props.set_if may_affect_control_flow affect_control_flow |> Props.set_if may_load load |> Props.set_if may_store store in - { - code = Insn.code insn; - name = Insn.name insn; - asm = normalize_asm (Insn.asm insn); - bil = Option.value bil ~default:[Bil.special "Unknown Semantics"]; - ops = Insn.ops insn; - props; - } - -let is flag t = Props.has t.props flag + write effect Slot.[ + Props.slot <-- props; + name <-- Insn.name insn; + asm <-- normalize_asm (Insn.asm insn); + ops <-- Some (Insn.ops insn); + ] + +let get = KB.Value.get Props.slot +let put = KB.Value.put Props.slot +let is flag t = Props.has (get t) flag let may = is -let must flag insn = {insn with props = Props.(insn.props + flag) } -let mustn't flag insn = {insn with props = Props.(insn.props - flag)} +let must flag insn = put insn Props.(get insn + flag) +let mustn't flag insn = put insn Props.(get insn - flag) let should = must let shouldn't = mustn't +let name = KB.Value.get Slot.name +let asm = KB.Value.get Slot.asm +let bil insn = KB.Value.get Bil.slot insn +let ops s = match KB.Value.get Slot.ops s with + | None -> [||] + | Some ops -> ops + +let empty = KB.Value.empty Theory.Program.Semantics.cls + module Adt = struct let pr fmt = Format.fprintf fmt @@ -135,17 +253,21 @@ module Adt = struct List.map ~f:snd |> String.concat ~sep:", " - let pp ch insn = pr ch "%s(%a, Props(%s))" - (String.capitalize insn.name) - pp_ops (Array.to_list insn.ops) - (props insn) + let pp ppf insn = + let name = name insn in + if name = Slot.empty + then pr ppf "Undefined()" + else pr ppf "%s(%a, Props(%s))" + (String.capitalize name) + pp_ops (Array.to_list (ops insn)) + (props insn) end let pp_adt = Adt.pp module Trie = struct module Key = struct - type token = int * Op.t array [@@deriving bin_io, compare, sexp] + type token = string * Op.t array [@@deriving bin_io, compare, sexp] type t = token array let length = Array.length @@ -156,7 +278,7 @@ module Trie = struct module Normalized = Trie.Make(struct include Key let compare_token (x,xs) (y,ys) = - let r = compare_int x y in + let r = compare_string x y in if r = 0 then Op.Normalized.compare_ops xs ys else r let hash_ops = Array.fold ~init:0 ~f:(fun h x -> h lxor Op.Normalized.hash x) @@ -164,28 +286,31 @@ module Trie = struct x lxor hash_ops xs end) - let token_of_insn insn = insn.code, insn.ops + let token_of_insn insn = name insn, ops insn let key_of_insns = Array.of_list_map ~f:token_of_insn include Trie.Make(Key) end include Regular.Make(struct - type nonrec t = t [@@deriving sexp, bin_io, compare] - let hash t = t.code + type t = Theory.Program.Semantics.t [@@deriving sexp, bin_io, compare] + let hash t = Hashtbl.hash t let module_name = Some "Bap.Std.Insn" - let version = "1.0.0" + let version = "2.0.0" let string_of_ops ops = Array.map ops ~f:Op.to_string |> Array.to_list |> String.concat ~sep:"," let pp fmt insn = - Format.fprintf fmt "%s(%s)" insn.name (string_of_ops insn.ops) + let name = name insn in + if name = Slot.empty + then Format.fprintf fmt "%s" name + else Format.fprintf fmt "%s(%s)" name (string_of_ops (ops insn)) end) let pp_asm ppf insn = - Format.fprintf ppf "%s" insn.asm + Format.fprintf ppf "%s" (normalize_asm (asm insn)) let () = Data.Write.create ~pp:Adt.pp () |> diff --git a/lib/bap_disasm/bap_disasm_insn.mli b/lib/bap_disasm/bap_disasm_insn.mli index b2cb767e2..1a496e0b8 100644 --- a/lib/bap_disasm/bap_disasm_insn.mli +++ b/lib/bap_disasm/bap_disasm_insn.mli @@ -1,11 +1,15 @@ open Core_kernel +open Bap_core_theory open Regular.Std open Bap_types.Std open Bap_disasm_types +open Bap_ir -type t [@@deriving bin_io, compare, sexp] +type t = Theory.Program.Semantics.t [@@deriving bin_io, compare, sexp] type op = Op.t [@@deriving bin_io, compare, sexp] +val empty : t + val of_basic : ?bil:bil -> Basic.full_insn -> t val name : t -> string @@ -24,6 +28,7 @@ val conditional : must property val indirect : must property val call : must property val return : must property +val barrier : must property val affect_control_flow : may property val load : may property val store : may property @@ -35,6 +40,17 @@ val mustn't : must property -> t -> t val should : may property -> t -> t val shouldn't : may property -> t -> t + +module Slot : sig + type 'a t = (Theory.Effect.cls, 'a) KB.slot + val name : string t + val asm : string t + val ops : op array option t + val delay : int option t + val dests : Set.M(Theory.Label).t option t +end + + val pp_adt : Format.formatter -> t -> unit module Trie : sig type key diff --git a/lib/bap_disasm/bap_disasm_rec.ml b/lib/bap_disasm/bap_disasm_rec.ml index 5904bba49..c98e3ea73 100644 --- a/lib/bap_disasm/bap_disasm_rec.ml +++ b/lib/bap_disasm/bap_disasm_rec.ml @@ -2,40 +2,32 @@ open Core_kernel open Regular.Std open Bap_types.Std open Graphlib.Std -open Bil.Types -open Or_error +open Bap_core_theory open Bap_image_std -module Targets = Bap_disasm_target_factory +open KB.Syntax -module Dis = Bap_disasm_basic +module Driver = Bap_disasm_driver module Brancher = Bap_disasm_brancher module Rooter = Bap_disasm_rooter -module Addrs = Addr.Table module Block = Bap_disasm_block module Insn = Bap_disasm_insn +module Basic = Bap_disasm_basic -type full_insn = Dis.full_insn [@@deriving sexp_of] +type full_insn = Basic.full_insn [@@deriving sexp_of] type insn = Insn.t [@@deriving sexp_of] type block = Block.t type edge = Block.edge [@@deriving compare, sexp] type jump = Block.jump [@@deriving compare, sexp] -type lifter = Targets.lifter - -type dis = (Dis.empty, Dis.empty) Dis.t - -type dst = [ - | `Jump of addr option - | `Cond of addr option - | `Fall of addr -] [@@deriving sexp] +type lifter = Bap_disasm_target_factory.lifter type error = [ | `Failed_to_disasm of mem | `Failed_to_lift of mem * full_insn * Error.t ] [@@deriving sexp_of] -type maybe_insn = full_insn option * bil option [@@deriving sexp_of] +type code = Theory.Program.t [@@deriving bin_io, compare, sexp] +type maybe_insn = full_insn option * code [@@deriving sexp_of] type decoded = mem * maybe_insn [@@deriving sexp_of] type dests = Brancher.dests @@ -63,388 +55,81 @@ module Cfg = Graphlib.Make(Node)(Edge) type cfg = Cfg.t [@@deriving compare] -module Visited = struct - - type entry = Unknown | Insn of {last : addr} - type t = entry Addr.Map.t - - let empty = Addr.Map.empty - - let add_insn t mem = - Map.set t (Memory.min_addr mem) (Insn {last = Memory.max_addr mem}) - - let touch t addr = - Map.update t addr ~f:(function - | None -> Unknown - | Some known -> known) - - let find_insn t a = match Map.find t a with - | Some (Insn {last}) -> Some last - | _ -> None - - let min = Map.min_elt - let mem = Map.mem - let forget = Map.remove - - let upper_bound t = Map.closest_key t `Greater_than - - let rec next_insn t a = - match upper_bound t a with - | None -> None - | Some (addr, Insn {last}) -> Some (addr,last) - | Some (addr,_) -> next_insn t addr - - let min_insn t = match min t with - | None -> None - | Some (addr, Insn {last}) -> Some (addr,last) - | Some (addr,_) -> next_insn t addr - - let has_insn t a = match Map.find t a with - | Some (Insn _) -> true - | _ -> false - -end - -type blk_dest = [ - | `Block of block * edge - | `Unresolved of jump -] - -type stage1 = { - base : mem; - addr : addr; - visited : Visited.t; - roots : addr list; - inits : Addr.Set.t; - dests : dests Addr.Table.t; - errors : (addr * error) list; - lift : lifter; -} - -type stage2 = { - stage1 : stage1; - addrs : mem Addrs.t; (* table of blocks *) - succs : dests Addrs.t; - preds : addr list Addrs.t; - disasm : mem -> decoded list; -} - -type stage3 = { - cfg : Cfg.t; - failures : (addr * error) list; -} - -type t = stage3 - -let errored s ty = {s with errors = (s.addr,ty) :: s.errors} - -let filter_dests base ds = - List.map ds ~f:(fun d -> match d with - | Some addr,_ when Memory.contains base addr -> d - | _,kind -> None, kind) - -let update_dests s mem ds = - let key = Memory.min_addr mem in - match filter_dests s.base ds with - | [] -> Addr.Table.add_exn s.dests ~key ~data:[] - | ds -> List.iter ds ~f:(fun data -> - Addr.Table.add_multi s.dests ~key ~data) - -let rec has_jump = function - | [] -> false - | Bil.Jmp _ :: _ | Bil.CpuExn _ :: _ -> true - | Bil.If (_,y,n) :: xs -> has_jump y || has_jump n || has_jump xs - | Bil.While (_,b) :: xs -> has_jump b || has_jump xs - | _ :: xs -> has_jump xs - -let ok_nil = function - | Ok xs -> xs - | Error _ -> [] - -let is_terminator s mem insn = - Dis.Insn.is insn `May_affect_control_flow || - has_jump (ok_nil (s.lift mem insn)) || - Set.mem s.inits (Addr.succ (Memory.max_addr mem)) - -let update s mem insn dests : stage1 = - let s = { s with visited = Visited.add_insn s.visited mem } in - if is_terminator s mem insn then - let () = update_dests s mem dests in - let roots = List.(filter_map ~f:fst dests |> rev_append s.roots) in - { s with roots } - else { - s with roots = Addr.succ (Memory.max_addr mem) :: s.roots - } - -(* switch to next root or finish if there're no roots *) -let next dis s = - let rec loop s = match s.roots with - | [] -> Dis.stop dis s - | r :: roots when not(Memory.contains s.base r) -> - loop {s with roots} - | r :: roots when Visited.mem s.visited r -> - loop {s with roots} - | addr :: roots -> - let mem = Memory.view ~from:addr s.base in - let mem = Result.map_error mem ~f:(fun err -> Error.tag err "next_root") in - mem >>= fun mem -> - Dis.jump dis mem {s with roots; addr} in - let s = {s with visited = Visited.touch s.visited s.addr } in - loop s - -let stop_on = [`Valid] - -let update_intersections disasm brancher s = - let dests = Addr.Table.create () in - let find_destinations a = - Option.value (Addrs.find dests a) ~default:[] in - let next dis = function - | [] -> Dis.stop dis [] - | addr :: roots -> - Memory.view ~from:addr s.base >>= fun mem -> - Dis.jump dis mem roots in - let equal_dest x y = match x,y with - | (Some a, _), (Some a',_) -> Addr.(a = a') - | _ -> false in - let is_intersected ds ds' = - List.exists ds ~f:(fun d -> List.exists ds' ~f:(equal_dest d)) in - let visit_intersections = function - | [] | [_] -> Ok () - | (from :: roots) as intersections -> - Memory.view ~from s.base >>= fun mem -> - Dis.run disasm mem ~stop_on ~return - ~init:roots - ~hit:(fun d mem insn roots -> - let ds = filter_dests s.base (brancher mem insn) in - Hashtbl.set dests (Memory.min_addr mem) ds; - next d roots) - ~stopped:next >>= fun _ -> - List.iter intersections ~f:(fun a -> - let ds = find_destinations a in - let has_common = - List.exists intersections ~f:(fun a' -> - Addr.(a <> a') && is_intersected ds (find_destinations a')) in - if has_common then - Addrs.set s.dests a ds); - Ok () in - let intersections = - Map.fold s.visited ~init:Addr.Map.empty - ~f:(fun ~key:addr ~data:tail inters -> - match tail with - | Visited.Unknown -> inters - | Visited.Insn {last} -> Map.add_multi inters last addr) in - Map.fold intersections ~init:(Ok ()) ~f:(fun ~key:_ ~data:addrs r -> - r >>= fun () -> visit_intersections addrs) >>= fun () -> - return s - -let stage1 ?(rooter=Rooter.empty) lift brancher disasm base = - let roots = - Rooter.roots rooter |> Seq.filter ~f:(Memory.contains base) in - let addr,roots = match Seq.to_list roots with - | r :: rs -> r,rs - | [] -> Memory.min_addr base, [] in - let init = {base; addr; visited = Visited.empty; - roots; inits = Addr.Set.of_list roots; - dests = Addrs.create (); errors = []; lift} in - Memory.view ~from:addr base >>= fun mem -> - Dis.run disasm mem ~stop_on ~return ~init - ~hit:(fun d mem insn s -> - next d (update s mem insn (brancher mem insn))) - ~invalid:(fun d mem s -> next d (errored s (`Failed_to_disasm mem))) - ~stopped:next >>= update_intersections disasm brancher - -(* performs the initial markup. - - Returns three tables: leads, terms, and kinds. Leads is a mapping - from leader to all predcessing terminators. Terms is a mapping from - terminators to all successive leaders. Kinds is a mapping from a - terminator addresses, to all outputs with each output including the - kind annotation, e.g, Cond, Jump, etc. This also includes - unresolved outputs. -*) - -let sexp_of_addr addr = - Sexp.Atom (Addr.string_of_value addr) - -let create_indexes (dests : dests Addr.Table.t) = - let leads = Addrs.create () in - let terms = Addrs.create () in - let succs = Addrs.create () in - Addrs.iteri dests ~f:(fun ~key:src ~data:dests -> - List.iter dests ~f:(fun dest -> - Addrs.add_multi succs ~key:src ~data:dest; - match dest with - | None,_ -> () - | Some dst,_ -> - Addrs.add_multi leads ~key:dst ~data:src; - Addrs.add_multi terms ~key:src ~data:dst)); - leads, terms, succs - -let join_destinations ?default dests = - let jmp x = [ Bil.(jmp (int x)) ] in - let undecided x = - Bil.if_ (Bil.unknown "destination" (Type.Imm 1)) (jmp x) [] in - let init = match default with - | None -> [] - | Some x -> jmp x in - Set.fold dests ~init ~f:(fun bil x -> undecided x :: bil) - -let make_switch x dests = - let case addr = Bil.(if_ (x = int addr) [jmp (int addr)] []) in - let default = Bil.jmp x in - Set.fold dests ~init:[default] ~f:(fun ds a -> case a :: ds) - -let dests_of_bil bil = - (object inherit [Addr.Set.t] Stmt.visitor - method! visit_jmp e dests = match e with - | Int w -> Set.add dests w - | _ -> dests - end)#run bil Addr.Set.empty - -let add_destinations bil = function - | [] -> bil - | dests -> - let d = dests_of_bil bil in - let d' = Addr.Set.of_list dests in - let n = Set.diff d' d in - if Set.is_empty n then bil - else - if has_jump bil then - (object inherit Stmt.mapper - method! map_jmp = function - | Int addr -> join_destinations ~default:addr n - | indirect -> make_switch indirect n - end)#run bil - else bil @ join_destinations n - -let disasm stage1 dis = - let dis = Dis.store_asm dis in - let dis = Dis.store_kinds dis in - fun mem -> - Dis.run dis mem - ~init:[] ~return:ident ~stopped:(fun s _ -> - Dis.stop s (Dis.insns s)) |> - List.map ~f:(function - | mem, None -> mem,(None,None) - | mem, (Some ins as insn) -> - let dests = - Addrs.find stage1.dests (Memory.min_addr mem) |> - Option.value ~default:[] |> - List.filter_map ~f:(function - | a, (`Cond | `Jump) -> a - | _ -> None) in - let bil = match stage1.lift mem ins with - | Ok bil -> bil - | _ -> [] in - mem, (insn, Some (add_destinations bil dests))) - -let stage2 dis stage1 = - let leads, terms, kinds = create_indexes stage1.dests in - let addrs = Addrs.create () in - let succs = Addrs.create () in - let preds = Addrs.create () in - let next = Addr.succ in - let is_edge addr max_addr = - Addrs.mem leads (next max_addr) || - Addrs.mem kinds addr || - Addrs.mem stage1.dests addr || - Set.mem stage1.inits (next max_addr) in - let is_insn = Visited.has_insn stage1.visited in - let next_visited = Visited.upper_bound stage1.visited in - let create_block start addr max_addr = - Memory.range stage1.base start max_addr >>= fun blk -> - Addrs.add_exn addrs ~key:start ~data:blk; - let () = match Addrs.find terms addr with - | None -> () - | Some leaders -> List.iter leaders ~f:(fun leader -> - Addrs.add_multi preds ~key:leader ~data:start) in - let dests = match Addrs.find kinds addr with - | Some dests -> dests - | None when Addrs.mem leads (next max_addr) && - not (Addrs.mem stage1.dests addr) -> - Addrs.add_multi preds ~key:(next max_addr) ~data:start; - [Some (next max_addr),`Fall] - | None -> [] in - Addrs.add_exn succs ~key:start ~data:dests; - return () in - let next_block leftovers start = - let rec loop leftovers last start curr = - match Visited.find_insn leftovers curr with - | Some max_addr -> - let insn = curr,max_addr in - let leftovers = Visited.forget leftovers curr in - if is_edge curr max_addr then Some (leftovers, start, insn) - else loop leftovers (Some insn) start (next max_addr) - | None -> match last with - | Some last when is_insn start -> Some (leftovers,start,last) - | _ -> match next_visited curr with - | Some (addr,_) -> loop leftovers None addr addr - | None -> None in - loop leftovers None start start in - let fetch_blocks () = - let rec loop leftovers start = - match next_block leftovers start with - | Some (leftovers,start,(addr, max_addr)) -> - if Hashtbl.mem addrs start then - loop leftovers (next max_addr) - else - create_block start addr max_addr >>= fun () -> - loop leftovers (next max_addr) - | None -> match Visited.min_insn leftovers with - | Some (addr,_) -> loop leftovers addr - | _ -> Ok () in - match Visited.min_insn stage1.visited with - | Some (addr,_) -> loop stage1.visited addr - | _ -> errorf "Provided memory doesn't contain a recognizable code" in - fetch_blocks () >>= fun () -> - return {stage1; addrs; succs; preds; disasm = disasm stage1 dis} - -let stage3 s2 = - let is_found addr = Addrs.mem s2.addrs addr in - let pred_is_found = is_found in - let succ_is_found = function - | None,_ -> true - | Some addr,_ -> is_found addr in - let filter bs ~f = Addrs.filter_map bs ~f:(fun ps -> - match List.filter ps ~f with - | [] -> None - | preds -> Some preds) in - let s2 = { - s2 with - succs = filter s2.succs ~f:succ_is_found; - preds = filter s2.preds ~f:pred_is_found; - } in - let nodes = Addrs.create () in - Addrs.iteri s2.addrs ~f:(fun ~key:addr ~data:mem -> - s2.disasm mem |> List.filter_map ~f:(function - | _,(None,_) -> None - | mem,(Some insn,bil) -> - Some (mem, Insn.of_basic ?bil insn)) |> function - | [] -> () - | insns -> - let node = Block.create mem insns in - Addrs.set nodes ~key:addr ~data:node); - let cfg = - Addrs.fold nodes ~init:Cfg.empty ~f:(fun ~key:addr ~data:x cfg -> - match Addrs.find s2.succs addr with - | None -> Cfg.Node.insert x cfg - | Some dests -> - List.fold dests ~init:cfg ~f:(fun cfg dest -> match dest with - | None,_ -> Cfg.Node.insert x cfg - | Some d,e -> match Addrs.find nodes d with - | None -> Cfg.Node.insert x cfg - | Some y -> - let edge = Cfg.Edge.create x y e in - Cfg.Edge.insert edge cfg)) in - return {cfg; failures = s2.stage1.errors} - -let run ?(backend="llvm") ?brancher ?rooter arch mem = - let b = Option.value brancher ~default:(Brancher.of_bil arch) in - let brancher = Brancher.resolve b in - let module Target = (val Targets.target_of_arch arch) in - let lifter = Target.lift in - Dis.with_disasm ~backend (Arch.to_string arch) ~f:(fun dis -> - stage1 ?rooter lifter brancher dis mem >>= stage2 dis >>= stage3) - -let cfg t = t.cfg -let errors s = List.map s.failures ~f:snd +type t = Driver.state KB.t + +let create_insn basic prog = + match basic with + | None -> prog + | Some insn -> + let bil = Insn.bil prog in + let prog' = Insn.of_basic ~bil insn in + KB.Value.merge ~on_conflict:`drop_right prog prog' + +let follows_after m1 m2 = Addr.equal + (Addr.succ (Memory.max_addr m1)) + (Memory.min_addr m2) + +let has_conditional_jump blk = + let insn = Block.terminator blk in + Insn.(is jump insn) && + Insn.(is conditional insn) + +let global_cfg disasm = + Driver.explore disasm + ~init:Cfg.empty + ~block:(fun mem insns -> + Driver.execution_order insns >>= + KB.List.filter_map ~f:(fun label -> + KB.collect Basic.Insn.slot label >>= fun basic -> + KB.collect Theory.Program.Semantics.slot label >>= fun s -> + KB.collect Memory.slot label >>| function + | None -> None + | Some mem -> Some (mem,create_insn basic s)) >>| + Block.create mem) + ~node:(fun node cfg -> + KB.return (Cfg.Node.insert node cfg)) + ~edge:(fun src dst g -> + let k = if follows_after (Block.memory src) (Block.memory dst) + then `Fall + else if has_conditional_jump src + then `Cond + else `Jump in + let edge = Cfg.Edge.create src dst k in + KB.return @@ Cfg.Edge.insert edge g) + + + +let result = Toplevel.var "cfg" + +let extract build disasm = + Toplevel.put result begin + disasm >>= build + end; + Toplevel.get result + + +let provide_arch arch mem = + let width = Size.in_bits (Arch.addr_size arch) in + KB.promise Arch.slot @@ fun label -> + KB.collect Theory.Label.addr label >>| function + | None -> None + | Some p -> + let p = Word.create p width in + if Memory.contains mem p then Some arch + else None + +let scan arch mem state = + provide_arch arch mem; + Driver.scan mem state + +let run ?backend ?(brancher=Brancher.empty) ?(rooter=Rooter.empty) arch mem = + Brancher.provide brancher; + Rooter.provide rooter; + provide_arch arch mem; + Ok (Driver.scan mem Driver.init) + +let cfg = extract global_cfg +let errors _ = [] + +let create = KB.return +let graph s = s >>= global_cfg diff --git a/lib/bap_disasm/bap_disasm_rec.mli b/lib/bap_disasm/bap_disasm_rec.mli index f784365f3..cdebe1918 100644 --- a/lib/bap_disasm/bap_disasm_rec.mli +++ b/lib/bap_disasm/bap_disasm_rec.mli @@ -3,12 +3,14 @@ open Core_kernel open Bap_types.Std open Graphlib.Std - +open Bap_knowledge open Image_internal_std open Bap_disasm_basic open Bap_disasm_brancher open Bap_disasm_rooter +module Driver = Bap_disasm_driver + type t type insn = Bap_disasm_insn.t @@ -33,3 +35,6 @@ val run : val cfg : t -> Cfg.t val errors : t -> error list + +val scan : arch -> mem -> Driver.state -> Driver.state knowledge +val global_cfg : Driver.state -> Cfg.t knowledge diff --git a/lib/bap_disasm/bap_disasm_reconstructor.ml b/lib/bap_disasm/bap_disasm_reconstructor.ml index 3294b7f2d..021e778ad 100644 --- a/lib/bap_disasm/bap_disasm_reconstructor.ml +++ b/lib/bap_disasm/bap_disasm_reconstructor.ml @@ -71,13 +71,11 @@ let is_unresolved blk cfg = deg = 0 || (deg = 1 && is_fall (Seq.hd_exn (Cfg.Node.outputs blk cfg))) -let add_call symtab blk name label = - Symtab.add_call symtab blk name label let add_unresolved syms name cfg blk = if is_unresolved blk cfg then let call_addr = terminator_addr blk in - add_call syms blk (name call_addr) `Fall + Symtab.insert_call syms blk (name call_addr) else syms let collect name cfg roots = @@ -119,7 +117,8 @@ let reconstruct name initial_roots prog = let name = name (Block.addr entry) in let syms = Symtab.add_symbol syms (name,entry,cfg) in Set.fold inputs ~init:syms ~f:(fun syms e -> - add_call syms (Cfg.Edge.src e) name (Cfg.Edge.label e)) in + let implicit = Cfg.Edge.label e = `Fall in + Symtab.insert_call ~implicit syms (Cfg.Edge.src e) name) in let remove_node cfg n = Cfg.Node.remove n cfg in let remove_reachable cfg from = let reachable = reachable cfg from in diff --git a/lib/bap_disasm/bap_disasm_rooter.ml b/lib/bap_disasm/bap_disasm_rooter.ml index d9efb03c2..1cb2523a0 100644 --- a/lib/bap_disasm/bap_disasm_rooter.ml +++ b/lib/bap_disasm/bap_disasm_rooter.ml @@ -1,8 +1,13 @@ +open Bap_core_theory + open Core_kernel open Bap_types.Std open Bap_image_std +open KB.Syntax + module Source = Bap_disasm_source +module Insn = Bap_disasm_insn type t = Rooter of addr seq type rooter = t @@ -30,3 +35,18 @@ let of_blocks blocks = | Some a when Addr.(a < sa) -> Some a | _ -> Some sa)); create (Hashtbl.data roots |> Seq.of_list) + +let provide rooter = + let init = Set.empty (module Bitvec_order) in + let roots = + roots rooter |> + Seq.map ~f:Word.to_bitvec |> + Seq.fold ~init ~f:Set.add in + let promise prop = + KB.promise prop @@ fun label -> + KB.collect Theory.Label.addr label >>| function + | None -> None + | Some addr -> + Option.some_if (Set.mem roots addr) true in + promise Theory.Label.is_valid; + promise Theory.Label.is_subroutine diff --git a/lib/bap_disasm/bap_disasm_rooter.mli b/lib/bap_disasm/bap_disasm_rooter.mli index e23feb522..9622b5a6b 100644 --- a/lib/bap_disasm/bap_disasm_rooter.mli +++ b/lib/bap_disasm/bap_disasm_rooter.mli @@ -17,5 +17,7 @@ val roots : t -> addr seq val union : t -> t -> t +val provide : t -> unit + module Factory : Factory with type t = t diff --git a/lib/bap_disasm/bap_disasm_std.ml b/lib/bap_disasm/bap_disasm_std.ml index 4d541968f..26ec32075 100644 --- a/lib/bap_disasm/bap_disasm_std.ml +++ b/lib/bap_disasm/bap_disasm_std.ml @@ -26,6 +26,7 @@ module Symbolizer = Bap_disasm_symbolizer module Brancher = Bap_disasm_brancher module Reconstructor = Bap_disasm_reconstructor + type 'a source = 'a Source.t type symtab = Symtab.t type rooter = Rooter.t diff --git a/lib/bap_disasm/bap_disasm_symbolizer.ml b/lib/bap_disasm/bap_disasm_symbolizer.ml index 644c96ffd..a59de1976 100644 --- a/lib/bap_disasm/bap_disasm_symbolizer.ml +++ b/lib/bap_disasm/bap_disasm_symbolizer.ml @@ -1,15 +1,39 @@ open Core_kernel +open Bap_core_theory open Bap_types.Std open Bap_image_std open Bap_disasm_source +open KB.Syntax + type t = Symbolizer of (addr -> string option) type symbolizer = t +let name_choices = KB.Domain.opinions ~empty:None + ~equal:(Option.equal String.equal) + ~inspect:(sexp_of_option sexp_of_string) + "name-choices" + +let common_name = + KB.Class.property ~package:"bap.std" + Theory.Program.cls "common-name" name_choices + let name_of_addr addr = sprintf "sub_%s" @@ Addr.string_of_value addr +module Name = struct + let is_empty name = + String.is_prefix name ~prefix:"sub_" + + let order x y : KB.Order.partial = + match is_empty x, is_empty y with + | true,true -> EQ + | false,false -> if String.equal x y then EQ else NC + | true,false -> LT + | false,true -> GT +end + let create fn = Symbolizer fn let run (Symbolizer f) a = f a @@ -25,14 +49,12 @@ let chain ss = let of_image img = let symtab = Image.symbols img in let names = Addr.Table.create () in - Table.iteri symtab ~f:(fun mem name -> - Hashtbl.set names - ~key:(Memory.min_addr mem) - ~data:name); - let find addr = match Hashtbl.find names addr with - | None -> None - | Some sym -> Some (Image.Symbol.name sym) in - create find + Table.iteri symtab ~f:(fun mem sym -> + let name = Image.Symbol.name sym + and addr = Memory.min_addr mem in + if not (Name.is_empty name) + then Hashtbl.set names ~key:addr ~data:name); + create (Hashtbl.find names) let of_blocks seq = let syms = Addr.Table.create () in @@ -42,4 +64,38 @@ let of_blocks seq = module Factory = Factory.Make(struct type nonrec t = t end) -let internal_image_symbolizer = (fun img -> Some (of_image img)) +let provide agent (Symbolizer name) = + let open KB.Syntax in + KB.propose agent common_name @@ fun label -> + KB.collect Arch.slot label >>= fun arch -> + KB.collect Theory.Label.addr label >>| fun addr -> + match arch, addr with + | Some arch, Some addr -> + let width = Size.in_bits (Arch.addr_size arch) in + name (Addr.create addr width) + | _ -> None + + +let update_name_slot label name = + KB.collect Theory.Label.name label >>= function + | Some _ -> KB.return () + | None -> + KB.provide Theory.Label.name label (Some name) + +let get_name addr = + let data = Some (Word.to_bitvec addr) in + KB.Object.scoped Theory.Program.cls @@ fun label -> + KB.provide Theory.Label.addr label data >>= fun () -> + KB.resolve common_name label >>= function + | Some name -> KB.return name + | None -> KB.collect Theory.Label.name label >>| function + | Some name -> name + | None -> name_of_addr addr + +module Toplevel = struct + let name = Toplevel.var "symbol-name" + + let get_name addr = + Toplevel.put name (get_name addr); + Toplevel.get name +end diff --git a/lib/bap_disasm/bap_disasm_symbolizer.mli b/lib/bap_disasm/bap_disasm_symbolizer.mli index 7a48950d0..298953dc7 100644 --- a/lib/bap_disasm/bap_disasm_symbolizer.mli +++ b/lib/bap_disasm/bap_disasm_symbolizer.mli @@ -1,3 +1,5 @@ +open Bap_knowledge +open Bap_core_theory open Bap_types.Std open Bap_image_std open Bap_disasm_source @@ -5,6 +7,14 @@ open Bap_disasm_source type t type symbolizer = t +val common_name : (Theory.program, string option KB.opinions) KB.slot +val provide : Knowledge.agent -> t -> unit +val get_name : addr -> string knowledge + +module Toplevel : sig + val get_name : addr -> string +end + val empty : t val create : (addr -> string option) -> t @@ -17,5 +27,9 @@ val resolve : t -> addr -> string val chain : t list -> t +module Name : sig + val is_empty : string -> bool + val order : string -> string -> Knowledge.Order.partial +end module Factory : Factory with type t = t diff --git a/lib/bap_disasm/bap_disasm_symtab.ml b/lib/bap_disasm/bap_disasm_symtab.ml index 34c97423c..7e46df8c1 100644 --- a/lib/bap_disasm/bap_disasm_symtab.ml +++ b/lib/bap_disasm/bap_disasm_symtab.ml @@ -1,14 +1,20 @@ +open Bap_core_theory + open Core_kernel open Regular.Std open Bap_types.Std open Image_internal_std open Or_error +open KB.Syntax open Format module Block = Bap_disasm_block module Cfg = Bap_disasm_rec.Cfg module Insn = Bap_disasm_insn +module Disasm = Bap_disasm_driver +module Callgraph = Bap_disasm_calls +module Symbolizer = Bap_disasm_symbolizer type block = Block.t [@@deriving compare, sexp_of] @@ -26,11 +32,13 @@ module Fn = Opaque.Make(struct let hash x = String.hash (fst3 x) end) + type t = { addrs : fn Addr.Map.t; names : fn String.Map.t; memory : fn Memmap.t; - callees : (string * edge) list Addr.Map.t; + ecalls : string Addr.Map.t; + icalls : string Addr.Map.t; } [@@deriving sexp_of] @@ -47,7 +55,8 @@ let empty = { addrs = Addr.Map.empty; names = String.Map.empty; memory = Memmap.empty; - callees = Addr.Map.empty; + ecalls = Map.empty (module Addr); + icalls = Map.empty (module Addr); } let merge m1 m2 = @@ -58,19 +67,20 @@ let filter_mem mem name entry = Memmap.filter mem ~f:(fun (n,e,_) -> not(String.(name = n) || Block.(entry = e))) -let filter_callees name callees = - Map.map callees - ~f:(List.filter ~f:(fun (name',_) -> String.(name <> name'))) - -let remove t (name,entry,_) : t = - if Map.mem t.addrs (Block.addr entry) then - { - names = Map.remove t.names name; - addrs = Map.remove t.addrs (Block.addr entry); - memory = filter_mem t.memory name entry; - callees = filter_callees name t.callees; - } - else t +let filter_calls name cfg calls = + let init = Map.filter calls ~f:(fun name' -> String.(name <> name')) in + Cfg.nodes cfg |> + Seq.fold ~init ~f:(fun calls node -> + Map.remove calls (Block.addr node)) + +let remove t (name,entry,cfg) : t = + if Map.mem t.addrs (Block.addr entry) then { + names = Map.remove t.names name; + addrs = Map.remove t.addrs (Block.addr entry); + memory = filter_mem t.memory name entry; + ecalls = filter_calls name cfg t.ecalls; + icalls = filter_calls name cfg t.icalls; + } else t let add_symbol t (name,entry,cfg) : t = let data = name,entry,cfg in @@ -97,10 +107,96 @@ let name_of_fn = fst let entry_of_fn = snd let span fn = span fn |> Memmap.map ~f:(fun _ -> ()) -let add_call t b name edge = - {t with callees = Map.add_multi t.callees (Block.addr b) (name,edge)} +let insert_call ?(implicit=false) symtab block data = + let key = Block.addr block in + if implicit then { + symtab with + icalls = Map.set symtab.icalls ~key ~data + } else { + symtab with + ecalls = Map.set symtab.ecalls ~key ~data + } + -let enum_calls t addr = - match Map.find t.callees addr with - | None -> [] - | Some callees -> callees +let explicit_callee {ecalls} = Map.find ecalls +let implicit_callee {icalls} = Map.find icalls + + + +let (<--) = fun g f -> match g with + | None -> None + | Some (e,g) -> Some (e, f g) + +let build_cfg disasm calls entry = + Disasm.explore disasm ~entry ~init:None + ~follow:(fun dst -> + let p = Callgraph.entry calls dst in + KB.return Addr.(p = entry)) + ~block:(fun mem insns -> + Disasm.execution_order insns >>= fun insns -> + KB.List.filter_map insns ~f:(fun label -> + KB.collect Theory.Program.Semantics.slot label >>= fun s -> + KB.collect Memory.slot label >>| function + | None -> None + | Some mem -> Some (mem, s)) >>| fun insns -> + Block.create mem insns) + ~node:(fun n g -> + KB.return @@ + if Addr.equal (Block.addr n) entry + then Some (n,Cfg.Node.insert n Cfg.empty) + else g <-- Cfg.Node.insert n) + ~edge:(fun src dst g -> + let msrc = Block.memory src + and mdst = Block.memory dst in + let next = Addr.succ (Memory.max_addr msrc) in + let kind = if Addr.equal next (Memory.min_addr mdst) + then `Fall else `Jump in + let edge = Cfg.Edge.create src dst kind in + KB.return (g <-- Cfg.Edge.insert edge)) + + +let build_symbol disasm calls start = + build_cfg disasm calls start >>= function + | None -> assert false + | Some (entry,graph) -> + Symbolizer.get_name start >>| fun name -> + name,entry,graph + +let create_intra disasm calls = + Callgraph.entries calls |> + Set.to_sequence |> + KB.Seq.fold ~init:empty ~f:(fun symtab entry -> + build_symbol disasm calls entry >>| fun fn -> + add_symbol symtab fn) + +let create_inter disasm calls init = + Disasm.explore disasm + ~init + ~block:(fun mem _ -> KB.return mem) + ~node:(fun _ s -> KB.return s) + ~edge:(fun src dst s -> + let src = Memory.min_addr src + and dst = Memory.min_addr dst + and next = Addr.succ (Memory.max_addr src) in + if Addr.equal + (Callgraph.entry calls src) + (Callgraph.entry calls dst) + then KB.return s + else + Symbolizer.get_name dst >>| fun name -> + if Addr.equal next dst + then {s with icalls = Map.set s.icalls src name} + else {s with ecalls = Map.set s.ecalls src name}) + + +let create disasm calls = + create_intra disasm calls >>= + create_inter disasm calls + +let result = Toplevel.var "symtab" + +module Toplevel = struct + let create disasm calls = + Toplevel.put result (create disasm calls); + Toplevel.get result +end diff --git a/lib/bap_disasm/bap_disasm_symtab.mli b/lib/bap_disasm/bap_disasm_symtab.mli index 9fc9b4d5b..b194ef648 100644 --- a/lib/bap_disasm/bap_disasm_symtab.mli +++ b/lib/bap_disasm/bap_disasm_symtab.mli @@ -1,7 +1,12 @@ +open Bap_core_theory + open Core_kernel open Bap_types.Std open Image_internal_std +module Disasm = Bap_disasm_driver +module Callgraph = Bap_disasm_calls + type block = Bap_disasm_block.t type edge = Bap_disasm_block.edge type cfg = Bap_disasm_rec.Cfg.t @@ -11,6 +16,12 @@ type symtab = t [@@deriving compare, sexp_of] type fn = string * block * cfg [@@deriving compare, sexp_of] +val create : Disasm.state -> Callgraph.t -> t KB.t + +module Toplevel : sig + val create : Disasm.state -> Callgraph.t -> t Toplevel.t +end + val empty : t val add_symbol : t -> fn -> t val remove : t -> fn -> t @@ -22,10 +33,31 @@ val intersecting : t -> mem -> fn list val to_sequence : t -> fn seq val span : fn -> unit memmap -(** [add_call symtab block name edge] remembers a call to a function - [name] from the given block with [edge] *) -val add_call : t -> block -> string -> edge -> t -(** [enum_calls t addr] returns a list of calls from a block with - the given [addr] *) -val enum_calls : t -> addr -> (string * edge) list +(** {2 Callgraph Interface} + + In parallel to a collection of control flow graphs, + Symtab contains a callgraph. +*) + +(** [insert_call ?implicit symtab callsite callee] remembers a call to a + function with name [callee] from a callsite, represented by the + [block]. + + If [implicit] is true (defaults to false) then the call is marked + as an implicit call. An implicit call is a call that is made via + a fallthrough edge. + + Note, the callee is represented with a string not with an address, + since it is possible that [find_by_name callee = None], for + example, when we have a call to an external function out of our + address space. *) +val insert_call : ?implicit:bool -> t -> block -> string -> t + +(** [explicit_callee symtab address] returns a callee which is + explicitly called from a block with the given [address]. *) +val explicit_callee : t -> addr -> string option + +(** [implicit_callee symtab address] returns a callee which is + implicitly called from a block with the given [address]. *) +val implicit_callee : t -> addr -> string option diff --git a/lib/bap_disasm/bap_disasm_target_factory.ml b/lib/bap_disasm/bap_disasm_target_factory.ml index dc0017beb..88196d16c 100644 --- a/lib/bap_disasm/bap_disasm_target_factory.ml +++ b/lib/bap_disasm/bap_disasm_target_factory.ml @@ -3,14 +3,15 @@ open Bap_types.Std open Bap_image_std include Bap_disasm_target_intf -let create_stub_target () = +let create_stub_target arch = let module Lifter = struct let lift _ _ = Or_error.error_string "not implemented" + let addr_size = Arch.addr_size arch module CPU = struct let gpr = Var.Set.empty - let nil = Var.create "nil" reg8_t - let mem = nil + let nil = Var.create "nil" (Type.imm (Size.in_bits addr_size)) + let mem = Var.create "mem" (Type.mem addr_size `r8) let pc = nil let sp = nil let sp = nil @@ -40,7 +41,7 @@ let target_of_arch = let get arch = match Hashtbl.find targets arch with | Some target -> target | None -> - let target = create_stub_target () in + let target = create_stub_target arch in Hashtbl.set targets ~key:arch ~data:target; target in get diff --git a/lib/bap_elementary/.merlin b/lib/bap_elementary/.merlin new file mode 100644 index 000000000..e31611c23 --- /dev/null +++ b/lib/bap_elementary/.merlin @@ -0,0 +1,4 @@ +REC +B ../../_build/lib/bap_core_theory +B ../../_build/lib/knowledge +S . \ No newline at end of file diff --git a/lib/bap_elementary/bap_elementary.ml b/lib/bap_elementary/bap_elementary.ml new file mode 100644 index 000000000..4d52743c0 --- /dev/null +++ b/lib/bap_elementary/bap_elementary.ml @@ -0,0 +1,135 @@ +open Core_kernel +open Bap.Std + +open Bap_knowledge +open Bap_core_theory +open Knowledge.Syntax + +module Elementary (Core : Theory.Core) = struct + open Core + + type 'a t = 'a knowledge + + exception Not_a_table + + let bits fsort = Theory.Float.(Format.bits (format fsort)) + + let name op sort rank = + let name = Format.asprintf "%a" Theory.Value.Sort.pp sort in + String.concat ~sep:"/" [op; name; string_of_int rank] + + let scheme ident = + match String.split ~on:'/' (Theory.Var.Ident.to_string ident) with + | [name; sort; rank] -> Some (name, sort, rank) + | _ -> None + + let operation ident = match scheme ident with + | Some (name,_,_) -> name + | None -> raise Not_a_table + + let table = name + let is_table ident = Option.is_some (scheme ident) + + let bind a body = + a >>= fun a -> + let sort = Theory.Value.sort a in + Theory.Var.scoped sort @@ fun v -> + let_ v !!a (body v) + + let (>>>=) = bind + + let (>>->) x f = + x >>= fun x -> + f (Theory.Value.sort x) x + + + include struct open Theory + let approximate + : rank : int -> + reduce : (('a,'s) format float -> 'r bitv) -> + extract : (int -> 'd bitv -> 's bitv) -> + coefs : ('r, 'd) Mem.t var -> + ('a,'s) format float -> + rmode -> + ('a,'s) format float + = fun ~rank ~reduce ~extract ~coefs x rmode -> + x >>-> fun fsort x -> + reduce !!x >>>= fun key -> + load (var coefs) (var key) >>>= fun value -> + let coef i = float fsort (extract i (var value)) in + let rec sum i y = + if i >= 0 then + fmul rmode !!x y >>>= fun y -> + coef i >>>= fun c -> + fadd rmode (var y) (var c) >>>= fun y -> + sum (i - 1) (var y) + else y in + coef rank >>>= fun cr -> + if rank = 0 then var cr + else sum (rank - 1) (var cr) + end + + let of_int sort x = + let m = Bitvec.modulus (Theory.Bitv.size sort)in + int sort Bitvec.(int x mod m) + + let nth fsort n bitv = + bitv >>-> fun sort bitv -> + let index = of_int sort (n * Theory.Bitv.size (bits fsort)) in + lshift !!bitv index >>>= fun bitv -> + high (bits fsort) (var bitv) + + let tabulate op ~rank ~size x rmode = + x >>-> fun fsort x -> + let keys = Theory.Bitv.define size in + let values = Theory.Bitv.define + ((rank + 1) * Theory.Bitv.size (bits fsort)) in + let mems = Theory.Mem.define keys values in + let name = name op fsort rank in + let coefs = Theory.Var.define mems name in + let reduce x = high keys (fbits x) in + let extract = nth fsort in + approximate ~rank ~coefs ~reduce ~extract !!x rmode + + module Scheme = struct + type 'a t = 'a Theory.Value.sort -> int -> string + + let pow s = name "pow" s + let powr s = name "powr" s + let compound s = name "compound" s + let rootn s = name "rootn" s + let pownn s = name "pownn" s + let rsqrt s = name "rsqrt" s + let hypot s = name "hypot" s + let exp s = name "exp" s + let expm1 s = name "expm1" s + let exp2 s = name "exp2" s + let exp2m1 s = name "exp2m1" s + let exp10 s = name "exp10" s + let exp10m1 s = name "exp10m1" s + let log s = name "log" s + let log2 s = name "log2" s + let log10 s = name "log10" s + let logp1 s = name "logp1" s + let log2p1 s = name "log2p1" s + let log10p1 s = name "log10p1" s + let sin s = name "sin" s + let cos s = name "cos" s + let tan s = name "tan" s + let sinpi s = name "sinpi" s + let cospi s = name "cospi" s + let atanpi s = name "atanpi" s + let atan2pi s = name "atan2pi" s + let asin s = name "asin" s + let acos s = name "acos" s + let atan s = name "atan" s + let atan2 s = name "atan2" s + let sinh s = name "sinh" s + let cosh s = name "cosh" s + let tanh s = name "tanh" s + let asinh s = name "asinh" s + let acosh s = name "acosh" s + let atanh s = name "atanh" s + end + +end diff --git a/lib/bap_elementary/bap_elementary.mli b/lib/bap_elementary/bap_elementary.mli new file mode 100644 index 000000000..bf2a28e5b --- /dev/null +++ b/lib/bap_elementary/bap_elementary.mli @@ -0,0 +1,105 @@ +open Core_kernel +open Bap.Std +open Bap_knowledge +open Bap_core_theory +open Theory + +(** Elementary is a library that provides few primitives for + approximations of floating point operations via table methods. *) +module Elementary (Theory : Theory.Core) : sig + + type 'a t = 'a knowledge + + exception Not_a_table + + (** [approximate ~rank ~reduce ~extract ~coefs x rmode] + returns a function f(x,rmode) that is defined by a polynomial + of rank [rank], which coefficients are stored in a table [coefs]. + + @param coefs is a table, where keys are some points in integer + space, in which floating point values could be mapped by + [reduce] function. And values of [coefs] are some point in + integer namespace from which floating point coefficients + could be restored by [extract] function with a respect to a + rank of each coefficient. *) + val approximate : + rank : int -> + reduce : (('a,'s) format float -> 'r bitv) -> + extract : (int -> 'd bitv -> 's bitv) -> + coefs : ('r, 'd) Mem.t var -> + ('a,'s) format float -> + rmode -> + ('a,'s) format float + + (** [tabulate op ~rank ~size x rmode] defines a subset of + functions that can be created by [approximate], s.t. + each value in a table is a concatenation of coefficients + (from one with the least rank to one with the most), + and each key is first [size] bits of floating point value. + + @param op is a name of floating point operation to approximate. *) + val tabulate : + string -> + rank:int -> + size:int -> + ('a,'s) format float -> + rmode -> + ('a,'s) format float + + (** [table operation sort rank] defines a naming scheme for + approximation of [rank] of an [operation] for values of [sort]. *) + val table : string -> ('r, 's) format Float.t Value.sort -> int -> string + + (** [is_table ident] returns true if [ident] is a table *) + val is_table : Var.ident -> Base.bool + + (** [operation ident] returns the name of a function, + which polynomial coefficients reside in a table + referenced by [ident]. + Raise Not_a_table if [ident] doesn't match to a naming + scheme *) + val operation : Var.ident -> string + + (** module contains naming schemes for different math functions *) + module Scheme : sig + type 'a t = 'a Value.sort -> int -> string + + val pow : 'a t + val powr : 'a t + val compound : 'a t + val rootn : 'a t + val pownn : 'a t + val rsqrt : 'a t + val hypot : 'a t + val exp : 'a t + val expm1 : 'a t + val exp2 : 'a t + val exp2m1 : 'a t + val exp10 : 'a t + val exp10m1 : 'a t + val log : 'a t + val log2 : 'a t + val log10 : 'a t + val logp1 : 'a t + val log2p1 : 'a t + val log10p1 : 'a t + val sin : 'a t + val cos : 'a t + val tan : 'a t + val sinpi : 'a t + val cospi : 'a t + val atanpi : 'a t + val atan2pi : 'a t + val asin : 'a t + val acos : 'a t + val atan : 'a t + val atan2 : 'a t + val sinh : 'a t + val cosh : 'a t + val tanh : 'a t + val asinh : 'a t + val acosh : 'a t + val atanh : 'a t + end + +end diff --git a/lib/bap_image/bap_image.ml b/lib/bap_image/bap_image.ml index e5a934339..6903d221f 100644 --- a/lib/bap_image/bap_image.ml +++ b/lib/bap_image/bap_image.ml @@ -176,9 +176,9 @@ let map_region data {locn={addr}; info={off; len; endian}} = Memory.create ~pos:off ~len endian addr data let static_view segments = function {addr} as locn -> - match Table.find_addr segments addr with - | None -> Result.failf "region is not mapped to memory" () - | Some (segmem,_) -> mem_of_locn segmem locn +match Table.find_addr segments addr with +| None -> Result.failf "region is not mapped to memory" () +| Some (segmem,_) -> mem_of_locn segmem locn let add_sym segments memory (symtab : symtab) ({name; locn=entry; info={extra_locns=locns}} as sym) = @@ -193,17 +193,17 @@ let add_sym segments memory (symtab : symtab) | _intersects_ -> Ok (memory,symtab)) let add_segment base memory segments seg = - map_region base seg >>= fun mem -> - Table.add segments mem seg >>= fun segments -> - let memory = tag mem segment seg memory in - Result.return (memory,segments) + map_region base seg >>= fun mem -> + Table.add segments mem seg >>= fun segments -> + let memory = tag mem segment seg memory in + Result.return (memory,segments) let add_sections_view segments sections memmap = List.fold sections ~init:(memmap,[]) ~f:(fun (memmap,ers) {name; locn} -> - match static_view segments locn with - | Ok mem -> tag mem section name memmap, ers - | Error er -> memmap,er::ers) + match static_view segments locn with + | Ok mem -> tag mem section name memmap, ers + | Error er -> memmap,er::ers) let make_table add base memory = List.fold ~init:(memory,Table.empty,[]) @@ -227,10 +227,7 @@ let words_of_table word_size tab = Memory.foldi ~word_size mem ~init:tab ~f:(fun addr word tab -> match Memory.view ~word_size ~from:addr ~words:1 mem with - | Error err -> - eprintf "\nSkipping with error: %s\n" - (Error.to_string_hum err); - tab + | Error _ -> tab | Ok mem -> ok_exn (Table.add tab mem word)) in Table.foldi tab ~init:Table.empty ~f:(fun mem _ -> words_of_memory mem) @@ -350,7 +347,7 @@ module Scheme = struct let relocation () = declare "relocation" (scheme fixup $ addr) Tuple.T2.create let external_reference () = - declare "external_reference" (scheme addr $ name) Tuple.T2.create + declare "external-reference" (scheme addr $ name) Tuple.T2.create let base_address () = declare "base-address" (scheme addr) ident end @@ -368,7 +365,7 @@ module Derive = struct Fact.Seq.reduce ~f:(fun a1 a2 -> if Arch.equal a1 a2 then Fact.return a1 else Fact.failf "arch is ambiguous" ()) - (Seq.filter_map ~f:Arch.of_string s) >>= fun a -> + (Seq.filter_map ~f:Arch.of_string s) >>= fun a -> match a with | Some a -> Fact.return a | None -> Fact.failf "unknown/unsupported architecture" () @@ -399,10 +396,10 @@ module Derive = struct {addr; size; info=(r,w,x)} {size=len; info=off} {info=name} -> - location ~addr ~size >>= fun locn -> - int_of_int64 off >>= fun off -> - int_of_int64 len >>| fun len -> - {name; locn; info={off; len; endian; r; w; x}}) >>= + location ~addr ~size >>= fun locn -> + int_of_int64 off >>= fun off -> + int_of_int64 len >>| fun len -> + {name; locn; info={off; len; endian; r; w; x}}) >>= Fact.Seq.all let sections = diff --git a/lib/bap_image/bap_memory.ml b/lib/bap_image/bap_memory.ml index 2360e7f3a..3f537962d 100644 --- a/lib/bap_image/bap_memory.ml +++ b/lib/bap_image/bap_memory.ml @@ -1,9 +1,11 @@ open Core_kernel +open Bap_core_theory open Regular.Std open Bap_types.Std open Or_error open Image_common + type 'a or_error = 'a Or_error.t type 'a m = 'a @@ -19,7 +21,7 @@ type t = { addr : addr; off : int; size : int; -} +} [@@deriving bin_io] module Repr = struct type t = { @@ -37,7 +39,11 @@ let to_repr mem = { Repr.size = mem.size; } -let sexp_of_t mem = Repr.sexp_of_t (to_repr mem) +let sexp_of_t mem = Sexp.List [ + Atom (Addr.string_of_value mem.addr); + sexp_of_int mem.size; + sexp_of_endian mem.endian; + ] let endian t = t.endian @@ -430,3 +436,12 @@ include Printable.Make(struct end) let hexdump t = Format.asprintf "%a" pp_hex t + + +let domain = KB.Domain.optional ~inspect:sexp_of_t "mem" + ~equal:(fun x y -> + Addr.equal x.addr y.addr && + Int.equal x.size y.size) + +let slot = KB.Class.property ~package:"bap.std" + Theory.Program.cls "mem" domain diff --git a/lib/bap_image/bap_memory.mli b/lib/bap_image/bap_memory.mli index b717fdf50..cad84be20 100644 --- a/lib/bap_image/bap_memory.mli +++ b/lib/bap_image/bap_memory.mli @@ -1,9 +1,11 @@ open Core_kernel +open Bap_core_theory open Regular.Std open Bap_types.Std open Image_common -type t [@@deriving sexp_of] + +type t [@@deriving bin_io, sexp_of] val create : ?pos:int (** defaults to [0] *) @@ -12,6 +14,7 @@ val create -> addr -> Bigstring.t -> t Or_error.t +val slot : (Theory.program, t option) KB.slot val of_file : endian -> addr -> string -> t Or_error.t val view : ?word_size:size -> ?from:addr -> ?words:int -> t -> t Or_error.t val range : t -> addr -> addr -> t Or_error.t diff --git a/lib/bap_image/image_internal_std.ml b/lib/bap_image/image_internal_std.ml index cb3fa7d0b..6da28eeee 100644 --- a/lib/bap_image/image_internal_std.ml +++ b/lib/bap_image/image_internal_std.ml @@ -10,7 +10,7 @@ module Backend = Image_backend type backend = Backend.t module Memory = Bap_memory -type mem = Memory.t [@@deriving sexp_of] +type mem = Memory.t [@@deriving bin_io, sexp_of] module Memmap = Bap_memmap type 'a memmap = 'a Memmap.t [@@deriving sexp_of] diff --git a/lib/bap_lisp/.merlin b/lib/bap_lisp/.merlin new file mode 100644 index 000000000..922a33084 --- /dev/null +++ b/lib/bap_lisp/.merlin @@ -0,0 +1,3 @@ +REC +PKG bap-knowledge, bap-core-theory +B ../../_build/lib/bap_lisp diff --git a/lib/bap_lisp/bap_lisp.ml b/lib/bap_lisp/bap_lisp.ml new file mode 100644 index 000000000..aa8561041 --- /dev/null +++ b/lib/bap_lisp/bap_lisp.ml @@ -0,0 +1,12 @@ +module Lisp = struct + module Attribute = Bap_lisp__attribute + module Def = Bap_lisp__def + module Var = Bap_lisp__var + module Load = Bap_lisp__parse + module Resolve = Bap_lisp__resolve + module Check = Bap_lisp__type.Check + module Context = Bap_lisp__context + module Program = Bap_lisp__program + module Type = Bap_lisp__type + module Doc = Bap_lisp__doc +end diff --git a/lib/bap_lisp/bap_lisp.mli b/lib/bap_lisp/bap_lisp.mli new file mode 100644 index 000000000..08a68b0be --- /dev/null +++ b/lib/bap_lisp/bap_lisp.mli @@ -0,0 +1,77 @@ +(* open Core_kernel + * open Format + * open Bap_knowledge + * open Bap_core_theory + * + * module Lisp : sig + * type program + * type message + * type context + * + * val program : program knowledge + * + * val link_program : program -> unit knowledge + * + * module Load : sig + * type error + * val program : ?paths:string list -> context -> (program,error) result + * val pp_program : formatter -> program -> unit + * val pp_error : formatter -> error -> unit + * end + * + * module Context : sig + * type t = context + * val empty : t + * end + * + * module Type : sig + * type t + * type signature + * type error + * + * type parameters = [ + * | `All of t + * | `Gen of t list * t + * | `Tuple of t list + * ] + * + * module Spec : sig + * val any : t + * val var : string -> t + * val sym : t + * val int : t + * val bool : t + * val byte : t + * val word : int -> t + * val a : t + * val b : t + * val c : t + * val d : t + * + * val tuple : t list -> [`Tuple of t list] + * val all : t -> [`All of t] + * val one : t -> [`Tuple of t list] + * val unit : [`Tuple of t list] + * val (//) : [`Tuple of t list] -> [`All of t] -> parameters + * val (@->) : [< parameters] -> t -> signature + * end + * + * val check : Sort.exp Var.Ident.Map.t -> program -> error list + * val pp_error : Format.formatter -> error -> unit + * end + * + * + * module Doc : sig + * module type Element = sig + * type t + * val pp : formatter -> t -> unit + * end + * + * module Category : Element + * module Name : Element + * module Descr : Element + * type index = (Category.t * (Name.t * Descr.t) list) list + * + * val generate_index : index knowledge + * end + * end *) diff --git a/lib/bap_lisp/bap_lisp__attribute.ml b/lib/bap_lisp/bap_lisp__attribute.ml new file mode 100644 index 000000000..47978760d --- /dev/null +++ b/lib/bap_lisp/bap_lisp__attribute.ml @@ -0,0 +1,58 @@ +open Core_kernel + +open Bap_lisp__types + +type attrs = Univ_map.t +type set = attrs + +type error = .. + +type error += Expect_list + +exception Unknown_attr of string * tree +exception Bad_syntax of error * tree list + + +type 'a attr = { + key : 'a Univ_map.Key.t; + add : 'a -> 'a -> 'a; + parse : tree list -> 'a; +} + +type 'a t = 'a Univ_map.Key.t + +type parser = Parser of (set -> tree list -> set) + +let parsers : parser String.Table.t = String.Table.create () + +let make_parser attr attrs sexp = + let value = attr.parse sexp in + Univ_map.update attrs attr.key ~f:(function + | None -> value + | Some value' -> attr.add value value') + +let register ~name ~add ~parse = + let attr = { + key = Univ_map.Key.create ~name sexp_of_opaque; + add; + parse = parse; + } in + let parser = Parser (make_parser attr) in + Hashtbl.add_exn parsers ~key:name ~data:parser; + attr.key + +let expected_parsers () = + String.Table.keys parsers |> String.concat ~sep:" | " + +let parse s attrs name values = match Hashtbl.find parsers name with + | None -> raise (Unknown_attr (name,s)) + | Some (Parser run) -> run attrs values + +let parse attrs = function + | {data=List ({data=Atom name} as s :: values)} -> parse s attrs name values + | s -> raise (Bad_syntax (Expect_list,[s])) + +module Set = struct + let get = Univ_map.find + let empty = Univ_map.empty +end diff --git a/lib/bap_lisp/bap_lisp__attribute.mli b/lib/bap_lisp/bap_lisp__attribute.mli new file mode 100644 index 000000000..aeeec48ea --- /dev/null +++ b/lib/bap_lisp/bap_lisp__attribute.mli @@ -0,0 +1,33 @@ +(** Primus attributes. + + Attributes are declared with the [declare] statement. Each + attribute has its own syntax. A parser can be registered using + this module. + + So far, we keep this module internal. +*) + +open Bap_lisp__types + +type 'a t + +type set + +type error = .. +exception Unknown_attr of string * tree +exception Bad_syntax of error * tree list + +(** registers a new attribute. An attribute is a monoind that is + parsed from a sexp and added to existing attribute of the same + kind.*) +val register : + name:string -> + add:('a -> 'a -> 'a) -> + parse:(tree list -> 'a) -> 'a t + +val parse : set -> tree -> set + +module Set : sig + val get : set -> 'a t -> 'a option + val empty : set +end diff --git a/lib/bap_lisp/bap_lisp__attributes.ml b/lib/bap_lisp/bap_lisp__attributes.ml new file mode 100644 index 000000000..09eb0c509 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__attributes.ml @@ -0,0 +1,117 @@ +open Core_kernel + + +open Bap_lisp__types + +module Attribute = Bap_lisp__attribute +module Var = Bap_lisp__var + +let fail err t = raise (Attribute.Bad_syntax (err,t)) + +module Variables = struct + + type t = var list + + type Attribute.error += Expect_atom + type Attribute.error += Var_error of Var.read_error + + let var t = match t with + | {data=List _} -> fail Expect_atom [t] + | {data=Atom v; id; eq} -> match Var.read id eq v with + | Ok v -> v + | Error err -> fail (Var_error err) [t] + + let parse = List.map ~f:var + + let global = Attribute.register + ~name:"global" + ~add:List.append + ~parse + + let static = Attribute.register + ~name:"static" + ~add:List.append + ~parse +end + +type Attribute.error += + | Expect_atom + | Unterminated_quote + +let parse_name = function + | {data=Atom x} as s -> + let n = String.length x in + if n < 2 then x + else if x.[0] = '"' + then if x.[n-1] = '"' + then String.sub ~pos:1 ~len:(n-2) x + else fail Unterminated_quote [s] + else x + | s -> fail Expect_atom [s] + + +module External = struct + type t = string list + + let parse = List.map ~f:parse_name + + let t = Attribute.register + ~name:"external" + ~add:List.append + ~parse +end + + +module Advice = struct + type cmethod = Before | After [@@deriving compare, sexp] + + module Methods = Map.Make_plain(struct + type t = cmethod [@@deriving compare, sexp] + end) + + type Attribute.error += + | Unknown_method of string + | Bad_syntax + | Empty + | No_targets + + type t = {methods : String.Set.t Methods.t} + + let methods = String.Map.of_alist_exn [ + ":before", Before; + ":after", After; + ] + + let targets {methods} m = match Map.find methods m with + | None -> String.Set.empty + | Some targets -> targets + + let parse_targets met ss = match ss with + | [] -> fail No_targets ss + | ss -> + List.fold ss ~init:{methods=Methods.empty} ~f:(fun {methods} t -> { + methods = Map.update methods met ~f:(function + | None -> String.Set.singleton (parse_name t) + | Some ts -> Set.add ts (parse_name t)) + }) + + let parse trees = match trees with + | [] -> fail Empty trees + | {data=List _} as s :: _ -> fail Bad_syntax [s] + | {data=Atom s} as lit :: ss -> + if String.is_empty s then fail (Unknown_method s) [lit]; + match s with + | ":before" -> parse_targets Before ss + | ":after" -> parse_targets After ss + | _ when s.[0] = ':' -> fail (Unknown_method s) [lit] + | _ -> parse_targets Before trees + + let add d1 d2 = { + methods = Map.merge d1.methods d2.methods ~f:(fun ~key -> function + | `Both (xs,ys) -> Some (Set.union xs ys) + | `Left xs | `Right xs -> Some xs) + } + + let t = Attribute.register ~name:"advice" + ~parse ~add +end diff --git a/lib/bap_lisp/bap_lisp__attributes.mli b/lib/bap_lisp/bap_lisp__attributes.mli new file mode 100644 index 000000000..fc1ac1753 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__attributes.mli @@ -0,0 +1,20 @@ +open Core_kernel +open Bap_lisp__types + +module Attribute = Bap_lisp__attribute + +module External : sig + val t : string list Attribute.t +end + +module Variables : sig + val global : var list Attribute.t + val static : var list Attribute.t +end + +module Advice : sig + type cmethod = Before | After + type t + val t : t Attribute.t + val targets : t -> cmethod -> String.Set.t +end diff --git a/lib/bap_lisp/bap_lisp__context.ml b/lib/bap_lisp/bap_lisp__context.ml new file mode 100644 index 000000000..22f0fa9fa --- /dev/null +++ b/lib/bap_lisp/bap_lisp__context.ml @@ -0,0 +1,108 @@ +open Core_kernel + +open Bap_lisp__types + +module Attribute = Bap_lisp__attribute + +module Feature = String +module Name = String + +type t = Feature.Set.t Name.Map.t +let empty = Name.Map.empty + +type Attribute.error += Expect_atom | Expect_list | Unterminated_quote + + +let fail what got = raise (Attribute.Bad_syntax (what,[got])) +let expect_atom = fail Expect_atom +let expect_list = fail Expect_list + + +let features = Feature.Set.of_list + +let merge xs ys : t = Map.merge xs ys ~f:(fun ~key:_ -> function + | `Left v | `Right v -> Some v + | `Both (x,y) -> Some (Feature.Set.union x y) ) + +let sexp_of_context (name,values) = + Sexp.List (List.map (name :: Set.to_list values) + ~f:(fun x -> Sexp.Atom x)) + +let sexp_of (cs : t) = + Sexp.List (Atom "context" :: + (Map.to_alist cs |> List.map ~f:sexp_of_context)) + +let value = function + | {data=Atom x} -> x + | s -> expect_atom s + +let parse_name = function + | {data=Atom x} as s -> + let n = String.length x in + if n < 2 then x + else if x.[0] = '"' + then if x.[n-1] = '"' + then String.sub ~pos:1 ~len:(n-2) x + else fail Unterminated_quote s + else x + | s -> fail Expect_atom s + + +let context_of_tree = function + | {data=List (x :: xs)} -> + parse_name x, Feature.Set.of_list (List.map xs ~f:parse_name) + | s -> expect_list s + + +let push cs name vs = + Map.update cs name ~f:(function + | None -> vs + | Some vs' -> Set.union vs vs') + + +let parse : tree list -> t = + List.fold ~init:Name.Map.empty ~f:(fun cs tree -> + let (name,vs) = context_of_tree tree in + push cs name vs) + +let add cs cs' = + Map.fold cs ~init:cs' ~f:(fun ~key:name ~data:vs cs' -> + push cs' name vs) + +let t = Attribute.register + ~name:"context" + ~add + ~parse + +let pp ppf ctxt = + Sexp.pp_hum ppf (sexp_of ctxt) + + +(* [C <= C'] iff for each class c in C, there is a class c' in C' + such that c >= c', where c >= c' is a superset operation. + + This implies that the set of classes in C is a subset of the set of + classes in C', since all missing classes can be introduced as + classes with an empty feature sets. Intuitively, that denotes that + a definition implicitly states that it is applicable to all + instances of a missing class. +*) +let (<=) ctxt ctxt' = + let sups = Map.merge ctxt ctxt' ~f:(fun ~key:_ -> function + | `Left _ -> None + | `Right features' -> + if Set.is_empty features' then None else Some features' + | `Both (features,features') -> + if (Set.is_subset features' ~of_:features) + then None else Some features') in + Map.is_empty @@ sups + + +type porder = Less | Same | Equiv | More + +let compare c1 c2 = + match c1 <= c2, c2 <= c1 with + | true, false -> Less + | true, true -> Same + | false,false -> Equiv + | false,true -> More diff --git a/lib/bap_lisp/bap_lisp__context.mli b/lib/bap_lisp/bap_lisp__context.mli new file mode 100644 index 000000000..2e4482c55 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__context.mli @@ -0,0 +1,283 @@ +(** Primus Lisp Type Class Contexts. + + @primus-lisp-internals@ + + {1 Introduction} + + Type context, not to be confused with the evaluation environment, + defines a set of features that describe the current state of the + world. The context is a more or less static property (to some + definition of static). The context is defined by the program that + we are analyzing (a project structure) and by explicit + proclamations. Thus, once a lisp feature (or a set of features) is + loaded the context is fixed and for each definition we should be + able to find an implementation that suits the context. We (will + soon) extend the notion of a context with observation classes, + that brings more dynamic flavor to the context systems, as an + observation instance is created in the runtime. However, the type + of an instance (i.e., the class) is still fixed at the compile + time, thus it is a static feature. + + {1 Application} + + The context type class system allows polymoprhic definitions. That + means, that a definition (i.e., a meaning of a name) can have + multiple implementation (represented with different + formulae). This is made by specifying a context of applicapibility + of a definiton. During the compile time the most suitable + definition is used (i.e., a defintion that is applicable and is + the most specific from the set of all applicable definitions). + + The property of being most suitable is defined by the partial + ordering defined on the set of context classes, in other words, by + using the subtype (or inclusion) polymorphism. + + + {1 Representation} + + + We represent a context as a map of sets. The powerset lattice + gives a subset partial order, with the empty set being the + superclass (the least constraint context). Every definition in + Primus Lisp has a context of applicapibility, i.e., it is defined + only in a specific context. By default it is an empty set. + + A definition is considered only if current context is a subset of + the definition context. For example, if a definition context is + [(arch arm)] then it is applicable when current context is [(arch + arm v7)] or just [(arch arm)], and not applicable if it is [(arch + x86)]. + + + {1 Work in progress} + + {2 Observation classes} + + An observation class is a particular kind of a context class, that + allows a user to interact with the Primus observation + system. Basically the Observation class is a production rule, that + defines new observations that are built from other observations. + + The syntax is an abuse of the defclass stanza, i.e., + + (defclass uaf-memory-problem (?alloc ?free ?use) + (memcheck-acquire ?alloc ?ptr) + (memcheck-release ?free ?ptr) + (memcheck-violate ?use ?ptr)) + + + The idea is that it defines a new class of observations, that are + derived from three other classes: memcheck-acquire, + memcheck-release, and memcheck-violate. An additional constraint + is that the `?ptr` field of these threee events should be equal. + + The class above is represented as the following BARE rule: + + (((memcheck-acquire ?alloc ?ptr) + (memcheck-release ?free ?ptr) + (memcheck-violate ?use ?ptr)) + ((uaf-memory-problem ?alloc ?free ?use))) + + + Thus a production rule generates a new observation. With respect + to the subtyping a newly generated rule is always a subtype of + the derived rules, so it plays well with the class notion. The + only drawback here, is that we extend our notion of subtyping from + a pure structural one, to the nominal. + + The defclass definition is planned to provide even more than just + BARE substitutions, but + + - when clauses that will evaluate arbitrary Primus expressions for + an extra constraint generation + + - defparameter that will compute parameters also using arbitrary Primus + expressions. + + The expressions called from a body of an observation class are + evaluated in the context of this class (and all its + superclasses). Thus it would be possible to define the behavior of + the definition based on a particular slot of a class, for example + + {v + (defclass chunk-allocated (?ptr ?len) + (call-return ?name ?args ?ptr) + (defparameter ?len (compute-allocated-size ?args))) + + (defun compute-allocated-size (n m) + (declare (context (observe (call-return calloc)))) + (multiply n m)) + v} + + Another planned feature is to provide a mechanism for callbacks + invocation for each defined event, (ab)using the defmethod syntax: + + {v + (defmethod chunk-allocated (ptr len) + (memcheck-acquire ptr len)) + + v} + + This is, however, requres each slot of an observation class to be + a value, not a symbol. So, we need to think more about it, maybe + slots that are values, should be defined differently. + + + {2 Method combination or super notation} + + The idea of the subtyping polymorphism is to provide a mechanism + for extensible refinement of a defintion. I.e., we start with a + generic definition and then extend it (without modification) in + some more specific context. However, currently, there is no + mechanisms to access a more general definition from a context of + more specific one. I.e., we don't have the [super] notation, + upcasting, or some method combination mechanism. So, we can't + actually _refine_ the definition, but rather we are forced to + _redefine_ it. This is a severe limitation that we should try to + overcome. + + Possible solutions: + + - The super notation. Conventional programming languages allows to + tag a definition at the point of application with some super + tag, that basically disables the dynamic dispatching. Though it + is usual a rather intuitive to most of the users, it is hard to + map this solution to Primus Lisp, as we don't have the nominal + subtyping, but rather the structural one. So a derivative + context, don't really know which base context (if any) it + refines. Moreover, given the partial order, it may have several + least super types, and since we don't have names to them, we + can't really pick one. So, it is tempting to mark this approach + as invalid. + + + - Another approach would be to provide a mechanism to upcast + (generalize) current context, so that a more general definition + can be still accessible from a more specific definition. Though + this sounds reasonable, it could be hard to implement it + correctly. If we will introduce some syntactic notation, that + will allow a user to specify a different context of invocation, + the it will complicate the notion of the context a + lot. Moreover, if the upcasting operator will become a term, it + may even make the context to be a dynamic property that is hard + to verify statically. However, with a careful approach, this + still can be a valid solution. A possible syntax would be: + [(f #(new-context) args...)]. This syntax, allows a user to + specify an arbitrary context, including downcasting. We can, of + course, limit this by throwing a type error. But this will + complicate the notion of a context, and will disable programs + that are valid, i.e., it actually makes sense to derive a + definition from a definition that is not applicable, by using it + and amending in some way. One case, contradict, that such + approach just shows that an abstraction was missed, and the + commonality between two definitions should be moved as a least + upper bound definition of both of them. But, this will in fact, + contradict the open/closed principle, as extracting the common + functionality would require editing the existing code. + + + - Method combination. The method combination is a CLOS mechanism + common in common lisps. Instead of calling the most specific + instance we are actually calling methods of all instances and + combine them according to some rule. This approach is very + natural to the Common Lisp style of methods that are defined + externally to the class, and looks as the most appealing + solution. However, there are a few problems: + + 1. A definition is parametrized not by _a_ type context, but + by a product of type contexts, thus we essentially have a + multi-method here. The applicable methods should be combined + in some deterministic and human understandable (predictable) + order. In CLOS (though it depends on a chose method + combination) the applicable methods are sorted in the + precedence order. Where the order of direct superclasses of + a class is defined syntactically by the list of + super-classes of a class (since CLOS has nominal subtyping), + and the order between different classes is specified by the + order of corresponding method parameters. Since, in Primus + Lisp we have structural subtyping we do not have an explicit + order between the direct super classes of a class, thus in + which order the methods are combined is hard to + determine. We can, however, rely on the alphabetical order, + to make it at least predictable. + + 2. If we will provide the `call-next-method` function, that + will basically upcast the context, and switch to the more + general definition, then the question is whether the called + method will call the next method itself. It is like a + cooperative generalization. Something, that I personally + don't like as whether a method will call the parent method + is defined in the code. + + 3. The call-next-method implies that each our definition is + not a method, not a definition. And that a definition is spread + across an open set of definition, that are combined in some + dynamic order. + + + - Do not introduce any mechanism but rely on the advice mechanism + instead. In Common Lisp terminology this will only leave us + with auxiliary methods, except that the after method allows us + to override the return value. Thus, if someone would like to + partially override a definition, it would be possible to use + after method to get the result of a parent computation. Since + definitions added via the advice mechanism do not compete with + master methods we don't have the problem of the method + ordering. + + 1. we have plenty of method combinators, that are more clean + than (call-next-method) but still quite flexible. + + 2. :before, :after, :around, and there while,until variants, + are reasily implementable. (We can implement the around + using the same approach as the (call-next-method), except + that we can use a much more concrete (call-the-advised)). +*) + +open Core_kernel + +module Attribute = Bap_lisp__attribute + +type t + +val t : t Attribute.t + +val empty : t + + + +(** [cx <= cx'] is true if [cx] is same as [cx'] or if [cx] is more + specific. Where a context [c] is the same as context [c'] if [c] + is as specific as [c'], i.e., no more, no less. + + This is a partial order relation, i.e., it is possible that both + contexts are neither same, nor one is less than of another, nor + vice verse. + + Examples: + + {v + + ((arch arm v7)) <= ((arch arm)) => true + ((arch arm v7) (compiler gcc)) <= ((arch arm)) => false + v} +*) +val (<=) : t -> t -> bool + + +(** Partial ordering between context classes. + + We define thepartial order in terms of how generic is a + definition. + +*) +type porder = + | Less (** less generic: c1 <= c2 && not(c2 <= c1) *) + | Same (** exactly the same: c1 <= c2 && c2 <= c1 *) + | Equiv (** not comparable : not(c1 <= c2) && not(c2 <= c1) *) + | More (** more generic : not(c1 <= c2) && c2 <= c1 *) + +val compare : t -> t -> porder +val pp : Format.formatter -> t -> unit + +val merge : t -> t -> t diff --git a/lib/bap_lisp/bap_lisp__def.ml b/lib/bap_lisp/bap_lisp__def.ml new file mode 100644 index 000000000..82c8755df --- /dev/null +++ b/lib/bap_lisp/bap_lisp__def.ml @@ -0,0 +1,180 @@ +open Core_kernel +open Format + +open Bap_lisp__types + +module Attribute = Bap_lisp__attribute +module Loc = Bap_lisp__loc +module Index = Bap_lisp__index +module Type = Bap_lisp__type + + +type attrs = Attribute.set + +type meta = { + name : string; + docs : string; + attrs : attrs; +} [@@deriving fields] + +type func = { + args : var list; + body : ast; +} [@@deriving fields] + +type meth = func + +type macro = { + param : string list; + subst : tree; +} [@@deriving fields] + +type subst = { + elts : tree list; +} [@@deriving fields] + + +type const = { + value : string; +} + +type para = { + default : ast; +} + + +type primitive = { + types : Type.signature; +} + +type 'a spec = {meta : meta; code : 'a} +type 'a t = 'a spec indexed +type 'a def = ?docs:string -> ?attrs:attrs -> string -> 'a + +let name {data={meta}} = name meta +let docs {data={meta}} = docs meta +let attributes {data={meta}} = attrs meta + +let field f t = f t.data.code + +let create data tree = { + data; + id = tree.id; + eq = tree.eq; +} + +module Func = struct + let args = field args + let body = field body + let create ?(docs="") ?(attrs=Attribute.Set.empty) name args body = + create { + meta = {name; docs; attrs}; + code = {args; body} + } + + let with_body t body = { + t with data = { + t.data with code = {t.data.code with body}} + } + +end + +module Meth = Func + +module Para = struct + let create : 'a def = + fun ?(docs="") ?(attrs=Attribute.Set.empty) name default -> + create { + meta = {name; docs; attrs}; + code = {default}; + } + + let default p = p.data.code.default + let with_default t default = { + t with data = { + t.data with code = {default} + } + } +end + +module Macro = struct + type error += Bad_subst of tree * tree list + let args = field param + let body = field subst + let create ?(docs="") ?(attrs=Attribute.Set.empty) name param subst = + create { + meta = {name; docs; attrs}; + code = {param; subst} + } + + let take_rest xs ys = + let rec take xs ys zs = match xs,ys with + | [],[] -> Some zs + | [x], (_ :: _ :: _ as rest) -> Some ((x,rest)::zs) + | x :: xs, y :: ys -> take xs ys ((x,[y])::zs) + | _ :: _, [] | [],_ -> None in + match take xs ys []with + | Some [] -> Some (0,[]) + | Some ((_,rest) :: _ as bs) -> + Some (List.length rest, List.rev bs) + | None -> None + + let bind macro cs = take_rest macro.data.code.param cs + + let find = List.Assoc.find ~equal:String.equal + + let unknown = Eq.null + + let subst bs body = + let rec sub : tree -> tree list = function + | {data=List xs; id} -> + [{data=List (List.concat_map xs ~f:sub); id; eq=unknown}] + | {data=Atom x} as atom -> match find bs x with + | None -> [atom] + | Some cs -> cs in + match body with + | {data=List xs; id} -> + {data=List (List.concat_map xs ~f:sub); id; eq=unknown} + | {data=Atom x} as atom -> match find bs x with + | None -> atom + | Some [x] -> x + | Some xs -> raise (Fail (Bad_subst (atom,xs))) + + let apply macro cs = subst cs macro.data.code.subst +end + +module Const = struct + let create ?(docs="") ?(attrs=Attribute.Set.empty) name ~value = + create { + meta = {name; docs; attrs}; + code = {value} + } + let value p = {data=Atom p.data.code.value; id=p.id; eq=p.eq} +end + +module Subst = struct + type syntax = Ident | Ascii | Hex + + let body = field elts + + + let create ?(docs="") ?(attrs=Attribute.Set.empty) name elts = + create { + meta = {name; docs; attrs}; + code = {elts} + } + +end + +module Primitive = struct + let create ?(docs="") name types = { + data = { + meta = {name;docs; attrs=Attribute.Set.empty}; + code={types}; + }; + id = Id.null; + eq = Eq.null; + } + + let signature p = p.data.code.types +end diff --git a/lib/bap_lisp/bap_lisp__def.mli b/lib/bap_lisp/bap_lisp__def.mli new file mode 100644 index 000000000..78b0c4b69 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__def.mli @@ -0,0 +1,74 @@ +open Bap_lisp__types + +module Attribute = Bap_lisp__attribute +module Type = Bap_lisp__type + + + +type 'a spec +type 'a t = 'a spec indexed +type func +type meth +type macro +type subst +type const +type primitive +type para + +type attrs = Attribute.set + +val name : 'a t -> string +val docs : 'a t -> string +val attributes : 'a t -> attrs + +type 'a def = ?docs:string -> ?attrs:attrs -> string -> 'a + +module Func : sig + val create : (var list -> ast -> tree -> func t) def + val args : func t -> var list + val body : func t -> ast + val with_body : func t -> ast -> func t +end + +module Meth : sig + val create : (var list -> ast -> tree -> meth t) def + val args : meth t -> var list + val body : meth t -> ast + val with_body : meth t -> ast -> meth t +end + +module Para : sig + val create : (ast -> tree -> para t) def + val default : para t -> ast + val with_default : para t -> ast -> para t +end + +module Macro : sig + val create : (string list -> tree -> tree -> macro t) def + val args : macro t -> string list + val body : macro t -> tree + val bind : macro t -> tree list -> (int * (string * tree list) list) option + + (** [apply m bs] returns the body of [m] where any occurence of a + variable [x] is substituted with [y] if [x,[y]] is in the list + of bindings [bs]. + + The identity of the returned tree is the same as the identity of + the macro body.*) + val apply : macro t -> (string * tree list) list -> tree +end + +module Const : sig + val create : (value:string -> tree -> const t) def + val value : const t -> tree +end + +module Subst : sig + val create : (tree list -> tree -> subst t) def + val body : subst t -> tree list +end + +module Primitive : sig + val create : ?docs:string -> string -> Type.signature -> primitive t + val signature : primitive t -> Type.signature +end diff --git a/lib/bap_lisp/bap_lisp__doc.ml b/lib/bap_lisp/bap_lisp__doc.ml new file mode 100644 index 000000000..f10e68352 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__doc.ml @@ -0,0 +1,66 @@ +open Core_kernel +open Format + +module Lisp = struct + module Def = Bap_lisp__def + module Program = Bap_lisp__program +end + +module type Element = sig + type t + val pp : formatter -> t -> unit +end + +module Category = String +module Name = String +module Descr = String + +type index = (string * (string * string) list) list + + +let unquote s = + if String.is_prefix s ~prefix:{|"|} && + String.is_suffix s ~suffix:{|"|} + then String.sub s ~pos:1 ~len:(String.length s - 2) + else s + +let dedup_whitespace str = + let buf = Buffer.create (String.length str) in + let push = Buffer.add_char buf in + String.fold str ~init:`white ~f:(fun state c -> + let ws = Char.is_whitespace c in + if not ws then push c; + match state,ws with + | `white,true -> `white + | `white,false -> `black + | `black,true -> push c; `white + | `black,false -> `black) |> ignore; + Buffer.contents buf + +let normalize_descr s = + dedup_whitespace (unquote (String.strip s)) + +let normalize xs = + List.Assoc.map xs ~f:normalize_descr |> + String.Map.of_alist_reduce ~f:(fun x y -> + if x = "" then y else if y = "" then x + else if x = y then x + else sprintf "%s\nOR\n%s" x y) |> + Map.to_alist + + +let describe prog item = + Lisp.Program.get prog item |> List.map ~f:(fun x -> + Lisp.Def.name x, Lisp.Def.docs x) |> normalize + + +let generate_index p (* signals *) = Lisp.Program.Items.[ + "Macros", describe p macro; + "Substitutions", describe p subst; + "Constants", describe p const; + "Functions", describe p func; + "Methods", describe p meth; + "Parameters", describe p para; + "Primitives", describe p primitive; + (* "Signals", normalize signals; *) + ] diff --git a/lib/bap_lisp/bap_lisp__doc.mli b/lib/bap_lisp/bap_lisp__doc.mli new file mode 100644 index 000000000..83492e7ea --- /dev/null +++ b/lib/bap_lisp/bap_lisp__doc.mli @@ -0,0 +1,14 @@ +open Bap_knowledge +open Format + +module type Element = sig + type t + val pp : formatter -> t -> unit +end + +module Category : Element +module Name : Element +module Descr : Element +type index = (Category.t * (Name.t * Descr.t) list) list + +val generate_index : Bap_lisp__program.t -> index diff --git a/lib/bap_lisp/bap_lisp__index.ml b/lib/bap_lisp/bap_lisp__index.ml new file mode 100644 index 000000000..d1d8e9400 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__index.ml @@ -0,0 +1,21 @@ +open Core_kernel + +module type S = sig + type t [@@deriving sexp_of] + val null : t + val next : t -> t + val pp : Format.formatter -> t -> unit + include Comparable.S_plain with type t := t +end + +module Make() : S = struct + let null = Int63.zero + let next = Int63.succ + include Int63 +end + +type ('a,'i,'e) interned = { + data : 'a; + id : 'i; + eq : 'e; +} diff --git a/lib/bap_lisp/bap_lisp__index.mli b/lib/bap_lisp/bap_lisp__index.mli new file mode 100644 index 000000000..5072291c2 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__index.mli @@ -0,0 +1,17 @@ +open Core_kernel + +type ('a,'i,'e) interned = { + data : 'a; + id : 'i; + eq : 'e; +} + +module type S = sig + type t [@@deriving sexp_of] + val null : t + val next : t -> t + val pp : Format.formatter -> t -> unit + include Comparable.S_plain with type t := t +end + +module Make() : S diff --git a/lib/bap_lisp/bap_lisp__loc.ml b/lib/bap_lisp/bap_lisp__loc.ml new file mode 100644 index 000000000..80b620ea8 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__loc.ml @@ -0,0 +1,49 @@ +open Core_kernel +open Format + +type range = Parsexp.Positions.range [@@deriving compare, sexp_of] +type loc = { + file : string; + range : range; +} [@@deriving compare, sexp_of] + + +let pp ppf {file; range={start_pos=s; end_pos=e}} = + let len = e.offset - s.offset in + fprintf ppf "File %S, line %d, characters %d-%d" + file s.line s.col (s.col+len) + +let merge_pos merge p1 p2 = Parsexp.Positions.{ + line = merge p1.line p2.line; + col = merge p1.col p2.col; + offset = merge p1.offset p2.offset; + } + +let merge p1 p2 = + if p1.file <> p2.file + then invalid_arg "Loc: can't merge locations from different files"; + Parsexp.Positions.{ + p1 with range = { + start_pos = merge_pos min p1.range.start_pos p2.range.start_pos; + end_pos = merge_pos max p1.range.end_pos p2.range.end_pos; + } + } + +let shift_pos p off = Parsexp.Positions.{ + p with + col = p.col + off; + offset = p.offset + off; + } + +let nth_char p off = Parsexp.Positions.{ + p with range = { + start_pos = shift_pos p.range.start_pos off; + end_pos = shift_pos p.range.end_pos (off + 1)} + } + + +include Comparable.Make_plain(struct + type t = loc [@@deriving compare, sexp_of] + end) + +type t = loc [@@deriving compare, sexp_of] diff --git a/lib/bap_lisp/bap_lisp__loc.mli b/lib/bap_lisp/bap_lisp__loc.mli new file mode 100644 index 000000000..307d48048 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__loc.mli @@ -0,0 +1,21 @@ +open Core_kernel + +(* a region in a file *) +type range = Parsexp.Positions.range [@@deriving sexp_of] + + +(* a region in the specified file *) +type loc = { + file : string; + range : range; +} [@@deriving compare, sexp_of] + +type t = loc [@@deriving compare, sexp_of] + +val merge : t -> t -> t + +val nth_char : t -> int -> t + +val pp : Format.formatter -> t -> unit + +include Comparable.S_plain with type t := t diff --git a/lib/bap_lisp/bap_lisp__parse.ml b/lib/bap_lisp/bap_lisp__parse.ml new file mode 100644 index 000000000..0456a8d2b --- /dev/null +++ b/lib/bap_lisp/bap_lisp__parse.ml @@ -0,0 +1,739 @@ +open Core_kernel +open Format + +open Bap_lisp__types + +module Attribute = Bap_lisp__attribute +module Context = Bap_lisp__context +module Def = Bap_lisp__def +module Var = Bap_lisp__var +module Word = Bap_lisp__word +module Loc = Bap_lisp__loc +module Resolve = Bap_lisp__resolve +module Program = Bap_lisp__program +module Type = Bap_lisp__type + +type defkind = Func | Macro | Const | Subst | Meth | Para + +type format_error = + | Expect_digit + | Illegal_escape of char + | Unterminated of [`Esc | `Exp] + +type parse_error = + | Bad_var_literal of Var.read_error + | Bad_word_literal of Word.read_error + | Bad_form of string + | Bad_app + | Bad_let_binding + | Bad_param_list + | Bad_macro_param_list + | Bad_require + | Unknown_toplevel + | Bad_toplevel + | Bad_def of defkind + | Bad_ascii + | Bad_hex + | Unknown_subst_syntax + | Unresolved of defkind * string * Resolve.resolution + + +exception Parse_error of parse_error * tree list +exception Attribute_parse_error of Attribute.error * tree list * tree + +type source = + | Cmdline + | Module of Loc.t + +type error = + | Parse_error of parse_error * Loc.t + | Format_error of format_error * Loc.t + | Sexp_error of string * source * Source.error + | Unresolved_feature of string * source + | Unknown_attr of string * Loc.t + +exception Fail of error + +let loc src trees = match trees with + | [] -> Source.loc src Id.null + | t :: ts -> + List.fold ts ~init:(Source.loc src t.id) ~f:(fun loc t -> + Loc.merge loc (Source.loc src t.id)) + +let is_quoted s = + let n = String.length s in + n > 1 && s.[0] = '"' && s.[n - 1] = '"' + +let is_symbol s = + String.length s > 1 && s.[0] = '\'' + +let unqoute s = + if is_quoted s + then String.sub ~pos:1 ~len:(String.length s - 2) s + else s + +let symbol s = + if is_symbol s + then String.subo ~pos:1 s + else s + +module Parse = struct + open Program.Items + + let fails err s = raise (Parse_error (err,s)) + let fail err s = fails err [s] + let bad_form op got = fail (Bad_form op) got + let nil = {exp=Bitvec.zero; typ=Type.word 1} + + + let expand prog cs = + List.concat_map cs ~f:(function + | {data=List _} as cs -> [cs] + | {data=Atom x} as atom -> + match Resolve.subst prog subst x () with + | None -> [atom] + | Some (Error s) -> fail (Unresolved (Subst,x,s)) atom + | Some (Ok (d,())) -> + Def.Subst.body d |> List.map ~f:(fun tree -> + {tree with id = atom.id})) + + let let_var : tree -> var = function + | {data=List _} as s -> fail Bad_let_binding s + | {data=Atom x; id; eq} as s -> match Var.read id eq x with + | Error e -> fail (Bad_var_literal e) s + | Ok var -> var + + + let fmt prog fmt tree = + let fmt = unqoute fmt in + let fail err off = + let pos = + Loc.nth_char (loc (Program.sources prog) [tree]) (off+1) in + raise (Fail (Format_error (err, pos))) in + let unescape off c = + try Scanf.unescaped (sprintf "\\%c" c) + with Scanf.Scan_failure _ -> + fail (Illegal_escape c) off in + let nil = `Lit [] in + let str cs = String.of_char_list (List.rev cs) in + let push_nothing = ident in + let push s xs = s :: xs in + let push_lit s = push (Lit s) in + let push_pos x = push (Pos (Int.of_string (Char.to_string x))) in + let push_chars cs = push_lit (str cs) in + let rec parse off spec state = + let lit parse xs = function + | '\\' -> parse (push_chars xs) `Esc + | '$' -> parse (push_chars xs) `Exp + | x -> parse push_nothing (`Lit (x::xs)) in + let esc parse = function + | '$' -> parse (push_lit "$") nil + | c -> parse (push_lit (unescape off c)) nil in + let exp parse = function + | c when Char.is_digit c -> parse (push_pos c) nil + | _ -> fail Expect_digit off in + let step push state = parse (off+1) (push spec) state in + if Int.(off < String.length fmt) then match state with + | `Lit xs -> lit step xs fmt.[off] + | `Esc -> esc step fmt.[off] + | `Exp -> exp step fmt.[off] + else List.rev @@ match state with + | `Lit xs -> push_chars xs spec + | (`Esc|`Exp) as state -> fail (Unterminated state) off in + parse 0 [] nil + + + let parse prog tree = + let rec exp : tree -> ast = fun tree -> + let cons data : ast = {data; id=tree.id; eq=tree.eq} in + + let if_ = function + | c::e1::e2 :: es -> cons (Ite (exp c,exp e1,seq e2 es)) + | _ -> bad_form "if" tree in + + let let_ = function + | {data=List bs} :: e :: es -> + List.fold_right bs ~init:(seq e es) ~f:(fun b e -> + match b with + | {data=List [v; x]; id; eq} -> + {data=Let (let_var v,exp x,e); id; eq} + | s -> fail Bad_let_binding s) + | _ -> bad_form "let" tree in + + let while_ = function + | c :: e :: es -> cons (Rep (exp c, seq e es)) + | _ -> bad_form "while" tree in + + let msg = function + | {data=Atom msg} as tree :: es when is_quoted msg -> + cons (Msg (fmt prog msg tree, exps es)) + | _ -> bad_form "msg" tree in + + let set = function + | [v; e] -> cons (Set (let_var v, exp e)) + | _ -> bad_form "set" tree in + + let prog_ es = cons (Seq (exps es)) in + + let error = function + | [{data=Atom msg}] when is_quoted msg -> cons (Err msg) + | _ -> bad_form "error" tree in + + let forms = [ + "if", if_; + "let", let_; + "prog", prog_; + "while", while_; + "msg", msg; + "set", set; + "error", error; + ] in + + let macro op args = match Resolve.macro prog macro op args with + | None -> cons (App (Dynamic op, exps args)) + | Some (Ok (macro,bs)) -> exp (Def.Macro.apply macro bs) + | Some (Error err) -> fail (Unresolved (Macro,op,err)) tree in + + let list : tree list -> ast = function + | [] -> cons (Int {data=nil; id=tree.id; eq=tree.eq}) + | {data=List _} as s :: _ -> fail Bad_app s + | {data=Atom op} :: exps -> + match List.Assoc.find ~equal:String.equal forms op with + | None -> macro op (expand prog exps) + | Some form -> form exps in + + let sym ({id;eq;data=r} as s) = + if is_symbol r then cons (Sym { s with data = symbol s.data}) + else match Var.read id eq r with + | Error e -> fail (Bad_var_literal e) tree + | Ok v -> cons (Var v) in + + let lit ({id; eq; data=r} as t) = match Word.read id eq r with + | Ok x -> cons (Int x) + | Error Not_an_int -> sym t + | Error other -> fail (Bad_word_literal other) tree in + + let start : tree -> ast = function + | {data=List xs} -> list xs + | {data=Atom x} as t -> match Resolve.const prog const x () with + | None -> lit {t with data=x} + | Some Error err -> fail (Unresolved (Const,x,err)) t + | Some (Ok (const,())) -> exp (Def.Const.value const) in + start tree + and seq e es = {data=Seq ((exp e) :: exps es); id=Id.null; eq=Eq.null} + and exps : tree list -> ast list = fun xs -> List.map xs ~f:exp in + exp tree + + let params = function + | {data=List vars} -> List.map ~f:let_var vars + | s -> fail Bad_param_list s + + let atom = function + | {data=Atom x} -> x + | tree -> fail Bad_macro_param_list tree + + let metaparams = function + | {data=List vars} -> List.map ~f:atom vars + | s -> fail Bad_macro_param_list s + + + let parse_declaration attrs tree = + try Attribute.parse attrs tree + with Attribute.Bad_syntax (err,trees) -> + raise (Attribute_parse_error (err,trees,tree)) + + let parse_declarations attrs = + List.fold ~init:attrs ~f:Attribute.parse + + let ascii xs = + let rec loop xs acc = match xs with + | [] -> acc + | {data=Atom x} as s :: xs when is_quoted x -> + let x = try Scanf.unescaped x + with Scanf.Scan_failure _ -> fail Bad_ascii s in + String.fold x ~init:acc ~f:(fun acc x -> + {data=x; id = s.id; eq = Eq.null} :: acc) |> + loop xs + | here :: _ -> fail Bad_ascii here in + List.rev_map (loop xs []) ~f:(fun c -> + {c with data=Atom (sprintf "%#02x" (Char.to_int c.data))}) + + let is_odd x = x mod 2 = 1 + + let hex xs = + let rec loop xs acc = match xs with + | [] -> List.rev acc + | {data=List _} as here :: _ -> fail Bad_hex here + | {data=Atom x} as s :: xs -> + let x = if is_odd (String.length x) then "0" ^ x else x in + String.foldi x ~init:acc ~f:(fun i acc _ -> + if is_odd i + then {s with data=Atom (sprintf "0x%c%c" x.[i-1] x.[i])} :: acc + else acc) |> + loop xs in + loop xs [] + + let reader = function + | None -> ident + | Some {data=Atom ":ascii"} -> ascii + | Some {data=Atom ":hex"} -> hex + | Some here -> fail Unknown_subst_syntax here + + let is_keyarg = function + | {data=Atom s} -> Char.(s.[0] = ':') + | _ -> false + + let constrained prog attrs = + match Attribute.Set.get attrs Context.t with + | None -> prog + | Some constraints -> + Program.with_context prog @@ + Context.merge (Program.context prog) constraints + + let defun ?docs ?(attrs=[]) name p body prog gattrs tree = + let attrs = parse_declarations gattrs attrs in + let es = List.map ~f:(parse (constrained prog attrs)) body in + Program.add prog func @@ Def.Func.create ?docs ~attrs name (params p) { + data = Seq es; + id = tree.id; + eq = tree.eq; + } tree + + let defmethod ?docs ?(attrs=[]) name p body prog gattrs tree = + let attrs = parse_declarations gattrs attrs in + let es = List.map ~f:(parse (constrained prog attrs)) body in + Program.add prog meth @@ Def.Meth.create ?docs ~attrs name (params p) { + data = Seq es; + id = tree.id; + eq = tree.eq; + } tree + + let defmacro ?docs ?(attrs=[]) name ps body prog gattrs tree = + Program.add prog macro @@ + Def.Macro.create ?docs + ~attrs:(parse_declarations gattrs attrs) name + (metaparams ps) + body tree + + let defparameter ?docs ?(attrs=[]) name body prog gattrs tree = + let attrs = parse_declarations gattrs attrs in + Program.add prog para @@ + Def.Para.create ?docs + ~attrs name (parse (constrained prog attrs) body) tree + + let defsubst ?docs ?(attrs=[]) name body prog gattrs tree = + let syntax = match body with + | s :: _ when is_keyarg s -> Some s + | _ -> None in + Program.add prog subst @@ + Def.Subst.create ?docs + ~attrs:(parse_declarations gattrs attrs) name + (reader syntax body) tree + + let defconst ?docs ?(attrs=[]) name body prog gattrs tree = + Program.add prog const @@ + Def.Const.create ?docs + ~attrs:(parse_declarations gattrs attrs) name ~value:body tree + + let toplevels = String.Set.of_list [ + "declare"; + "defconstant"; + "defparameter"; + "defmacro"; + "defsubst"; + "defun"; + "defmethod"; + "require"; + ] + + let declaration gattrs s = match s with + | {data=List ({data=Atom "declare"} :: attrs)} -> + parse_declarations gattrs attrs + | {data=List ({data=Atom toplevel} as here :: _)} -> + if Set.mem toplevels toplevel then gattrs + else fail Unknown_toplevel here + | _ -> fail Bad_toplevel s + + + let stmt gattrs state s = match s with + | {data = List ( + {data=Atom "defun"} :: + {data=Atom name} :: + params :: + {data=Atom docs} :: + {data=List ({data=Atom "declare"} :: attrs)} :: + body) + } when is_quoted docs -> + defun ~docs ~attrs name params body state gattrs s + | {data = List ( + {data=Atom "defun"} :: + {data=Atom name} :: + params :: + {data=Atom docs} :: + body) + } when is_quoted docs -> + defun ~docs name params body state gattrs s + | {data = List ( + {data=Atom "defun"} :: + {data=Atom name} :: + params :: + {data=List ({data=Atom "declare"} :: attrs)} :: + body) + } -> + defun ~attrs name params body state gattrs s + | {data = List ( + {data=Atom "defun"} :: + {data=Atom name} :: + params :: + body) + } -> + defun name params body state gattrs s + | {data=List ({data=Atom "defun"} :: _)} -> fail (Bad_def Func) s + | {data = List ( + {data=Atom "defmethod"} :: + {data=Atom name} :: + params :: + {data=Atom docs} :: + {data=List ({data=Atom "declare"} :: attrs)} :: + body) + } when is_quoted docs -> + defmethod ~docs ~attrs name params body state gattrs s + | {data = List ( + {data=Atom "defmethod"} :: + {data=Atom name} :: + params :: + {data=Atom docs} :: + body) + } when is_quoted docs -> + defmethod ~docs name params body state gattrs s + | {data = List ( + {data=Atom "defmethod"} :: + {data=Atom name} :: + params :: + {data=List ({data=Atom "declare"} :: attrs)} :: + body) + } -> + defmethod ~attrs name params body state gattrs s + | {data = List ( + {data=Atom "defmethod"} :: + {data=Atom name} :: + params :: + body) + } -> + defmethod name params body state gattrs s + | {data=List ({data=Atom "defmethod"} :: _)} -> fail (Bad_def Meth) s + | {data = List [ + {data=Atom "defparameter"}; + {data=Atom name}; + body + ]} -> + defparameter name body state gattrs s + | {data = List [ + {data=Atom "defparameter"}; + {data=Atom name}; + body; + {data=List ({data=Atom "declare"} :: attrs)} + ]} -> + defparameter ~attrs name body state gattrs s + | {data = List [ + {data=Atom "defparameter"}; + {data=Atom name}; + body; + {data=Atom docs}; + ]} -> + defparameter ~docs name body state gattrs s + | {data = List [ + {data=Atom "defparameter"}; + {data=Atom name}; + body; + {data=Atom docs}; + {data=List ({data=Atom "declare"} :: attrs)} + ]} -> + defparameter ~attrs ~docs name body state gattrs s + | {data=List ({data=Atom "defparameter"} :: _)} -> fail (Bad_def Para) s + | _ -> state + + + let meta gattrs state s = match s with + | {data=List [ + {data=Atom "defconstant"}; + {data=Atom name}; + {data=Atom body}; + {data=Atom docs}; + {data=List ({data=Atom "declare"} :: attrs)}; + ]} when is_quoted docs -> + defconst ~docs ~attrs name body state gattrs s + | {data=List [ + {data=Atom "defconstant"}; + {data=Atom name}; + {data=Atom body }; + {data=Atom docs}; + ]} when is_quoted docs -> + defconst ~docs name body state gattrs s + | {data=List [ + {data=Atom "defconstant"}; + {data=Atom name}; + {data=Atom body}; + {data=List ({data=Atom "declare"} :: attrs)}; + ]} -> + defconst ~attrs name body state gattrs s + | {data=List [ + {data=Atom "defconstant"}; + {data=Atom name}; + {data=Atom body }; + ]} -> + defconst name body state gattrs s + + | {data=List [ + {data=Atom "defmacro"}; + {data=Atom name}; + params; + {data=Atom docs}; + {data=List ({data=Atom "declare"} :: attrs)}; + body]} when is_quoted docs -> + defmacro ~docs ~attrs name params body state gattrs s + | {data=List [ + {data=Atom "defmacro"}; + {data=Atom name}; + params; + {data=Atom docs}; + body]} when is_quoted docs -> + defmacro ~docs name params body state gattrs s + | {data=List [ + {data=Atom "defmacro"}; + {data=Atom name}; + params; + {data=List ({data=Atom "declare"} :: attrs)}; + body]} -> + defmacro ~attrs name params body state gattrs s + | {data=List [ + {data=Atom "defmacro"}; + {data=Atom name}; + params; + body]} -> + defmacro name params body state gattrs s + + | {data=List ( + {data=Atom "defsubst"} :: + {data=Atom name} :: + {data=Atom docs} :: + {data=List ({data=Atom "declare"} :: attrs)} :: + body)} when is_quoted docs -> + defsubst ~docs ~attrs name body state gattrs s + | {data=List ( + {data=Atom "defsubst"} :: + {data=Atom name} :: + {data=Atom docs} :: + body)} when is_quoted docs -> + defsubst ~docs name body state gattrs s + | {data=List ( + {data=Atom "defsubst"} :: + {data=Atom name} :: + {data=List ({data=Atom "declare"} :: attrs)} :: + body)} -> + defsubst ~attrs name body state gattrs s + | {data=List ( + {data=Atom "defsubst"} :: + {data=Atom name} :: + body)} -> + defsubst name body state gattrs s + | {data=List ({data=Atom "defsubst"}::_)} -> fail (Bad_def Subst) s + | {data=List ({data=Atom "defmacro"}::_)} -> fail (Bad_def Macro) s + | {data=List ({data=Atom "defconst"}::_)} -> fail (Bad_def Const) s + | _ -> state + + let declarations = + List.fold ~init:Attribute.Set.empty ~f:declaration + + let source constraints source = + let init = Program.with_context Program.empty constraints in + let init = Program.with_sources init source in + let state = Source.fold source ~init ~f:(fun _ trees state -> + List.fold trees ~init:state ~f:(meta (declarations trees))) in + Source.fold source ~init:state ~f:(fun _ trees state -> + List.fold trees ~init:state ~f:(stmt (declarations trees))) +end + +module Load = struct + let file_of_feature paths feature = + let name = feature ^ ".lisp" in + List.find_map paths ~f:(fun path -> + Sys.readdir path |> Array.find_map ~f:(fun file -> + if String.(file = name) + then Some (Filename.concat path file) + else None)) + + + let is_loaded p name = Option.is_some (Source.find p name) + + let load_tree paths p feature loc = + match file_of_feature paths feature with + | None -> + raise (Fail (Unresolved_feature (feature,loc))) + | Some name -> match Source.find p name with + | Some _ -> p + | None -> match Source.load p name with + | Ok p -> p + | Error err -> raise (Fail (Sexp_error (feature,loc,err))) + + let load_trees paths p features = + Map.fold ~init:p features ~f:(fun ~key ~data p -> + load_tree paths p key data) + + let parse_require tree = match tree with + | {data=List [{data=Atom "require"}; {data=Atom name}]} -> + Some (Ok name) + | {data=List ({data=Atom "require"} :: xs)} -> Some (Error xs) + | _ -> None + + let required paths p = + Source.fold p ~init:String.Map.empty ~f:(fun _ trees required -> + List.fold trees ~init:required ~f:(fun required tree -> + match parse_require tree with + | None -> required + | Some (Error trees) -> + raise (Fail (Parse_error (Bad_require,loc p trees))) + | Some (Ok name) -> + let pos = Module (Source.loc p tree.id) in + match file_of_feature paths name with + | None -> raise (Fail (Unresolved_feature (name,pos))) + | Some file -> match Source.find p file with + | None -> Map.set required name pos + | Some _ -> required)) + + let transitive_closure paths p = + let rec fixpoint p = + let required = required paths p in + if Map.is_empty required then p + else fixpoint (load_trees paths p required) in + fixpoint p + + let features_of_list = + List.fold ~init:String.Map.empty ~f:(fun fs f -> + Map.set fs ~key:f ~data:Cmdline) + + let features ?(paths=[Filename.current_dir_name]) ctxt fs = + let source = + load_trees paths Source.empty (features_of_list fs) |> + transitive_closure paths in + try + Parse.source ctxt source + with Parse_error (err,trees) -> + raise (Fail (Parse_error (err, loc source trees))) +end + +let program ?paths proj features = + try Ok (Load.features ?paths proj features) + with Fail e -> Error e + +let string_of_typ_error = function + | Type.Empty -> "empty string can't be used as type expression" + | Type.Not_sexp -> "type expression is not a well-formed sexp" + | Type.Bad_sort -> "the sort definition is not recognized, expects: +sort ::= + | (? starting with a lowercase letter ?) + | (? starting with the upper case letter ?) + | () +sort-param ::= | +" + +let string_of_var_error = function + | Var.Empty -> "empty string can't be used as a variable name" + | Var.Not_a_var -> "not a valid identifier" + | Var.Bad_format -> "variable name contains extra `:' symbol" + | Var.Bad_type e -> string_of_typ_error e + +let string_of_word_error = function + | Word.Empty -> "an empty string" + | Word.Not_an_int -> "doesn't start with a number" + | Word.Bad_literal -> + "must start with a digit and contain no more than one `:' symbol" + | Word.Unclosed -> "unmatching single quote in a character literal" + | Word.Bad_type e -> string_of_typ_error e + +let string_of_form_syntax = function + | "if" -> "(if ...)" + | "let" -> "(let (( ) ...) ...)" + | "while" -> "(while ...)" + | "msg" -> "(msg ...)" + | "set" -> "(set )" + | "prog" -> "(prog ...)" + | "error" -> {|(error "")|} + | _ -> assert false + +let string_of_defkind = function + | Func -> "function" + | Meth -> "method" + | Para -> "parameter" + | Macro -> "macro" + | Const -> "contant" + | Subst -> "substitution" + + +let string_of_def_syntax = function + | Func -> "(defun ( ...) [] [] ...)" + | Meth -> "(defmethod ( ...) [] [] ..." + | Para -> "(defparameter [] [])" + | Macro -> "(defmacro ( ...) [] [] )" + | Const -> "(defconstant [] [] )" + | Subst -> "(defsubst [] [] [:] ...)" + +let pp_parse_error ppf err = match err with + | Bad_var_literal e -> + fprintf ppf "bad variable literal - %s" (string_of_var_error e) + | Bad_word_literal e -> + fprintf ppf "bad word literal - %s" (string_of_word_error e) + | Bad_form name -> + fprintf ppf "bad %s syntax - expected %s" name (string_of_form_syntax name) + | Bad_app -> + fprintf ppf "head of the list in the application form shall be an atom" + | Bad_let_binding -> + fprintf ppf "expected a variable literal" + | Bad_param_list -> + fprintf ppf "expected a list of variables" + | Bad_macro_param_list -> + fprintf ppf "expected a list of atoms" + | Bad_require -> + fprintf ppf "expected (require )" + | Unknown_toplevel -> + fprintf ppf "unknown toplevel form" + | Bad_toplevel -> + fprintf ppf "bad toplevel syntax" + | Bad_ascii | Bad_hex -> + fprintf ppf "expected a list of atoms" + | Unknown_subst_syntax -> + fprintf ppf "unknown substitution syntax" + | Unresolved (k,n,r) -> + fprintf ppf "unable to resolve %s `%s', because %a" + (string_of_defkind k) n Resolve.pp_resolution r + | Bad_def def -> + fprintf ppf "bad %s definition, expect %s" + (string_of_defkind def) (string_of_def_syntax def) + +let pp_request ppf req = match req with + | Cmdline -> + fprintf ppf "requested by a user" + | Module loc -> + fprintf ppf "requested in %a" Loc.pp loc + +let pp_format_error ppf err = match err with + | Expect_digit -> fprintf ppf "expected digit" + | Unterminated `Esc -> fprintf ppf "unterminated escape sequence" + | Unterminated `Exp -> fprintf ppf "unterminated reference" + | Illegal_escape c -> fprintf ppf "illegal escape character '%c'" c + +let pp_error ppf err = match err with + | Parse_error (err,loc) -> + fprintf ppf "%a@\nParse error: %a" Loc.pp loc pp_parse_error err + | Sexp_error (name,req,err) -> + fprintf ppf "%a@\nOccured when parsing feature %s %a" + Source.pp_error err name pp_request req + | Unresolved_feature (name,req) -> + fprintf ppf "Error: no implementation provided for feature `%s' %a" + name pp_request req + | Unknown_attr (attr,loc) -> + fprintf ppf "%a@\nError: unknown attribute %s@\n" Loc.pp loc attr + | Format_error (err,loc) -> + fprintf ppf "%a@\nFormat error: %a" Loc.pp loc pp_format_error err + +let pp_program = Program.pp diff --git a/lib/bap_lisp/bap_lisp__parse.mli b/lib/bap_lisp/bap_lisp__parse.mli new file mode 100644 index 000000000..2085a0854 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__parse.mli @@ -0,0 +1,11 @@ + +module Context = Bap_lisp__context +module Program = Bap_lisp__program + +type error + +val program : ?paths:string list -> Context.t -> string list -> + (Program.t,error) result + +val pp_error : Format.formatter -> error -> unit +val pp_program : Format.formatter -> Program.t -> unit diff --git a/lib/bap_lisp/bap_lisp__program.ml b/lib/bap_lisp/bap_lisp__program.ml new file mode 100644 index 000000000..a5b49eb98 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__program.ml @@ -0,0 +1,820 @@ +open Core_kernel +open Graphlib.Std +open Regular.Std +open Monads.Std +open Bap_core_theory +open Bap_lisp__types +open Format + + +module Lisp = struct + module Context = Bap_lisp__context + module Var = Bap_lisp__var + module Type = Bap_lisp__type +end + +module Def = Bap_lisp__def + +type t = { + context : Lisp.Context.t; + sources : Source.t; + primits : Def.primitive Def.t list; + macros : Def.macro Def.t list; + substs : Def.subst Def.t list; + consts : Def.const Def.t list; + defs : Def.func Def.t list; + mets : Def.meth Def.t list; + pars : Def.para Def.t list; +} [@@deriving fields] + +type program = t + +let empty = { + context = Lisp.Context.empty; + sources = Source.empty; + primits = []; + defs = []; + mets = []; + pars = []; + macros=[]; + substs=[]; + consts=[]; +} + +type 'a item = ([`Read | `Set_and_create ], t, 'a Def.t list) Fieldslib.Field.t_with_perm + +module Items = struct + let macro = Fields.macros + let subst = Fields.substs + let const = Fields.consts + let func = Fields.defs + let meth = Fields.mets + let para = Fields.pars + let primitive = Fields.primits +end + +let add p (fld : 'a item) x = + Field.fset fld p (x :: Field.get fld p) + +let get p (fld : 'a item) = Field.get fld p + +let with_context p context = {p with context} +let with_sources p sources = {p with sources} + +let (++) = Map.merge ~f:(fun ~key:_ -> function + | `Both (id,_) | `Left id | `Right id -> Some id) + +let union init xs ~f = + List.fold xs ~init ~f:(fun vs x -> vs ++ f x) + +type node = + | Entry + | Defun of Id.t + | Exit +[@@deriving compare] + +module Callgraph = struct + module G = Graphlib.Make(struct + type t = node + include Opaque.Make(struct + type t = node [@@deriving compare] + let hash = Hashtbl.hash + end) + end)(Unit) + + let empty = String.Set.empty + let (++) = Set.union + let call = String.Set.singleton + let union xs ~f = + List.map xs ~f |> + String.Set.union_list + + let rec calls = function + | {data=App ((Dynamic v),xs); } -> call v ++ union xs ~f:calls + | {data=(Var _ | Int _ | Sym _ | Err _)} -> empty + | {data=Ite (x,y,z)} -> calls x ++ calls y ++ calls z + | {data=(Seq xs | App (_,xs) | Msg (_,xs))} -> union xs ~f:calls + | {data=(Let (_,x,y) | Rep (x,y))} -> calls x ++ calls y + | {data=Set (_,x)} -> calls x + + + (** computes a mapping from name to id of definitions *) + let compute_ids defs = + List.fold defs ~init:String.Map.empty ~f:(fun ids def -> + Map.add_multi ids ~key:(Def.name def) ~data:def.id) + + let edge id id' = + G.Edge.create (Defun id) (Defun id') () + + let build_kernel defs = + let ids = + let ids = compute_ids defs in + fun name -> match Map.find ids name with + | None -> [] + | Some x -> x in + List.fold defs ~init:G.empty ~f:(fun g def -> + let g = G.Node.insert (Defun def.id) g in + Set.fold (calls (Def.Func.body def)) ~init:g ~f:(fun g name -> + List.fold (ids name) ~init:g ~f:(fun g id -> + G.Edge.insert (edge def.id id) g))) + + let close dir g = + let edge n = match dir with + | `In -> G.Edge.create Entry n () + | `Out -> G.Edge.create n Exit () in + G.nodes g |> Seq.fold ~init:g ~f:(fun g n -> + if G.Node.degree ~dir n g = 0 + then G.Edge.insert (edge n) g + else g) + + let build defs = + close `Out (close `In (build_kernel defs)) + + include G +end + +let pp_callgraph ppf g = + Graphlib.to_dot (module Callgraph) + ~formatter:ppf + ~string_of_node:(function + | Entry -> "" + | Exit -> "" + | Defun id -> asprintf "%a" Id.pp id) + g + +let pp_term pp_exp ppf = function + | {data={exp; typ=Any}} -> + fprintf ppf "%a" pp_exp exp + | {data={exp; typ}} -> + fprintf ppf "%a:%a" pp_exp exp Lisp.Type.pp typ +let pp_word = pp_term Bitvec.pp +let pp_var = pp_term String.pp + +let rec concat_prog xs = + List.concat_map xs ~f:(function + | {data=Seq xs} -> concat_prog xs + | x -> [x]) + +module Ast = struct + let rec pp ppf {data} = pp_exp ppf data + and pp_exp ppf = function + | Int x -> + pp_word ppf x + | Sym x -> + pp_print_string ppf x.data + | Var x -> + pp_var ppf x + | Ite (c,t,e) -> + fprintf ppf "@[<2>(if@ %a@;<1 2>%a@ %a)@]" pp c pp t pp_prog e + | Let (v,e,b) -> + fprintf ppf "@[(let@;<1 2>@[<2>(%a@ %a)@]@ %a)@]" pp_var v pp e pp b + | App (b,xs) -> + fprintf ppf "@[<2>(%a@ %a)@]" pp_binding b pp_exps xs; + | Seq [] -> fprintf ppf "()" + | Seq [x] -> pp ppf x + | Seq xs -> + fprintf ppf "@[<2>(prog@ @[%a@])@]" pp_exps (concat_prog xs) + | Set (v,x) -> + fprintf ppf "@[<2>(set@ %a@ %a)@]" pp_var v pp x + | Rep (c,b) -> + fprintf ppf "@[<2>(while@;<1 2>%a@ @[%a@])@]" pp c pp_prog b + | Msg (f,es) -> + fprintf ppf "@[<2>(msg@ \"%a\"@ %a)@]" pp_fmt f pp_exps es; + | Err msg -> + fprintf ppf "@[<2>(error@ %s)@]" msg + and pp_binding ppf = function + | Dynamic x -> fprintf ppf "%s" x + | Static _ -> fprintf ppf "" + and pp_exps ppf xs = pp_print_list ~pp_sep:pp_print_space pp ppf xs + and pp_fmt ppf xs = pp_print_list pp_fmt_elt ppf xs + and pp_fmt_elt ppf = function + | Lit s -> pp_print_string ppf s + | Pos n -> fprintf ppf "$%d" n + and pp_prog ppf = function + | {data=Seq xs} -> + fprintf ppf "%a" pp_exps (concat_prog xs) + | exp -> pp ppf exp +end + +let pp_def ppf d = + fprintf ppf "@[<2>(defun %s @[<2>(%a)@]@ %a)@]@," + (Def.name d) + (pp_print_list ~pp_sep:pp_print_space pp_var) (Def.Func.args d) + Ast.pp_prog (Def.Func.body d) + +let pp_met ppf d = + fprintf ppf "@[<2>(defmethod %s @[<2>(%a)@]@ %a)@]@," + (Def.name d) + (pp_print_list ~pp_sep:pp_print_space pp_var) (Def.Meth.args d) + Ast.pp_prog (Def.Meth.body d) + +let pp_par ppf d = + fprintf ppf "@[<2>(defparamerter %s@,%a@,%S)@]" + (Def.name d) + Ast.pp_prog (Def.Para.default d) + (Def.docs d) + +let pp ppf {pars; mets; defs;} = + let pp_items pp items = + fprintf ppf "@[%a@]" (pp_print_list pp) items in + pp_items pp_par pars; + pp_items pp_met mets; + pp_items pp_def defs + + +module Use = struct + let empty = String.Map.empty + let union = union empty + let use = String.Map.singleton + + type t = { + calls : Id.t String.Map.t Id.Map.t; + vars : Id.t String.Map.t Id.Map.t; (* def -> var -> use *) + } + + let vars bound ast = + let use bound {exp=v} id = + if Set.mem bound v then String.Map.empty + else use v id in + let rec free bound = function + | {data=(Int _ | Err _ | Sym _)} -> empty + | {data=Var v; id} -> use bound v.data id + | {data=Ite (x,y,z)} -> free bound x ++ free bound y ++ free bound z + | {data=Let (v,x,y)} -> free bound x ++ free (Set.add bound v.data.exp) y + | {data=(Seq xs | App (_,xs) | Msg (_,xs))} -> union xs ~f:(free bound) + | {data=Set (v,x); id} -> use bound v.data id ++ free bound x + | {data=Rep (x,y)} -> free bound x ++ free bound y in + free bound ast + + let rec calls = function + | {data=App ((Dynamic v),_); id} -> use v id + | {data=(Var _ | Int _ | Sym _ | Err _)} -> empty + | {data=Ite (x,y,z)} -> calls x ++ calls y ++ calls z + | {data=(Seq xs | App (_,xs) | Msg (_,xs))} -> union xs ~f:calls + | {data=(Let (_,x,y) | Rep (x,y))} -> calls x ++ calls y + | {data=Set (_,x)} -> calls x + + let collect {defs} = + let init = {calls = Id.Map.empty; vars = Id.Map.empty} in + List.fold ~init defs ~f:(fun s def -> + let body = Def.Func.body def in + let bound = Def.Func.args def |> + List.map ~f:(fun {data={exp}} -> exp) |> + String.Set.of_list in + let vs = vars bound body in + let cs = calls body in + { + calls = Map.set s.calls ~key:def.id ~data:cs; + vars = Map.set s.vars ~key:def.id ~data:vs; + }) +end + +(** Assign fresh indices to trees that were produced my macros or that + ** has no indices at all. + ** + ** We first scan through all meta definitions (i.e., macros, substs, + ** and consts) to obtain a set of indices that we shall rewrite, and + ** then perform rewriting for all program definitions (defs, mets, + ** and pars) + ** + ** The newly generated Ids are derived (i.e., associated) with their + ** base ids, so that if needed their origin can be always + ** established. (except if their origin was the null identifier). + ** + ** Motivation: since we identify an ast by its identifier, we want + ** the trees produced by the term rewriting to have different + ** identifiers. Otherwise, they could be unified, for example in the + ** Type checker. + **) +module Reindex = struct + module State = Monad.State.Make(Source)(Monad.Ident) + open State.Syntax + type 'a m = 'a Monad.State.T1(Source)(Monad.Ident).t + + let rec ids_of_trees trees = + List.fold trees ~init:Id.Set.empty ~f:(fun xs t -> match t with + | {data=Atom _; id} -> Set.add xs id + | {data=List ts;id} -> + Set.union (Set.add xs id) (ids_of_trees ts)) + + let ids_of_defs defs map reduce = + Id.Set.union_list [ + Id.Set.of_list @@ List.map defs ~f:(fun d -> d.id); + map defs ~f:reduce |> ids_of_trees + ] + + let macro_ids p = Id.Set.union_list [ + ids_of_defs p.macros List.map Def.Macro.body; + ids_of_defs p.consts List.map Def.Const.value; + ids_of_defs p.substs List.concat_map Def.Subst.body; + ] + + let derive from = + State.get () >>= fun src -> + let nextid = Id.next (Source.lastid src) in + State.put (Source.derived src ~from nextid) >>| fun () -> + nextid + + let reindex (get,set) macros def = + let rename t = + if Set.mem macros t.id || Id.null = t.id + then derive t.id >>| fun id -> {t with id} + else State.return t in + let rec map : ast -> ast m = fun t -> + rename t >>= fun t -> match t.data with + | Err _ -> State.return t + | Int x -> + rename x >>| fun x -> + {t with data = Int x} + | Sym s -> + rename s >>| fun s -> + {t with data = Sym s} + | Var v -> + rename v >>| fun v -> + {t with data = Var v} + | Ite (x,y,z) -> + map x >>= fun x -> + map y >>= fun y -> + map z >>| fun z -> + {t with data = Ite (x,y,z)} + | Let (c,x,y) -> + rename c >>= fun c -> + map x >>= fun x -> + map y >>| fun y -> + {t with data = Let (c,x,y)} + | Rep (x,y) -> + map x >>= fun x -> + map y >>| fun y -> + {t with data = Rep (x,y)} + | App (b,xs) -> + map_all xs >>| fun xs -> + {t with data = App (b,xs)} + | Msg (f,xs) -> + map_all xs >>| fun xs -> + {t with data = Msg (f,xs)} + | Seq xs -> + map_all xs >>| fun xs -> + {t with data = Seq xs} + | Set (v,x) -> + rename v >>= fun v -> + map x >>| fun x -> + {t with data = Set (v,x)} + and map_all xs = State.List.map xs ~f:map in + map (get def) >>| set def + + let reindex_all p = + let def = Def.Func.body,Def.Func.with_body in + let met = Def.Meth.body,Def.Meth.with_body in + let par = Def.Para.default,Def.Para.with_default in + let macros = macro_ids p in + State.List.map p.defs ~f:(reindex def macros) >>= fun defs -> + State.List.map p.mets ~f:(reindex met macros) >>= fun mets -> + State.List.map p.pars ~f:(reindex par macros) >>= fun pars -> + State.return (defs,mets,pars) + + let program p = + let (defs,mets,pars),sources = + State.run (reindex_all p) p.sources in + {p with defs; mets; pars; sources} + +end + +module Typing = struct + (* An expression in Primus Lisp gradual type system may have several + types, e.g., [(if c 123 'hello)] is a well-typed expression that + has type int+sym. The int+sym type is a disjunctive type, + or a polytype. Our type system is _soft_ as we have type Any, + that denotes a disjunction (join in our parlance) of all + types. The set of type expressions forms a lattice with the Any + type representing the Top element (all possible types). The Bot + type is an empty disjunction. + + Our type inference system is a mixture of flow based and + inference based analysis. We infer a type of a function, + and find a fixed point solution for a set of (possibly + mutually recursive) functions. The inferred type is the upper + approximation of a program behavior, i.e., it doesn't guarantee + an absence of runtime errors, though it guarantees some + consistency of the program static properties. + *) + + type sort = Theory.Value.Sort.Top.t + [@@deriving compare, sexp_of] + + (* Type value (aka type). A program value could be either a symbol + or a bitvector with the given width. All types have the same + runtime representation (modulo bitwidth). *) + type tval = + | Tsym + | Grnd of sort + [@@deriving compare, sexp_of] + + module Tval = Comparable.Make_plain(struct + type t = tval [@@deriving compare, sexp_of] + end) + + (* type variables + we allow users to specify type variables, the rest of the type + variables are created by using term identifiers of corresponding + program terms (thus we don't need to create fresh type variables).*) + type tvar = + | Name of string + | Tvar of Id.t + [@@deriving compare, sexp_of] + + module Tvar = Comparable.Make_plain(struct + type t = tvar [@@deriving compare, sexp_of] + end) + + (** Typing environment. + + Typing constraint is built as a composition of rules, where each + rule is a function of type [gamma -> gamma], so the rules can be + composed with the function composition operator. + *) + module Gamma : sig + type t [@@deriving compare, sexp_of] + + type rule = t -> t + + val empty : t + + (** [get gamma exp] returns a type of the expression [exp]. + If [None] is returned, then the expression doesn't have any + statical constraints, so its type is [Any]. If some set is + returned, then this set denotes a disjunction of types, that a + term can have during program evaluation. If this set is empty, + then the term is ill-typed. + *) + val get : t -> Id.t -> Tval.Set.t option + + + val merge : t -> t -> t + + (** [exps gamma] returns a list of typed expressions. *) + val exps : t -> Id.t list + + (** [constr exp typ] expression [exp] shall have type [typ] *) + val constr : Id.t -> typ -> rule + + (** [meet x y] types of [x] and [y] shall have the same type *) + val meet : Id.t -> Id.t -> rule + + (** [join x ys] expression [x] shall have a type that is a + disjunction of the types of expressions specified by the [ys] list. *) + val join : Id.t -> Id.t list -> rule + + end = struct + (* typing environment. + + [vars] associates each program term with a type variable. It is a + disjoint set that partitions the set of program terms into + equivalence classes, such that two terms belonging to the same + set will have the same type. + + [vals] is the typing environment that associates each type + variable with the sum of type values (ground types). If an type + variable is not mapped in [vals] then it is assumed to has type + Top (i.e., it is a set of all possible types). An empty set + denotes the bottom type, i.e., all expressions that has that type + are ill-typed.*) + type t = { + vars : tvar Id.Map.t; + vals : Tval.Set.t Tvar.Map.t; + } [@@deriving compare, sexp_of] + + type rule = t -> t + + let empty = { + vars = Id.Map.empty; + vals = Tvar.Map.empty; + } + + let exps {vars} = Map.keys vars + + let merge g g' = { + vars = Map.merge g.vars g'.vars ~f:(fun ~key:_ -> function + | `Left t | `Right t -> Some t + | `Both (_,t') -> Some t'); + vals = Map.merge g.vals g'.vals ~f:(fun ~key:_ -> function + | `Left ts | `Right ts -> Some ts + | `Both (_,ts) -> Some ts) + + } + + + let add_var id t g = + {g with vars = Map.set g.vars ~key:id ~data:t} + + let add_val t v g = + {g with vals = Map.set g.vals ~key:t ~data:v} + + let unify t1 t2 g = + let t = Tvar.min t1 t2 in + match Map.find g.vals t1, Map.find g.vals t2 with + | None, None -> g + | None, Some v | Some v, None -> add_val t v g + | Some v, Some v' -> add_val t (Set.inter v v') g + + let meet id1 id2 g = + let t,g = match Map.find g.vars id1, Map.find g.vars id2 with + | None,None -> Tvar.min (Tvar id1) (Tvar id2),g + | None,Some t | Some t,None -> t,g + | Some u,Some v -> Tvar.min u v, unify u v g in + add_var id2 t @@ + add_var id1 t g + + + let get g id = match Map.find g.vars id with + | None -> None + | Some rep -> Map.find g.vals rep + + + let inter_list = function + | [] -> None + | x :: xs -> Some (List.fold ~init:x xs ~f:Set.inter) + + let join id ids g = + List.filter_map ids ~f:(get g) |> + inter_list |> function + | None -> g + | Some vs -> match Map.find g.vars id with + | None -> { + vars = Map.set g.vars ~key:id ~data:(Tvar id); + vals = Map.set g.vals ~key:(Tvar id) ~data:vs; + } + | Some v -> { + g with vals = Map.update g.vals v ~f:(function + | None -> vs + | Some vs' -> Set.union vs vs') + } + + let constr_name id n g = + let g' = add_var id (Name n) g in + match Map.find g.vars id with + | None -> g' + | Some v -> Map.fold g'.vars ~init:g ~f:(fun ~key:id' ~data:v' g -> + if Tvar.equal v v' + then meet id id' g + else g) + + let constr_grnd id t g = + let v = match Map.find g.vars id with + | None -> Tvar id + | Some v -> v in + let g = add_var id v g in + let t = Tval.Set.singleton t in + {g with + vals = Map.update g.vals v ~f:(function + | None -> t + | Some t' -> Set.inter t t') + } + + let constr id t g = match t with + | Any -> g + | Name n -> constr_name id n g + | Symbol -> constr_grnd id Tsym g + | Type n -> constr_grnd id (Grnd n) g + + end + + type signature = Lisp.Type.signature = { + args : typ list; + rest : typ option; + ret : typ; + } + + type t = { + ctxt : Lisp.Context.t; + globs : sort String.Map.t; + prims : signature String.Map.t; + funcs : Def.func Def.t list; + } + + + let pp_args ppf args = + pp_print_list Lisp.Type.pp ppf args + + let pp_signature ppf {args; rest; ret} = + fprintf ppf "@[(%a" pp_args args; + Option.iter rest ~f:(fun rest -> + fprintf ppf "&rest %a" Lisp.Type.pp rest); + fprintf ppf ")@] => (%a)" Lisp.Type.pp ret + + let pp_tval ppf = function + | Tsym -> fprintf ppf "sym" + | Grnd s -> fprintf ppf "%a" Theory.Value.Sort.pp s + + let pp_plus ppf () = pp_print_char ppf '+' + let pp_tvals ppf tvals = + if Set.is_empty tvals + then fprintf ppf "nil" + else fprintf ppf "%a" + (pp_print_list ~pp_sep:pp_plus pp_tval) + (Set.elements tvals) + + let apply_signature appid ts g {args; rest; ret} = + let rec apply g ts ns = + match ts,ns with + | ts,[] -> Some g,ts + | [],_ -> None,[] + | t :: ts, n :: ns -> apply (Gamma.constr t.id n g) ts ns in + match apply g ts args with + | None,_ -> None + | Some g,ts -> + let g = Gamma.constr appid ret g in + match ts with + | [] -> Some g + | ts -> match rest with + | None -> None + | Some typ -> + Some (List.fold ts ~init:g ~f:(fun g t -> + Gamma.constr t.id typ g)) + + let type_of_expr g expr : typ = + match Gamma.get g expr.id with + | None -> Any + | Some ts -> match Set.elements ts with + | [Tsym] -> Symbol + | [Grnd n] -> Type n + | _ -> Any + + let type_of_exprs gamma exprs = + List.map exprs ~f:(type_of_expr gamma) + + let signature_of_gamma def gamma = { + rest = None; + ret = type_of_expr gamma (Def.Func.body def); + args = type_of_exprs gamma (Def.Func.args def); + } + + let signatures glob gamma name = + match Map.find glob.prims name with + | Some sign -> [sign] + | None -> List.fold glob.funcs ~init:[] ~f:(fun sigs def -> + if Def.name def = name + then signature_of_gamma def gamma :: sigs + else sigs) + + let join_gammas xs _why_is_it_ignored = xs + + let apply glob id name args gamma = + signatures glob gamma name |> + List.filter_map ~f:(apply_signature id args gamma) |> + List.reduce ~f:join_gammas |> function + | None -> gamma + | Some gamma -> gamma + + let last xs = match List.rev xs with + | {id} :: _ -> id + | _ -> assert false + + let constr_glob {globs} vars var gamma = + if Map.mem vars var.data.exp then gamma + else match Map.find globs var.data.exp with + | None -> gamma + | Some n -> Gamma.constr var.id (Type n) gamma + + + let push vars {data; id} = + Map.set vars ~key:data.exp ~data:id + + let varclass vars v = match Map.find vars v.data.exp with + | None -> v.id + | Some id -> id + + let pp_binding ppf (v,id) = + fprintf ppf "%s:%a" v Id.pp id + + let pp_sep ppf () = + fprintf ppf ", " + + let pp_vars ppf vs = + fprintf ppf "{%a}" (pp_print_list ~pp_sep pp_binding) @@ + Map.to_alist vs + + let sort s = Type (Theory.Value.Sort.forget s) + + let (++) f g x = f (g x) + + let infer_ast glob bindings ast : Gamma.t -> Gamma.t = + let rec infer vs expr = + match expr with + | {data=Sym _; id} -> + Gamma.constr id Symbol + | {data=Int x; id} -> + Gamma.constr id x.data.typ + | {data=Var v; id} -> + Gamma.meet v.id (varclass vs v) ++ + Gamma.meet id v.id ++ + constr_glob glob vs v ++ + Gamma.constr v.id v.data.typ + | {data=Ite (x,y,z); id} -> + Gamma.join id [y.id; z.id] ++ + infer vs x ++ + infer vs y ++ + infer vs z + | {data=Let (v,x,y); id} -> + Gamma.meet y.id id ++ + infer (push vs v) y ++ + Gamma.meet x.id v.id ++ + infer vs x ++ + Gamma.constr v.id v.data.typ + | {data=App ((Dynamic name),xs); id} -> + apply glob id name xs ++ + reduce vs xs + | {data=Seq []} -> ident + | {data=Seq xs; id} -> + Gamma.meet (last xs) id ++ + reduce vs xs + | {data=Set (v,x); id} -> + Gamma.join id [v.id; x.id] ++ + Gamma.constr v.id v.data.typ ++ + infer vs x + | {data=Rep (c,x); id} -> + Gamma.meet c.id id ++ + infer vs c ++ + infer vs x + | {data=Msg (_,xs); id} -> + Gamma.constr id (sort Theory.Bool.t) ++ + reduce vs xs + | {data=Err _} -> ident + | {data=App (Static _,_)} -> ident + and reduce vs = function + | [] -> ident + | x :: xs -> infer vs x ++ reduce vs xs in + infer bindings ast + + let find_func funcs id = + List.find funcs ~f:(fun f -> f.id = id) + + let transfer glob node gamma = + match node with + | Entry | Exit -> + gamma + | Defun id -> match find_func glob.funcs id with + | None -> gamma + | Some f -> + let args = Def.Func.args f in + let vars = List.fold args ~init:String.Map.empty ~f:push in + let gamma = List.fold args ~init:gamma ~f:(fun gamma v -> + Gamma.constr v.id v.data.typ gamma) in + let gamma = infer_ast glob vars (Def.Func.body f) gamma in + gamma + + let make_globs = + Seq.fold ~init:String.Map.empty ~f:(fun vars v -> + let data = Theory.Value.Sort.forget (Theory.Var.sort v) in + Map.set vars ~key:(Theory.Var.name v) ~data) + + let make_prims {primits} = + List.fold primits ~init:String.Map.empty ~f:(fun ps p -> + Map.set ps + ~key:(Def.name p) + ~data:(Def.Primitive.signature p)) + + let gamma_equal g1 g2 = Gamma.compare g1 g2 = 0 + + let infer vars p : Gamma.t = + let glob = { + ctxt = p.context; + prims = make_prims p; + globs = make_globs vars; + funcs = p.defs; + } in + let g = Callgraph.build p.defs in + let init = Solution.create Callgraph.Node.Map.empty Gamma.empty in + let equal = gamma_equal in + let fp = + Graphlib.fixpoint (module Callgraph) ~rev:true ~start:Exit + ~equal ~merge:Gamma.merge ~init ~f:(transfer glob) g in + Solution.get fp Entry + + (* The public interface *) + module Type = struct + type error = Loc.t * Source.Id.t + let check vars p : error list = + let p = Reindex.program p in + let gamma = infer vars p in + List.fold (Gamma.exps gamma) ~init:Loc.Map.empty ~f:(fun errs exp -> + assert (exp <> Source.Id.null); + if Source.has_loc p.sources exp + then match Gamma.get gamma exp with + | None -> errs + | Some ts -> + if Set.is_empty ts + then Map.set errs ~key:(Source.loc p.sources exp) ~data:exp + else errs + else errs) |> + Map.to_alist + + let pp_error ppf (loc,id) = + fprintf ppf "%a@\nType error - expression is ill-typed: %a" + Loc.pp loc Id.pp id + end +end + + + +module Context = Lisp.Context +module Type = Typing.Type diff --git a/lib/bap_lisp/bap_lisp__program.mli b/lib/bap_lisp/bap_lisp__program.mli new file mode 100644 index 000000000..a47e9395b --- /dev/null +++ b/lib/bap_lisp/bap_lisp__program.mli @@ -0,0 +1,36 @@ +open Core_kernel +open Bap_core_theory +open Bap_lisp__types +module Def = Bap_lisp__def +module Context = Bap_lisp__context + +type t +type program = t +type 'a item + +val empty : t +val add : t -> 'a item -> 'a Def.t -> t +val get : t -> 'a item -> 'a Def.t list +val context : t -> Context.t +val sources : t -> Source.t +val with_sources : t -> Source.t -> t +val with_context : t -> Context.t -> t + + +module Items : sig + val macro : Def.macro item + val subst : Def.subst item + val const : Def.const item + val func : Def.func item + val meth : Def.meth item + val para : Def.para item + val primitive : Def.primitive item +end + +module Type : sig + type error + val check : 'a Theory.Var.t Sequence.t -> program -> error list + val pp_error : Format.formatter -> error -> unit +end + +val pp : Format.formatter -> t -> unit diff --git a/lib/bap_lisp/bap_lisp__resolve.ml b/lib/bap_lisp/bap_lisp__resolve.ml new file mode 100644 index 000000000..37aa7277f --- /dev/null +++ b/lib/bap_lisp/bap_lisp__resolve.ml @@ -0,0 +1,220 @@ +open Core_kernel +open Format +open Bap_lisp__types + +module Attribute = Bap_lisp__attribute +module Context = Bap_lisp__context +module Def = Bap_lisp__def +module Loc = Bap_lisp__loc +module Program = Bap_lisp__program + +open Bap_lisp__attributes + + + +type stage = Loc.Set.t +type resolution = { + constr : Context.t; + stage1 : stage; (* definitions with the given name *) + stage2 : stage; (* definitions applicable to the ctxt *) + stage3 : stage; (* lower bounds of all definitions *) + stage4 : stage; (* infinum *) + stage5 : stage; (* overload *) +} + + +type ('t,'a,'b) resolver = + Program.t -> 't Program.item -> string -> 'a -> + ('b,resolution) result option + +type ('t,'a,'b) one = ('t,'a,'t Def.t * 'b) resolver +type ('t,'a,'b) many = ('t,'a,('t Def.t * 'b) list) resolver + + +type exn += Failed of string * Context.t * resolution + +let interns d name = Def.name d = name +let externs def name = + match Attribute.Set.get (Def.attributes def) External.t with + | None -> false + | Some names -> List.mem ~equal:String.equal names name + + + +(* all definitions with the given name *) +let stage1 has_name defs name = + List.filter defs ~f:(fun def -> has_name def name) + +let context def = + match Attribute.Set.get (Def.attributes def) Context.t with + | Some cx -> cx + | None -> Context.empty + + +let compare_def d1 d2 = + Context.(compare (context d1) (context d2)) + +(* all definitions that satisfy the [ctxts] constraint *) +let stage2 (global : Context.t) defs = + List.filter defs ~f:(fun def -> Context.(global <= context def)) + +(* returns a set of lower bounds from the given set of definitions. *) +let stage3 s2 = + List.fold s2 ~init:[] ~f:(fun cs d -> match cs with + | [] -> [d] + | c :: cs -> match compare_def d c with + | Same | Equiv -> d :: c :: cs + | More -> c :: cs + | Less -> [d]) + +(* ensures that all definitions belong to the same context class. + + if any two definitions are equivalent but not the same, then we + drop all definitions, since if we have more than one definition at + this stage, then the only left method of refinement is the + overloading, and we do not want to allow the last stage to choose + from equivalent definitions based on their type. For example + + Suppose we have two definition with the following types: + + [d1 : ((arch arm) (compiler gcc)) => (i32)] + + and + + [d2 : ((arch arm) (container elf)) => (i32 i32)] + + And we apply it two a single argument [(d x)], and the context is + + [((arch arm) (compiler gcc) (container elf) ..)], then we have two + perfectly valid and applicable to the current context definitions, + and we can't choose one or another based on the number of + arguments. +*) +let stage4 = function + | [] -> [] + | x :: xs -> + if List.for_all xs ~f:(fun y -> compare_def x y = Same) + then x::xs + else [] + +let overload_macro code (s3) = + List.filter_map s3 ~f:(fun def -> + Option.(Def.Macro.bind def code >>| fun (n,bs) -> n,def,bs)) |> + List.sort ~compare:(fun (n,_,_) (m,_,_) -> Int.ascending n m) |> function + | [] -> [] + | ((n,_,_) as c) :: cs -> List.filter_map (c::cs) ~f:(fun (m,d,bs) -> + Option.some_if (n = m) (d,bs)) + +let all_bindings f = + List.for_all ~f:(fun (v,x) -> + f v.data.typ x) + +let overload_defun typechecks args s3 = + let open Option in + List.filter_map s3 ~f:(fun def -> + List.zip (Def.Func.args def) args >>= fun bs -> + if all_bindings typechecks bs + then Some (def,bs) else None) + +let zip_tail xs ys = + let rec zip zs xs ys = match xs,ys with + | [],[] -> zs, None + | x,[] -> zs, Some (First x) + | [],y -> zs, Some (Second y) + | x :: xs, y :: ys -> zip ((x,y)::zs) xs ys in + let zs,tail = zip [] xs ys in + List.rev zs,tail + + +let overload_meth typechecks args s3 = + List.filter_map s3 ~f:(fun m -> + match zip_tail (Def.Meth.args m) args with + | bs,None + | bs, Some (Second _) when all_bindings typechecks bs -> + Some (m,bs) + | _ -> None) + +let overload_primitive s3 = List.map s3 ~f:(fun s -> s,()) + +let locs prog defs = + let src = Program.sources prog in + List.map defs ~f:(fun def -> + Source.loc src def.id) |> Loc.Set.of_list + +let one = function + | [x] -> Some x + | _ -> None + +let many xs = Some xs + +let run choose namespace overload prog item name = + let ctxts = Program.context prog in + let defs = Program.get prog item in + let s1 = stage1 namespace defs name in + let s2 = stage2 ctxts s1 in + let s3 = stage3 s2 in + let s4 = stage4 s3 in + let s5 = overload s4 in + match choose s5 with + | Some f -> Some (Ok f) + | None -> match s1 with + | [] -> None + | _ -> Some( Error { + constr = ctxts; + stage1 = locs prog s1; + stage2 = locs prog s2; + stage3 = locs prog s3; + stage4 = locs prog s4; + stage5 = locs prog (List.map s5 ~f:fst); + }) + +let extern typechecks prog item name args = + run one externs (overload_defun typechecks args) prog item name + +let defun typechecks prog item name args = + run one interns (overload_defun typechecks args) prog item name + +let meth typechecks prog item name args = + run many interns (overload_meth typechecks args) prog item name + +let macro prog item name code = + run one interns (overload_macro code) prog item name + +let primitive prog item name () = + run one interns overload_primitive prog item name + +let subst prog item name () = + run one interns overload_primitive prog item name + +let const = subst + +let pp_stage ppf stage = + if Set.is_empty stage + then fprintf ppf "No definitions@\n" + else Set.iter stage ~f:(fprintf ppf "%a@\n" Loc.pp) + +let pp_reason ppf res = + if Set.is_empty res.stage5 + then fprintf ppf "no suitable definitions were found.@\n" + else fprintf ppf "several equally applicable definitions were found.@\n" + +let pp_resolution ppf res = + pp_reason ppf res; + fprintf ppf "The following candidates were considered:@\n"; + fprintf ppf "All definitions with the given name:@\n"; + pp_stage ppf res.stage1; + fprintf ppf "All definitions applicable to the given context:@\n"; + pp_stage ppf res.stage2; + fprintf ppf "Definitions that are most specific to the given context:@\n"; + pp_stage ppf res.stage3; + if Set.equal res.stage3 res.stage4 + then + fprintf ppf "All definitions applicable to the specified arguments:@\n%a" + pp_stage res.stage5 + else + fprintf ppf + "Overloading was not applied, since the above definitions \ + belong to different context classes@\n"; + fprintf ppf "Note: the definitions were considered in the \ + following context:@\n%a" + Context.pp res.constr diff --git a/lib/bap_lisp/bap_lisp__resolve.mli b/lib/bap_lisp/bap_lisp__resolve.mli new file mode 100644 index 000000000..b3910f94b --- /dev/null +++ b/lib/bap_lisp/bap_lisp__resolve.mli @@ -0,0 +1,23 @@ +open Bap_lisp__types +module Context = Bap_lisp__context +module Def = Bap_lisp__def +module Program = Bap_lisp__program + +type resolution + +type ('t,'a,'b) resolver = + Program.t -> 't Program.item -> string -> 'a -> + ('b,resolution) result option + +type ('t,'a,'b) one = ('t,'a,'t Def.t * 'b) resolver +type ('t,'a,'b) many = ('t,'a,('t Def.t * 'b) list) resolver + +val extern : (typ -> 'a -> bool) -> (Def.func, 'a list, (var * 'a) list) one +val defun : (typ -> 'a -> bool) -> (Def.func, 'a list, (var * 'a) list) one +val meth : (typ -> 'a -> bool) -> (Def.meth, 'a list, (var * 'a) list) many +val macro : (Def.macro, tree list, (string * tree list) list) one +val primitive : (Def.primitive, unit, unit) one +val subst : (Def.subst, unit, unit) one +val const : (Def.const, unit, unit) one + +val pp_resolution: Format.formatter -> resolution -> unit diff --git a/lib/bap_lisp/bap_lisp__source.ml b/lib/bap_lisp/bap_lisp__source.ml new file mode 100644 index 000000000..b9ece17a1 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__source.ml @@ -0,0 +1,189 @@ +open Core_kernel + +module Cst = Parsexp.Cst +module Loc = Bap_lisp__loc +module Index = Bap_lisp__index + +module Id = Index.Make() +module Eq = Index.Make() + +type error = Bad_sexp of (string * Parsexp.Parse_error.t) + +type ('a,'i,'e) interned = ('a,'i,'e) Index.interned = { + data : 'a; + id : 'i; + eq : 'e; +} +type 'a indexed = ('a,Id.t,Eq.t) interned + +type tree = token indexed +and token = Atom of string | List of tree list + +type t = { + lastid : Id.t; + lasteq : Eq.t; + hashed : Eq.t Sexp.Map.t; + equivs : Eq.t Id.Map.t; + origin : string Id.Map.t; + ranges : Loc.range Id.Map.t; + source : tree list String.Map.t; + rclass : Id.t Id.Map.t; +} + +let empty = { + lastid = Id.null; + lasteq = Eq.null; + hashed = Sexp.Map.empty; + equivs = Id.Map.empty; + origin = Id.Map.empty; + ranges = Id.Map.empty; + source = String.Map.empty; + rclass = Id.Map.empty; +} + +let nextid p = { + p with lastid = Id.next p.lastid +} + +let nexteq p = {p with lasteq = Eq.next p.lasteq} + +let rec repr s id = match Map.find s.rclass id with + | None -> id + | Some id' -> if id = id' then id else repr s id' + +let hashcons p sexp = + match Map.find p.hashed sexp with + | Some eq -> p,eq + | None -> + let p = nexteq p in + {p with hashed = Map.set p.hashed ~key:sexp ~data:p.lasteq}, + p.lasteq + +let unify p eq = + {p with equivs = Map.set p.equivs ~key:p.lastid ~data:eq} + + +let nopos = Parsexp.Positions.beginning_of_file +let norange = Parsexp.Positions.make_range_incl + ~start_pos:nopos + ~last_pos:nopos + +let getrange pos parents child = match pos with + | None -> norange + | Some pos -> + Parsexp.Positions.find_sub_sexp_in_list_phys + pos parents ~sub:child |> function + | None -> norange + | Some range -> range + +let add_range p data = + {p with ranges = Map.set p.ranges ~key:p.lastid ~data} + +let of_cst p sexps = + let newterm p r s = + let p = nextid p in + let p = add_range p r in + hashcons p (Cst.Forget.t s) in + let rec of_sexp p s = match s with + | Cst.Comment _ -> p,None + | Cst.Sexp (Atom {atom; loc; unescaped} as s) -> + let p,eq = newterm p loc s in + let x = Option.value unescaped ~default:atom in + unify p eq, Some {data=Atom x; id=p.lastid; eq} + | Cst.Sexp (List {elements=xs; loc} as s) -> + let p,data = List.fold ~init:(p,[]) ~f:(fun (p,xs) sexp -> + match of_sexp p sexp with + | p,None -> p,xs + | p,Some x -> p,(x::xs)) xs in + let p,eq = newterm p loc s in + unify p eq,Some {data = List (List.rev data); id=p.lastid; eq} in + let p,trees = List.fold sexps ~init:(p,[]) ~f:(fun (p,xs) x -> + match of_sexp p x with + | p,None -> p,xs + | p,Some x -> p,(x::xs)) in + p,List.rev trees + + +let of_sexps ?pos p sexps = + let getrange = getrange pos sexps in + let newterm p s = + let p = nextid p in + let p = add_range p (getrange s) in + hashcons p s in + let rec of_sexp p s = match s with + | Sexp.Atom x -> + let p,eq = newterm p s in + unify p eq,{data=Atom x; id=p.lastid; eq} + | Sexp.List xs -> + let p,data = List.fold ~init:(p,[]) ~f:(fun (p,xs) sexp -> + let p,x = of_sexp p sexp in + p,(x::xs)) xs in + let p,eq = newterm p s in + unify p eq,{data = List (List.rev data); id=p.lastid; eq} in + let p,trees = List.fold sexps ~init:(p,[]) ~f:(fun (p,xs) x -> + let p,x = of_sexp p x in + p,(x::xs)) in + p,List.rev trees + + +let add_origin origins origin trees = + let rec add origins token = + let origins = Map.set origins ~key:token.id ~data:origin in + match token.data with + | Atom _ -> origins + | List tokens -> List.fold tokens ~init:origins ~f:add in + List.fold ~init:origins ~f:add trees + +let load p filename = + let source = In_channel.read_all filename in + match Parsexp.Many_cst.parse_string source with + | Error err -> Error (Bad_sexp (filename,err)) + | Ok cst -> + let p,tree = of_cst p cst in + let origin = add_origin p.origin filename tree in + Ok { + p with + origin; + source = Map.set p.source ~key:filename ~data:tree + } + +let find p filename = Map.find p.source filename +let range p id = match Map.find p.ranges (repr p id) with + | None -> norange + | Some rng -> rng + +let filename p id = match Map.find p.origin (repr p id) with + | None -> "/unknown/" + | Some file -> file + +let loc p tree = Loc.{ + file = filename p tree; + range = range p tree; + } + +let has_loc p id = Map.mem p.origin (repr p id) + +let lastid s = s.lastid +let lasteq s = s.lasteq + +let fold p ~init ~f = Map.fold ~init p.source ~f:(fun ~key ~data user -> + f key data user) + +let derived p ~from id = + if Id.null = from then { + p with lastid = Id.max id p.lastid + } else { + p with + lastid = Id.max id p.lastid; + rclass = Map.set p.rclass ~key:id ~data:from; + } + +let pp_error ppf (Bad_sexp (filename,err)) = + Parsexp.Parse_error.report ppf ~filename err + +let rec sexp_of_tree = function + | {data=List xs} -> Sexp.List (List.map xs ~f:sexp_of_tree) + | {data=Atom x} -> Sexp.Atom x + +let pp_tree ppf t = + Sexp.pp_hum ppf (sexp_of_tree t) diff --git a/lib/bap_lisp/bap_lisp__source.mli b/lib/bap_lisp/bap_lisp__source.mli new file mode 100644 index 000000000..a3303cb4c --- /dev/null +++ b/lib/bap_lisp/bap_lisp__source.mli @@ -0,0 +1,87 @@ +(** A repository of s-expressions. + + The repository holds an information about s-expressions loaded + from files. The s-expression representation is indexed, and with + each sub-tree of the s-expression we associate two indexes: the + identity index and the equality index. + + The identity index uniquely identifies a tree, and location + information is associated with each identity. The equality index + represents structural equality and is in fact the ordinal number + of the tree equivalence class in the quotient set by the + structural equality. In other words, if two trees have the same + equality index, then they are structurally equal. + + The source repository implements hashconsing, i.e., all trees with + the same structure share the same memory regions. + + Note: despite the name the module doesn't have any dependencies on + Primus Lisp, and works purely on the s-expression level. Later, we + will move it into a separate library, along with the index and + location modules. *) + + +module Index = Bap_lisp__index +module Loc = Bap_lisp__loc +module Id : Index.S +module Eq : Index.S + +type error +type t +type 'a indexed = ('a,Id.t,Eq.t) Index.interned +type tree = token indexed +and token = Atom of string | List of tree list + + + +(** [empty] source repository *) +val empty : t + + +(** [load source filename] loads the source code from the given + [filename]. The source code should be a sequence of well-formed + s-expressions. + + The [filename] should be an explicit path. +*) +val load : t -> string -> (t,error) result + + +(** [find source filename] returns a list of trees loaded from a file + with the given [filename]. *) +val find : t -> string -> tree list option + + +(** [loc source id] returns a location information for the identity + with the provided [id]. + + If there is no such information, then a bogus location is + returned. *) +val loc : t -> Id.t -> Loc.t + +(** [has_loc source id] if the location information is associated + with the given [id] *) +val has_loc : t -> Id.t -> bool + +(** [filename source id] returns the name of a file from which an + identity with the given [id] is orginating. + + If the identity is not known to the source code repository, then + a bogus filename is returned. *) +val filename : t -> Id.t -> string + + +(** [fold source ~init ~f] iterates over all files loaded into the + [source] repository. *) +val fold : t -> init:'a -> f:(string -> tree list -> 'a -> 'a) -> 'a + +val derived : t -> from:Id.t -> Id.t -> t + +val lastid : t -> Id.t + +val lasteq : t -> Eq.t + + +val pp_error : Format.formatter -> error -> unit + +val pp_tree : Format.formatter -> tree -> unit diff --git a/lib/bap_lisp/bap_lisp__type.ml b/lib/bap_lisp/bap_lisp__type.ml new file mode 100644 index 000000000..8b7484857 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__type.ml @@ -0,0 +1,123 @@ +open Core_kernel +open Bap_core_theory +open Bap_lisp__types +module Context = Bap_lisp__context +type context = Context.t + +type signature = { + args : typ list; + rest : typ option; + ret : typ; +} + +let symbol_size = 63 +let sort s = Type (Theory.Value.Sort.forget s) +let bool = sort Theory.Bool.t +let word n = sort (Theory.Bitv.define n) +let sym = Symbol +let var n = Name n + +type read_error = Empty | Not_sexp | Bad_sort + +(* let rec parse_sort : Sexp.t -> (_,read_error) result = function + * | Atom "Bool" -> Ok Bool + * | Atom s -> Ok (Cons (s,[])) + * | List (Atom name :: ps) -> + * Result.map (parse_params ps) ~f:(fun ps -> Sort.Cons (name,ps)) + * | List _ -> Error Bad_sort + * and parse_params ps = Result.all (List.map ps ~f:parse_param) + * and parse_param = function + * | Atom x when Char.is_digit x.[0] -> Ok (Index (int_of_string x)) + * | x -> Result.map (parse_sort x) ~f:(fun x -> Sort.Sort x) *) + + +let read s = + if String.length s < 1 then Error Empty + else + if Char.is_lowercase s.[0] + then Ok (Name s) + else match Sexp.of_string s with + | exception _ -> Error Not_sexp + | s -> + failwith "Sort parsing is not implemented yet" +(* Result.map (parse_sort s) ~f:(fun s -> Type s) *) + +let any = Any + + +let signature ?rest args ret = { + ret; + rest; + args; +} + +module Check = struct + let sort typ s = match typ with + | Any | Name _ -> true + | Symbol -> false + | Type s' -> + let s = Theory.Value.Sort.forget s in + Theory.Value.Sort.Top.compare s s' = 0 +end + +module Spec = struct + (* module Type = struct + * include Lisp.Program.Type + * type t = arch -> Lisp.Type.t + * type signature = arch -> Lisp.Type.signature + * + * type parameters = [ + * | `All of t + * | `Gen of t list * t + * | `Tuple of t list + * ] + * + * module Spec = struct + * let any _ = Lisp.Type.any + * let var s _ = Lisp.Type.var s + * let sym _ = Lisp.Type.sym + * let word n _ = Lisp.Type.word n + * let int arch = + * Lisp.Type.word (Size.in_bits (Arch.addr_size arch)) + * let bool = word 1 + * let byte = word 8 + * let a : t = var "a" + * let b : t = var "b" + * let c : t = var "c" + * let d : t = var "d" + * + * let tuple ts = `Tuple ts + * let unit = tuple [] + * let one t = tuple [t] + * let all t = `All t + * + * let (//) : [`Tuple of t list] -> [`All of t] -> parameters = + * fun (`Tuple ts) (`All t) -> `Gen (ts,t) + * + * let (@->) (dom : [< parameters]) (cod : t) : signature = + * let args,rest = match dom with + * | `All t -> [],Some t + * | `Tuple ts -> ts,None + * | `Gen (ts,t) -> ts, Some t in + * fun arch -> + * let args = List.map args ~f:(fun t -> t arch) in + * let cod = cod arch in + * let rest = Option.map rest ~f:(fun t -> t arch) in + * Lisp.Type.signature args ?rest cod + * + * end + * end *) + +end + +let pp ppf t = match t with + | Any | Symbol -> () + | Name s -> Format.fprintf ppf "%s" s + | Type _ -> () + + +include Comparable.Make(struct + type t = typ [@@deriving sexp, compare] + end) + +type t = typ [@@deriving sexp,compare] diff --git a/lib/bap_lisp/bap_lisp__type.mli b/lib/bap_lisp/bap_lisp__type.mli new file mode 100644 index 000000000..edf765fc0 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__type.mli @@ -0,0 +1,66 @@ +open Core_kernel +open Bap_core_theory + +open Bap_lisp__types + +module Context = Bap_lisp__context + +type t = typ [@@deriving compare, sexp] +type context = Context.t +type signature = { + args : typ list; + rest : typ option; + ret : typ; +} + +type read_error = Empty | Not_sexp | Bad_sort + + +val symbol_size : int +val read : string -> (t,read_error) result +val bool : t +val word : int -> t +val any : t +val sym : t +val var : string -> t + +val signature : ?rest:t -> t list -> t -> signature + + +val pp : Format.formatter -> t -> unit + +module Check : sig + val sort : t -> 'a Theory.Value.sort -> bool +end + + +(* module Spec : sig + * type t + * + * type parameters = [ + * | `All of t + * | `Gen of t list * t + * | `Tuple of t list + * ] + * + * val any : t + * val var : string -> t + * val sym : t + * val int : t + * val bool : t + * val byte : t + * val word : int -> t + * val a : t + * val b : t + * val c : t + * val d : t + * + * val tuple : t list -> [`Tuple of t list] + * val all : t -> [`All of t] + * val one : t -> [`Tuple of t list] + * val unit : [`Tuple of t list] + * val (//) : [`Tuple of t list] -> [`All of t] -> parameters + * val (@->) : [< parameters] -> t -> signature + * end *) + +include Comparable.S_plain with type t := t diff --git a/lib/bap_lisp/bap_lisp__types.ml b/lib/bap_lisp/bap_lisp__types.ml new file mode 100644 index 000000000..bf67a3c25 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__types.ml @@ -0,0 +1,57 @@ +open Core_kernel +open Bap_core_theory + +module Index = Bap_lisp__index +module Loc = Bap_lisp__loc +module Source = Bap_lisp__source + +module Id = Source.Id +module Eq = Source.Eq +type ('a,'i,'e) interned = ('a,'i,'e) Index.interned = { + data : 'a; + id : 'i; + eq : 'e; +} [@@deriving compare, sexp] + +type 'a indexed = ('a,Id.t,Eq.t) interned [@@deriving compare] + +type typ = + | Any + | Symbol + | Name of string + | Type of Theory.Value.Sort.Top.t [@@deriving sexp, compare] +type 'a term = {exp : 'a; typ : typ} [@@deriving compare] +type word = Bitvec.t term indexed [@@deriving compare] +type var = string term indexed [@@deriving compare] +type sym = string indexed [@@deriving compare] +type loc = Loc.t + +type error = .. +exception Fail of error + + +type tree = Source.tree +type token = Source.token = + | Atom of string + | List of tree list + + +type ast = exp indexed +and exp = + | Int of word + | Var of var + | Sym of sym + | Ite of ast * ast * ast + | Let of var * ast * ast + | App of binding * ast list + | Seq of ast list + | Set of var * ast + | Rep of ast * ast + | Msg of fmt list * ast list + | Err of string +and fmt = + | Lit of string + | Pos of int +and binding = + | Dynamic of string + | Static of var list * ast diff --git a/lib/bap_lisp/bap_lisp__var.ml b/lib/bap_lisp/bap_lisp__var.ml new file mode 100644 index 000000000..7320675a7 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__var.ml @@ -0,0 +1,32 @@ +open Core_kernel +open Bap_lisp__types +open Format + +module Type = Bap_lisp__type + +let to_string = function + | {data={exp;typ = Any}} -> exp + | {data={exp;typ}} -> asprintf "%s:%a" exp Type.pp typ + +let sexp_of_var v = Sexp.Atom (to_string v) + +type read_error = Empty | Not_a_var | Bad_format + | Bad_type of Type.read_error + +let read id eq = function + | "" -> Error Empty + | x when Char.is_digit x.[0] || x.[0] = '\'' || x.[0] = '"' -> + Error Not_a_var + | x -> match String.split x ~on:':' with + | [] -> assert false + | _::_::_::_ -> Error Bad_format + | [x] -> Ok {data={exp=x; typ=Any}; id; eq} + | [x;sz] -> match Type.read sz with + | Error e -> Error (Bad_type e) + | Ok typ -> Ok {data={exp=x; typ}; id; eq} + +include Comparable.Make_plain(struct + type t = var [@@deriving compare,sexp_of] + end) + +type t = var [@@deriving compare,sexp_of] diff --git a/lib/bap_lisp/bap_lisp__var.mli b/lib/bap_lisp/bap_lisp__var.mli new file mode 100644 index 000000000..8ee3f73e5 --- /dev/null +++ b/lib/bap_lisp/bap_lisp__var.mli @@ -0,0 +1,15 @@ +open Core_kernel +open Bap_lisp__types + +module Type = Bap_lisp__type + +type t = var [@@deriving compare, sexp_of] +include Comparable.S_plain with type t := t +val to_string : t -> string + + +type read_error = Empty | Not_a_var | Bad_format + | Bad_type of Type.read_error + + +val read : Id.t -> Eq.t -> string -> (t,read_error) result diff --git a/lib/bap_lisp/bap_lisp__word.ml b/lib/bap_lisp/bap_lisp__word.ml new file mode 100644 index 000000000..b70499dae --- /dev/null +++ b/lib/bap_lisp/bap_lisp__word.ml @@ -0,0 +1,79 @@ +open Core_kernel +open Bap_lisp__types + +module Type = Bap_lisp__type + +type t = word [@@deriving compare] + +type read_error = Empty | Not_an_int | Unclosed | Bad_literal + | Bad_type of Type.read_error + +let char_of_string s = + try Ok Char.(Bitvec.M8.int @@ to_int @@ of_string s) + with _ -> Error Bad_literal + +let read_char str = char_of_string (String.subo ~pos:1 str) + +let read_int str = + try Ok (Bitvec.of_string (String.strip str)) + with _ -> Error Bad_literal + +let char id eq s = + Result.map (read_char s) ~f:(fun exp -> + {data={exp; typ = Type.word 8}; id; eq}) + +let int ?typ id eq s = + Result.bind (read_int s) ~f:(fun exp -> + match typ with + | None -> Ok {data={exp;typ=Any}; id; eq} + | Some s -> match Type.read s with + | Error e -> Error (Bad_type e) + | Ok typ -> Ok {data={exp;typ}; id; eq}) + +let base x = + if String.length x < 3 then 10 + else + let i = if x.[0] = '-' then 1 else 0 in + match x.[i], x.[i+1] with + | '0',('b'|'B') -> 2 + | '0',('o'|'O') -> 8 + | '0',('x'|'X') -> 16 + | _ -> 10 + +let minimum_bitwidth ~base digits = + Float.to_int @@ + Float.round_up @@ + float digits *. (log (float base) /. log 2.) + +let is_hex = function + | '0'..'9' | 'a'..'f' | 'A'..'F' -> true + | _ -> false + +let infer_width x = + let base = base x in + let sign_bit = if x.[0] = '-' then 1 else 0 in + let is_digit = if base = 16 then is_hex else Char.is_digit in + let len = String.count x ~f:is_digit in + let return n = String.concat [ + x; ":"; + string_of_int (minimum_bitwidth ~base n + sign_bit) + ] in + if base = 10 then return len + else return (len - 1) (* for the base designator *) + +let read id eq x = + if String.is_empty x then Error Empty + else if x.[0] = '?' + then char id eq x + else if Char.is_digit x.[0] || + String.length x > 1 && + Char.is_digit x.[1] && + x.[0] = '-' + then match String.split x ~on:':' with + | [x;typ] -> int ~typ id eq x + | [x] -> int id eq @@ infer_width x + | _ -> Error Bad_literal + else Error Not_an_int + +let sexp_of_word {data={exp}} = Sexp.Atom (Bitvec.to_string exp) +let sexp_of_t = sexp_of_word diff --git a/lib/bap_lisp/bap_lisp__word.mli b/lib/bap_lisp/bap_lisp__word.mli new file mode 100644 index 000000000..d5ab9f66f --- /dev/null +++ b/lib/bap_lisp/bap_lisp__word.mli @@ -0,0 +1,10 @@ +open Core_kernel +open Bap_lisp__types + +module Type = Bap_lisp__type + +type t = word [@@deriving compare, sexp_of] +type read_error = Empty | Not_an_int | Unclosed | Bad_literal + | Bad_type of Type.read_error + +val read : Id.t -> Eq.t -> string -> (t,read_error) result diff --git a/lib/bap_piqi/bir_piqi.ml b/lib/bap_piqi/bir_piqi.ml index b39ad27c3..b0f94edae 100644 --- a/lib/bap_piqi/bir_piqi.ml +++ b/lib/bap_piqi/bir_piqi.ml @@ -62,7 +62,7 @@ module Get = struct end module Put = struct - let tid = Tid.name + let tid = Tid.to_string let exp = Bil_piqi.piqi_of_exp let var = Bil_piqi.piqi_of_var diff --git a/lib/bap_primus/bap_primus.mli b/lib/bap_primus/bap_primus.mli index 575e87241..215e1229a 100644 --- a/lib/bap_primus/bap_primus.mli +++ b/lib/bap_primus/bap_primus.mli @@ -36,9 +36,9 @@ module Std : sig built from basic building blocks, with minimal coupling between them. The central component is the Interpreter itself. It evaluates a program and interacts with three other components: - - Linker - - Env - - Memory + - Linker + - Env + - Memory The Linker is responsible for linking code into the program abstraction. The [Env] component defines the environment @@ -938,6 +938,18 @@ module Std : sig @since 1.5 *) val segfault : addr observation + (** [cfi_violation x] occurs when the CFI is not preserved. + The control flow integrity (CFI) is violated when a call + doesn't return to an expected place. This might be an + indicator of malicious code or an improper control flow + graph. + + After the observation is made the [cfi_violation] trap is + signaled, which could be handled via the + [cfi_violation_handler]. + @since 1.7 *) + val segfault : addr observation + (** is raised when a computation is halted *) type exn += Halt @@ -958,7 +970,7 @@ module Std : sig observation. @since 1.5 - *) + *) val pagefault_handler : string (** [division_by_zero_hanlder] is a trap handler for @@ -968,9 +980,18 @@ module Std : sig undefined. @since 1.5 - *) + *) val division_by_zero_handler : string + + + (** [division_by_zero] is the name of a trap handler for the + [Cfi_violation] exception. If it is linked into the machine, + then it will be invoked when the cfi-violation trap is signaled. + If it returns normally, then the result of the faulty operation is + undefined. *) + val cfi_violation_handler : string + (** Make(Machine) makes an interpreter that computes in the given [Machine]. *) module Make (Machine : Machine.S) : sig @@ -1070,7 +1091,7 @@ module Std : sig | `tid of tid | `addr of addr | `symbol of string - ] [@@deriving bin_io, compare, sexp] + ] [@@deriving bin_io, compare, sexp] (** Call tracing. @@ -1088,7 +1109,7 @@ module Std : sig observations. However, the Primus Lisp Interpreter provides call observations only when an externally visible function is called, e.g., malloc, free. - *) + *) module Trace : sig (** occurs when a subroutine is called. @@ -1099,7 +1120,7 @@ module Std : sig Example, (call (malloc 4)) - *) + *) val call : (string * value list) observation (** occurs just before a subroutine returns. @@ -1113,7 +1134,7 @@ module Std : sig Example, (call-return (malloc 4 0xDEADBEEF)) - *) + *) val return : (string * value list) observation (** occurs when an externally linked primus stub is called. @@ -1131,7 +1152,7 @@ module Std : sig Use [Machine.Observation.make] function, where [Machine] is a module implementing [Machine.S] interface, to provide observations. - *) + *) (** the statement that makes [call] observations. *) val call_entered : (string * value list) statement @@ -1165,7 +1186,7 @@ module Std : sig @since 1.5 - *) + *) val unresolved_handler : string module Name : Regular.S with type t = name @@ -1178,9 +1199,9 @@ module Std : sig machine and performs a computation using this machine.*) module type Code = functor (Machine : Machine.S) -> sig - (** [exec] computes the code. *) - val exec : unit Machine.t - end + (** [exec] computes the code. *) + val exec : unit Machine.t + end (** code representation *) @@ -1235,7 +1256,7 @@ module Std : sig with the given [name]. @since 1.5.0 - *) + *) val resolve_symbol : name -> string option m @@ -1243,7 +1264,7 @@ module Std : sig with the given [name]. @since 1.5.0 - *) + *) val resolve_tid : name -> tid option m @@ -1385,6 +1406,11 @@ module Std : sig (** [next iter] switches the internal state of [iter] to the next state and returns the current value *) val next : t -> int Machine.t + + + (** [word iter bitwidth] constructs a word of the given [bitwidth], + with bytes obtained from consequitive calls to [next].*) + val word : t -> int -> word Machine.t end end @@ -1429,11 +1455,54 @@ module Std : sig - (** Virtual memory. + (** Machine Memory. - The virtual memory is a byte addressable machine memory.*) + Provides storage facilities. A machine can have multiple memories, + e.g., RAM, ROM, HDD, cache, register files, etc. They are all accessed + via the unified memory inteface using [get] and [set] primitives wich + read and store bytes from the current memory. The current memory could + be switched with the [switch] operation and its descriptor could be + queried using the [memory] operation. + + Each memory device has an independent address space and address bus width + (which could be different from the virtual memory address size). + Each memory could be segmented and can have its own TLB, which is usually + implemented via the [pagefault] handlers. + *) module Memory : sig + + (** abstract memory descriptor, see [Descriptor.t] *) + type memory + + (** Abstract memory descriptor. + + A desciptor uniquely identifies a memory device by its name. + In addition, it holds meta information about memory address + and data bus sizes. + + *) + module Descriptor : sig + type t = memory [@@deriving compare, sexp_of] + + (** [create ~addr_size:m ~data_size:n name] constructs a + memory descriptor for a storage [name] with [m] lines in + the address bus, and [n] bits in data. *) + val create : addr_size:int -> data_size:int -> string -> memory + + + (** [unknown ~addr_size:m ~data_size:n] constructs a + memory descriptor for an arbitrary storage with [m] lines in + the address bus, and [n] bits in data. *) + val unknown : addr_size:int -> data_size:int -> memory + + + (** [name memory] returns [memory] identifier. *) + val name : memory -> string + + include Comparable.S with type t := t + end + (** occurs when a memory operation for the given addr cannot be satisfied. *) type exn += Pagefault of addr @@ -1443,11 +1512,20 @@ module Std : sig module Make(Machine : Machine.S) : sig + (** [switch memory] switches the memory module to [memory]. + + All consecutive operations until the next switch will affect + only this memory. *) + val switch : memory -> unit Machine.t + + + (** [memory] a descriptor of currently active [memory] *) + val memory : memory Machine.t + (** [get a] loads a byte from the address [a]. raises the [Pagefault] machine exception if [a] is not mapped. - - *) + *) val get : addr -> value Machine.t @@ -1488,12 +1566,15 @@ module Std : sig produce values generated by a generator (defaults to a [Generator.Random.Seeded.byte]). + If [init] is provided then the region is initialized. + An attempt to write to a readonly segment, or an attempt to execute non-executable segment will generate a segmentation fault. (TODO: provide more fine-granular traps).*) val allocate : ?readonly:bool -> ?executable:bool -> + ?init:(addr -> word Machine.t) -> ?generator:Generator.t -> addr -> int -> unit Machine.t @@ -2321,11 +2402,11 @@ ident ::= ?any atom that is not recognized as a ? {[ Category 1: - - Element1 Name, Element1 Description; - - Element2 Name, Element2 Description; - ... - Category2: - - ... + - Element1 Name, Element1 Description; + - Element2 Name, Element2 Description; + ... + Category2: + - ... ]} All entries are sorted in alphabetic order. diff --git a/lib/bap_primus/bap_primus_env.ml b/lib/bap_primus/bap_primus_env.ml index 8275be32a..d4be0c56b 100644 --- a/lib/bap_primus/bap_primus_env.ml +++ b/lib/bap_primus/bap_primus_env.ml @@ -69,7 +69,6 @@ let inspect_environment {values;random} = | None -> assert false) in Sexp.List bindings -let word = Word.of_int ~width:8 module Make(Machine : Machine) = struct @@ -88,15 +87,6 @@ module Make(Machine : Machine) = struct s with values = Map.set s.values ~key:var ~data:x }) - let gen_word gen width = - assert (width > 0); - let rec next x = - if Word.bitwidth x >= width - then Machine.return (Word.extract_exn ~hi:(width-1) x) - else Generator.next gen >>= fun y -> - next (Word.concat x (word y)) in - Generator.next gen >>| word >>= next - let null = Machine.get () >>| Project.arch >>| Arch.addr_size >>= fun s -> Value.zero (Size.in_bits s) @@ -105,13 +95,12 @@ module Make(Machine : Machine) = struct match Map.find t.values var with | Some res -> Machine.return res | None -> match Var.typ var with - | Type.Mem (_,_) -> null + | Type.Mem (_,_) | Type.Unk -> null | Type.Imm width -> match Map.find t.random var with | None -> Machine.raise (Undefined_var var) | Some gen -> - gen_word gen width >>= Value.of_word >>= fun x -> - set var x >>= fun () -> - !!x + Generator.word gen width >>= Value.of_word >>= fun x -> + set var x >>| fun () -> x let has var = Machine.Local.get state >>| fun t -> diff --git a/lib/bap_primus/bap_primus_generator.ml b/lib/bap_primus/bap_primus_generator.ml index af406c9b3..a1075fca1 100644 --- a/lib/bap_primus/bap_primus_generator.ml +++ b/lib/bap_primus/bap_primus_generator.ml @@ -156,4 +156,13 @@ module Make(Machine : Machine) = struct Machine.current () >>= fun id -> call key (init (Machine.Id.hash id)) | Some iter -> call key iter + + let word gen width = + let word = Word.of_int ~width:8 in + assert (width > 0); + let rec loop x = + if Word.bitwidth x >= width + then Machine.return (Word.extract_exn ~hi:(width-1) x) + else next gen >>= fun y -> loop (Word.concat x (word y)) in + next gen >>| word >>= loop end diff --git a/lib/bap_primus/bap_primus_generator.mli b/lib/bap_primus/bap_primus_generator.mli index 6740f4ef1..377285da1 100644 --- a/lib/bap_primus/bap_primus_generator.mli +++ b/lib/bap_primus/bap_primus_generator.mli @@ -26,4 +26,5 @@ end module Make( Machine : Machine) : sig val next : t -> int Machine.t + val word : t -> int -> word Machine.t end diff --git a/lib/bap_primus/bap_primus_interpreter.ml b/lib/bap_primus/bap_primus_interpreter.ml index 8b2db20ff..5fb0f3fe6 100644 --- a/lib/bap_primus/bap_primus_interpreter.ml +++ b/lib/bap_primus/bap_primus_interpreter.ml @@ -5,12 +5,25 @@ open Bap_c.Std open Format open Bap_primus_types -module Observation = Bap_primus_observation -module State = Bap_primus_state -module Linker = Bap_primus_linker +module Primus = struct + module Env = Bap_primus_env + module Linker = Bap_primus_linker + module Machine = Bap_primus_machine + module Memory = Bap_primus_memory + module Observation = Bap_primus_observation + module State = Bap_primus_state + module Value = Bap_primus_value +end + +open Primus + open Bap_primus_sexp +let memory_switch,switching_memory = + let inspect = Primus.Memory.Descriptor.sexp_of_t in + Observation.provide ~inspect "memory-switch" + let enter_term, term_entered = Observation.provide ~inspect:sexp_of_tid "enter-term" let leave_term, term_left = @@ -64,6 +77,9 @@ let halting,will_halt = let division_by_zero,will_divide_by_zero = Observation.provide ~inspect:sexp_of_unit "division-by-zero" +let cfi_violation,cfi_will_diverge = + Observation.provide ~inspect:sexp_of_word "cfi-violation" + let segfault, will_segfault = Observation.provide ~inspect:sexp_of_word "segfault" @@ -172,9 +188,16 @@ let sexp_of_name = function | `tid tid -> Sexp.Atom (Tid.name tid) | `addr addr -> Sexp.Atom (Addr.string_of_value addr) +type scope = { + stack : (var * value) list; + bound : int Var.Map.t; +} + type state = { addr : addr; curr : pos; + lets : scope; (* lexical context *) + prompts : tid list; } let sexp_of_state {curr} = @@ -184,7 +207,12 @@ let null proj = let size = Arch.addr_size (Project.arch proj) in Addr.zero (Size.in_bits size) -let state = Bap_primus_machine.State.declare +let empty_scope = { + stack = []; + bound = Var.Map.empty; +} + +let state = Primus.Machine.State.declare ~uuid:"14a17161-173b-46da-9e95-7819104cc220" ~name:"interpreter" ~inspect:sexp_of_state @@ -193,13 +221,17 @@ let state = Bap_primus_machine.State.declare Pos.{ addr = null proj; curr = Top {me=prog; up=Nil}; + lets = empty_scope; + prompts = []; }) type exn += Halt type exn += Division_by_zero type exn += Segmentation_fault of addr +type exn += Cfi_violation of addr type exn += Runtime_error of string + let () = Exn.add_printer (function | Runtime_error msg -> @@ -208,21 +240,23 @@ let () = | Segmentation_fault x -> Some (asprintf "Segmentation fault at %a" Addr.pp_hex x) | Division_by_zero -> Some "Division by zero" + | Cfi_violation where -> + Some (asprintf "CFI violation at %a" Addr.pp_hex where) | _ -> None) let division_by_zero_handler = "__primus_division_by_zero_handler" let pagefault_handler = "__primus_pagefault_handler" - +let cfi_violation_handler = "__primus_cfi_violation_handler" module Make (Machine : Machine) = struct open Machine.Syntax module Eval = Eval.Make(Machine) - module Memory = Bap_primus_memory.Make(Machine) - module Env = Bap_primus_env.Make(Machine) + module Memory = Primus.Memory.Make(Machine) + module Env = Primus.Env.Make(Machine) module Code = Linker.Make(Machine) - module Value = Bap_primus_value.Make(Machine) + module Value = Primus.Value.Make(Machine) type 'a m = 'a Machine.t @@ -247,12 +281,12 @@ module Make (Machine : Machine) = struct Env.set v x >>= fun () -> !!on_written (v,x) + let get v = !!on_reading v >>= fun () -> Env.get v >>= fun r -> !!on_read (v,r) >>| fun () -> r - let call_when_provided name = Code.is_linked (`symbol name) >>= fun provided -> if provided then Code.exec (`symbol name) >>| fun () -> true @@ -316,10 +350,60 @@ module Make (Machine : Machine) = struct value (if Word.is_one cond.value then yes.value else no.value) >>= fun r -> !!on_ite ((cond, yes, no), r) >>| fun () -> r + + let get_lexical scope v = + List.Assoc.find_exn scope ~equal:Var.equal v + + let update_lexical f = + Machine.Local.update state ~f:(fun s -> { + s with lets = f s.lets + }) + + let push_lexical v x = update_lexical @@ fun s -> { + stack = (v,x) :: s.stack; + bound = Map.update s.bound v ~f:(function + | None -> 1 + | Some n -> n + 1) + } + + let pop_lexical v = update_lexical @@ fun s -> { + stack = List.tl_exn s.stack; + bound = Map.change s.bound v ~f:(function + | None -> None + | Some 1 -> None + | Some n -> Some (n-1)) + } + + let memory typ name = match typ with + | Type.Imm _ as t -> failwithf "type error - load from %s:%a" + name Type.pps t () + | Type.Mem (ks,vs) -> + let ks = Size.in_bits ks and vs = Size.in_bits vs in + Primus.Memory.Descriptor.create ks vs name + + let rec memory_of_storage : exp -> _ = function + | Var v -> memory (Var.typ v) (Var.name v) + | Unknown (_,typ) -> memory typ "unknown" + | Store (s,_,_,_,_) + | Ite (_,s,_) + | Let (_,_,s) -> memory_of_storage s + | x -> failwithf "expression `%a' is no a storage" Exp.pps x () + + + let switch_memory m = + let m = memory_of_storage m in + Memory.memory >>= fun m' -> + if Primus.Memory.Descriptor.equal m m' + then Machine.return () + else + !!switching_memory m >>= fun () -> + Memory.switch m + + let rec eval_exp x = let eval = function - | Bil.Load (Bil.Var _, a,_,`r8) -> eval_load a - | Bil.Store (m,a,x,_,`r8) -> eval_store m a x + | Bil.Load (m,a,_,_) -> eval_load m a + | Bil.Store (m,a,x,_,_) -> eval_store m a x | Bil.BinOp (op, x, y) -> eval_binop op x y | Bil.UnOp (op,x) -> eval_unop op x | Bil.Var v -> eval_var v @@ -329,22 +413,30 @@ module Make (Machine : Machine) = struct | Bil.Extract (hi,lo,x) -> eval_extract hi lo x | Bil.Concat (x,y) -> eval_concat x y | Bil.Ite (cond, yes, no) -> eval_ite cond yes no - | exp -> - invalid_argf "precondition failed: denormalized exp: %s" - (Exp.to_string exp) () in + | Bil.Let (v,x,y) -> eval_let v x y in !!exp_entered x >>= fun () -> eval x >>= fun r -> !!exp_left x >>| fun () -> r - and eval_load a = eval_exp a >>= load_byte + and eval_let v x y = + eval_exp x >>= fun x -> + push_lexical v x >>= fun () -> + eval_exp y >>= fun r -> + pop_lexical v >>| fun () -> r and eval_ite cond yes no = eval_exp yes >>= fun yes -> eval_exp no >>= fun no -> eval_exp cond >>= fun cond -> ite cond yes no - and eval_store m a x = + and eval_load m a = + eval_exp a >>= fun a -> + switch_memory m >>= fun () -> eval_storage m >>= fun () -> + load_byte a + and eval_store m a x = eval_exp a >>= fun a -> eval_exp x >>= fun x -> + switch_memory m >>= fun () -> + eval_storage m >>= fun () -> store_byte a x >>| fun () -> a and eval_binop op x y = eval_exp x >>= fun x -> @@ -353,7 +445,11 @@ module Make (Machine : Machine) = struct and eval_unop op x = eval_exp x >>= fun x -> unop op x - and eval_var = get + and eval_var v = + Machine.Local.get state >>= fun {lets} -> + if Map.mem lets.bound v + then Machine.return (get_lexical lets.stack v) + else get v and eval_int = const and eval_cast t s x = eval_exp x >>= fun x -> @@ -476,26 +572,70 @@ module Make (Machine : Machine) = struct Value.of_word addr >>= fun addr -> !!will_jump (cond,addr) + let exec_to_prompt dst = + Machine.Local.get state >>= function + | {prompts = p :: _} when Tid.equal p dst -> + Machine.return () + | _ -> Code.exec (`tid dst) + let label cond : label -> _ = function - | Direct t -> - will_jump_to_tid cond t >>= fun () -> - Code.exec (`tid t) + | Direct dst -> + will_jump_to_tid cond dst >>= fun () -> + exec_to_prompt dst | Indirect x -> eval_exp x >>= fun ({value} as dst) -> !!will_jump (cond,dst) >>= fun () -> - Code.exec (`addr value) + Code.resolve_tid (`addr value) >>= function + | None -> Code.exec (`addr value) + | Some dst -> exec_to_prompt dst + - let call cond c = - label cond (Call.target c) >>= fun () -> - match Call.return c with - | Some t -> label cond t - | None -> failf "a non-return call returned" () + let resolve_return call = match Call.return call with + | None -> Machine.return None + | Some (Direct dst) -> Machine.return (Some dst) + | Some (Indirect dst) -> + eval_exp dst >>= fun {value=dst} -> + Code.resolve_tid (`addr dst) + + let push_prompt dst = match dst with + | None -> Machine.return () + | Some dst -> + Machine.Local.update state ~f:(fun s -> + {s with prompts = dst :: s.prompts}) + + let pop_prompt = + Machine.Local.update state ~f:(function + | {prompts=[]} as s -> s + | {prompts=_::prompts} as s -> {s with prompts}) + + let trap_cfi_violation callsite = + !!cfi_will_diverge callsite >>= fun () -> + call_when_provided cfi_violation_handler >>= function + | true -> Machine.return () + | false -> Machine.raise (Cfi_violation callsite) + + let call cond call = + Machine.Local.get state >>= fun {addr=callsite} -> + resolve_return call >>= fun ret -> + push_prompt ret >>= fun () -> + label cond (Call.target call) >>= fun () -> + Machine.Local.get state >>= function + | {prompts=[]} -> trap_cfi_violation callsite + | {prompts=p::_} -> match ret with + | None -> Machine.return () + | Some p' when Tid.(p <> p') -> + trap_cfi_violation callsite >>= fun () -> + pop_prompt >>= fun () -> + label cond (Direct p') + | Some p -> + pop_prompt >>= fun () -> + label cond (Direct p) let goto cond c = label cond c let interrupt n = !!will_interrupt n let jump cond t = match Jmp.kind t with - | Ret _ -> Machine.return () (* return from sub *) + | Ret dst -> label cond dst | Call c -> call cond c | Goto l -> goto cond l | Int (n,r) -> diff --git a/lib/bap_primus/bap_primus_interpreter.mli b/lib/bap_primus/bap_primus_interpreter.mli index 3a4b0d604..b70915971 100644 --- a/lib/bap_primus/bap_primus_interpreter.mli +++ b/lib/bap_primus/bap_primus_interpreter.mli @@ -8,6 +8,7 @@ val interrupt : int observation val division_by_zero : unit observation val segfault : addr observation val pagefault : addr observation +val cfi_violation : addr observation val loading : value observation val loaded : (value * value) observation @@ -60,9 +61,11 @@ val leave_jmp : jmp term observation type exn += Halt type exn += Division_by_zero type exn += Segmentation_fault of addr +type exn += Cfi_violation of addr val division_by_zero_handler : string val pagefault_handler : string +val cfi_violation_handler : string module Make (Machine : Machine) : sig diff --git a/lib/bap_primus/bap_primus_memory.ml b/lib/bap_primus/bap_primus_memory.ml index 7c4e49369..bcfe6575c 100644 --- a/lib/bap_primus/bap_primus_memory.ml +++ b/lib/bap_primus/bap_primus_memory.ml @@ -26,11 +26,43 @@ type region = type perms = {readonly : bool; executable : bool} type layer = {mem : region; perms : perms} -type t = { +type memory = { + name : string; + addr : int; + size : int; +} [@@deriving bin_io, sexp] + +let compare_memory {name=x} {name=y} = String.compare x y + +module Descriptor = struct + type t = memory [@@deriving bin_io, compare, sexp] + let create ~addr_size ~data_size name = { + addr = addr_size; + size = data_size; + name + } + + let unknown ~addr_size ~data_size = + create addr_size data_size "unknown" + + let name d = d.name + + + include Comparable.Make(struct + type t = memory [@@deriving bin_io, compare, sexp] + end) +end + +type state = { values : value Addr.Map.t; layers : layer list; } +type t = { + curr : Descriptor.t; + mems : state Descriptor.Map.t +} + let zero = Word.of_int ~width:8 0 let sexp_of_word w = Sexp.Atom (asprintf "%a" Word.pp_hex w) @@ -60,7 +92,7 @@ let sexp_of_layer {mem; perms={readonly; executable}} = ] |> String.concat ~sep:"" in Sexp.(List [sexp_of_mem mem; Atom flags]) -let inspect_memory {values; layers} = +let inspect_state {values; layers} = let values = Map.to_sequence values |> Seq.map ~f:(fun (key,value) -> Sexp.(List [sexp_of_word key; sexp_of_value value])) |> @@ -71,11 +103,31 @@ let inspect_memory {values; layers} = List [Atom "layers"; List layers]; ]) +let inspect_memory {curr; mems} = Sexp.List [ + Sexp.Atom curr.name; + Descriptor.Map.sexp_of_t inspect_state mems; + ] + + +let virtual_memory arch = + let module Target = (val target_of_arch arch) in + let mem = Target.CPU.mem in + match Var.typ mem with + | Type.Imm _ as t -> + invalid_argf "The CPU.mem variable %a:%a is not a storage" + Var.pps mem Type.pps t () + | Type.Mem (ks,vs) -> + let ks = Size.in_bits ks and vs = Size.in_bits vs in + Descriptor.create ks vs (Var.name mem) + + let state = Bap_primus_machine.State.declare ~uuid:"4b94186d-3ae9-48e0-8a93-8c83c747bdbb" ~inspect:inspect_memory - ~name:"memory" - (fun _ -> {values = Addr.Map.empty; layers = []}) + ~name:"memory" @@ fun p -> { + mems = Descriptor.Map.empty; + curr = virtual_memory (Project.arch p); + } let inside {base;len} addr = let high = Word.(base ++ len) in @@ -89,6 +141,10 @@ let find_layer addr = List.find ~f:(function let is_mapped addr {layers} = find_layer addr layers <> None +let empty_state = { + values = Addr.Map.empty; + layers = []; +} module Make(Machine : Machine) = struct open Machine.Syntax @@ -97,33 +153,96 @@ module Make(Machine : Machine) = struct module Value = Bap_primus_value.Make(Machine) let (!!) = Machine.Observation.make + + let memory = + Machine.Local.get state >>| fun s -> s.curr + + let switch curr = + Machine.Local.update state ~f:(fun s -> {s with curr}) + + let get_curr = + Machine.Local.get state >>| fun {curr; mems} -> + match Map.find mems curr with + | None -> empty_state + | Some s -> s + + let put_curr mem = + Machine.Local.get state >>= fun {curr; mems} -> + Machine.Local.put state { + curr; + mems = Map.set mems ~key:curr ~data:mem + } + let update state f = Machine.Local.get state >>= fun s -> - Machine.Local.put state (f s) + match Map.find s.mems s.curr with + | None -> Machine.return () + | Some m -> Machine.Local.put state { + s with + mems = Map.set s.mems ~key:s.curr ~data:(f m) + } let pagefault addr = Machine.raise (Pagefault addr) + let read_small mem addr size = + assert (size < 8 && size > 0); + let addr_size = Addr.bitwidth addr in + (* to address {n 2^m} bits we need {m+log(m)} addr space, + since {n < 8} (it is the word size in bits), we just add 3.*) + let width = addr_size + 3 in + let addr = Word.extract_exn ~hi:(width-1) addr in + let off = Addr.((addr - Memory.min_addr mem)) in + let off_in_bits = Word.(off * Word.of_int ~width size) in + let full_bytes = Word.(off_in_bits / Word.of_int ~width 8) in + let bit_off = + Word.to_int_exn @@ + Word.(off_in_bits - full_bytes * Word.of_int ~width 8) in + let leftover = Memory.length mem - Word.to_int_exn full_bytes in + let len = min leftover 2 in + let full_bytes = Word.extract_exn ~hi:(addr_size-1) full_bytes in + let from = Addr.(Memory.min_addr mem + full_bytes) in + let mem = ok_exn @@ Memory.view mem ~from ~words:len in + let data = Bigsubstring.to_string (Memory.to_buffer mem) in + let vec = Word.of_binary ~width:(len * 8) BigEndian data in + let hi = len * 8 - bit_off - 1 in + let lo = hi - size + 1 in + Word.extract_exn ~hi ~lo vec + + + (* we can't use Bap.Std.Memory here as we need arbitrary lengths *) + let read_word mem base size = + let start,next = match Memory.endian mem with + | LittleEndian -> Addr.nsucc base (size/8-1), Addr.pred + | BigEndian -> base, Addr.succ in + let rec read addr left = + let data = ok_exn @@ Memory.get ~addr mem in + if left <= 8 + then Word.extract_exn ~lo:(8-left) data + else Word.concat data @@ read (next addr) (left - 8) in + if size >= 8 then read start size + else read_small mem base size + let remembered {values; layers} addr value = - Machine.Local.put state { + put_curr { layers; values = Map.set values ~key:addr ~data:value; } >>| fun () -> value + let read addr {values;layers} = match find_layer addr layers with | None -> pagefault addr | Some layer -> match Map.find values addr with | Some v -> Machine.return v | None -> - let read_value = match layer.mem with + let read_value = + memory >>= fun {size} -> + match layer.mem with | Dynamic {value} -> Generate.next value >>= Value.of_int ~width:8 - | Static mem -> match Memory.get ~addr mem with - | Ok v -> Value.of_word v - | Error _ -> failwith "Primus.Memory.read" in + | Static mem -> Value.of_word (read_word mem addr size) in read_value >>= remembered {values; layers} addr - let write addr value {values;layers} = match find_layer addr layers with | None -> pagefault addr @@ -136,41 +255,52 @@ module Make(Machine : Machine) = struct let add_layer layer t = {t with layers = layer :: t.layers} let (++) = add_layer + let initialize values base len f = + Machine.Seq.fold (Seq.range 0 len) ~init:values ~f:(fun values i -> + let addr = Addr.(base ++ i) in + f addr >>= fun data -> + Value.of_word data >>| fun data -> + Map.set values ~key:addr ~data) + + + let allocate ?(readonly=false) ?(executable=false) + ?init ?(generator=Generator.Random.Seeded.byte) base len = - update state @@ add_layer { + get_curr >>| add_layer { perms={readonly; executable}; mem = Dynamic {base;len; value=generator} - } + } >>= fun s -> + match init with + | None -> put_curr s + | Some f -> + initialize s.values base len f >>= fun values -> + put_curr {s with values} let map ?(readonly=false) ?(executable=false) mem = update state @@ add_layer ({mem=Static mem; perms={readonly; executable}}) let add_text mem = map mem ~readonly:true ~executable:true let add_data mem = map mem ~readonly:false ~executable:false - let get addr = - Machine.Local.get state >>= read addr + let get addr = get_curr >>= read addr let set addr value = - if Value.bitwidth value <> 8 - then invalid_argf "Memory.set %a %a: value is not a byte" - Addr.pps addr Value.pps value (); - Machine.Local.get state >>= + get_curr >>= write addr value >>= - Machine.Local.put state + put_curr let load addr = get addr >>| Value.to_word let store addr value = Value.of_word value >>= set addr let is_mapped addr = - Machine.Local.get state >>| is_mapped addr + get_curr >>| is_mapped addr let is_writable addr = - Machine.Local.get state >>| fun {layers} -> + get_curr >>| fun {layers} -> find_layer addr layers |> function Some {perms={readonly}} -> not readonly | None -> false diff --git a/lib/bap_primus/bap_primus_memory.mli b/lib/bap_primus/bap_primus_memory.mli index a14c89df4..5df7d4234 100644 --- a/lib/bap_primus/bap_primus_memory.mli +++ b/lib/bap_primus/bap_primus_memory.mli @@ -1,12 +1,26 @@ +open Core_kernel open Bap.Std open Bap_primus_types module Generator = Bap_primus_generator type exn += Pagefault of addr +type memory +module Descriptor : sig + type t = memory [@@deriving compare, sexp_of] + + val create : addr_size:int -> data_size:int -> string -> memory + val unknown : addr_size:int -> data_size:int -> memory + val name : memory -> string + + include Comparable.S with type t := memory +end module Make(Machine : Machine) : sig + val switch : memory -> unit Machine.t + val memory : memory Machine.t + val load : addr -> word Machine.t val store : addr -> word -> unit Machine.t @@ -19,14 +33,17 @@ module Make(Machine : Machine) : sig val allocate : ?readonly:bool -> ?executable:bool -> + ?init:(addr -> word Machine.t) -> ?generator:Generator.t -> addr -> int -> unit Machine.t + val map : ?readonly:bool -> ?executable:bool -> mem -> unit Machine.t + val is_mapped : addr -> bool Machine.t val is_writable : addr -> bool Machine.t diff --git a/lib/bap_primus_machine/bap_primus_machine.ml b/lib/bap_primus_machine/bap_primus_machine.ml new file mode 100644 index 000000000..5cfc91b90 --- /dev/null +++ b/lib/bap_primus_machine/bap_primus_machine.ml @@ -0,0 +1,465 @@ +open Core_kernel +open Monads.Std +open Bap_knowledge + + +type outcome = + | Continue + | Stop + | Info of Info.t + +module Observer = struct + type ('f,'m) t = { + id : int; + observe: (outcome -> 'm) -> 'f + } +end + +module Inspector = struct + type 'm t = { + id : int; + inspect : Info.t -> 'm + } +end + +module Observation = struct + type info = Info.t + + + type ('f,'m) t = { + name : string; + inspect : ((info -> 'm) -> 'f) option; + observer : ('f,'m) Observer.t Univ_map.Multi.Key.t; + } + + let declare ?inspect ?(package="user") name = + let name = sprintf "%s:%s" package name in + let observer = Univ_map.Multi.Key.create ~name sexp_of_opaque in + {name; inspect; observer} +end + +module State = struct + type 'a t = { + key : 'a Univ_map.Key.t; + init : (unit -> unit) Knowledge.obj -> 'a knowledge; + inspect : ('a -> Info.t); + } + type 'a state = 'a t + + let declare ?(inspect=fun _ -> Info.of_string "") ? + (name="anonymous") init = + let sexp_of_t x = Info.sexp_of_t (inspect x) in + let key = Type_equal.Id.create ~name sexp_of_t in + {key; init; inspect} + + let inspect x = x.inspect + let name x = Type_equal.Id.name x.key +end + +module Exception : sig + type t = .. + val to_string : t -> string + val add_printer : (t -> string option) -> unit +end += struct + type t = exn = .. + let to_string err = Caml.Printexc.to_string err + let add_printer pr = Caml.Printexc.register_printer pr +end + + +module Primus = struct + type id = Monad.State.Multi.id + + type project = unit -> unit + let project = Knowledge.Class.declare ~package:"primus" "project" () + + + type env = unit + type exn = Exception.t = .. + + let package = "primus" + + module PE = struct + type t = (unit, exn) Monad.Result.result + end + + module SM = struct + include Monad.State.Multi.T2(Knowledge) + include Monad.State.Multi.Make2(Knowledge) + end + + type exit_status = + | Normal + | Exn of Exception.t + + type 'a state = 'a State.t + type 'a t = (('a,exn) result,PE.t sm) Monad.Cont.t + and 'f observation = ('f, outcome t) Observation.t + and 'f observer = ('f, outcome t) Observer.t + and inspectors = outcome t Inspector.t list + and 'a sm = ('a,machine_state) SM.t + and machine_state = { + proj : project Knowledge.obj; + curr : unit -> unit t; + local : Univ_map.t; + global : Univ_map.t; + deathrow : id list; + observations : Univ_map.t; + inspectors : inspectors Map.M(String).t; + key : int; + } + + + type 'a c = 'a t + type 'a m = 'a Knowledge.t + type 'a e = project Knowledge.obj -> unit m + + + module C = Monad.Cont.Make(PE)(struct + type 'a t = 'a sm + include Monad.Make(struct + type 'a t = 'a sm + let return = SM.return + let bind m f = SM.bind m ~f + let map = `Custom SM.map + end) + end) + + module CM = Monad.Result.Make(Exception)(struct + type 'a t = ('a, PE.t sm) Monad.Cont.t + include C + end) + + type _ error = exn + open CM + + module Id = Monad.State.Multi.Id + + type 'a machine = 'a t + + + let exn_raised = Observation.declare ~package "machine-exception" + ~inspect:(fun k exn -> + k (Info.of_string (Exception.to_string exn))) + + let forked = Observation.declare ~package "machine-fork" + ~inspect:(fun k pid cid -> + k @@ Info.create "machine-fork" (pid,cid) + [%sexp_of: Id.t * Id.t]) + + let switched = Observation.declare ~package "machine-switch" + ~inspect:(fun k parent child -> + k @@ Info.create "machine-switch" (parent,child) + [%sexp_of: Id.t * Id.t]) + + + let liftk x = CM.lift (C.lift (SM.lift x)) + (* lifts state monad to the outer monad *) + let lifts x = CM.lift (C.lift x) + let fact = liftk + + let with_global_context (f : (unit -> 'a t)) = + lifts (SM.current ()) >>= fun id -> + lifts (SM.switch SM.global) >>= fun () -> + f () >>= fun r -> + lifts (SM.switch id) >>| fun () -> + r + + let get_local () : _ t = lifts (SM.gets @@ fun s -> s.local) + let get_global () : _ t = with_global_context @@ fun () -> + lifts (SM.gets @@ fun s -> s.global) + + let set_local local = lifts @@ SM.update @@ fun s -> + {s with local} + + let set_global global = with_global_context @@ fun () -> + lifts (SM.update @@ fun s -> {s with global}) + + let project = lifts (SM.gets @@ fun s -> s.proj) + + type observed = outcome + module Observation : sig + type 'f t = 'f observation + type info = Info.t + type ctrl + type observed = outcome + + val declare : + ?inspect:((info -> observed machine) -> 'f) -> + ?package:string -> string -> + 'f observation + val provide : 'f observation -> f:('f -> observed machine) -> unit machine + val monitor : 'f observation -> f:(ctrl -> 'f) -> unit machine + val inspect : 'f observation -> f:(info -> observed machine) -> unit machine + val continue : ctrl -> observed machine + val stop : ctrl -> observed machine + end + = struct + type 'a m = 'a t + type 'f t = 'f observation + type observed = outcome + type info = Info.t + type ctrl = outcome -> outcome machine + + open Observation + + let declare = Observation.declare + let empty = Set.empty (module Int) + + let observations {observer} : _ observer list m = + lifts @@ SM.gets @@ fun s -> + Univ_map.Multi.find s.observations observer + + let set_observations {observer} obs = + lifts @@ SM.update @@ fun s -> { + s with + observations = Univ_map.Multi.set s.observations observer obs + } + + let set_inspectors {Observation.name} data = + lifts @@ SM.update @@ fun s -> { + s with + inspectors = Map.set s.inspectors ~key:name ~data + } + + let inspectors {name} = + lifts @@ SM.gets @@ fun {inspectors} -> + Map.find_multi inspectors name + + (* we don't want to use List.iter as it will create + extra allocations. + + Accumulates the kill set - ids of observers that + opted for unsubscription. *) + let rec loop kill k = function + | [] -> return kill + | {Observer.id; observe} :: fs -> + k (observe return) >>= function + | Stop -> loop (Set.add kill id) k fs + | _ -> loop kill k fs + + + let get_info {Observation.name; inspect} k = + let noinfo = Info.of_string name in + match inspect with + | None -> return noinfo + | Some inspect -> + k (inspect (fun data -> Info data)) >>| function + | Info data -> data + | _ -> noinfo + (* in case if the inspector didn't call our continuation *) + + let call_inspectors obs k = + inspectors obs >>= function + | [] -> return () + | fs -> get_info obs k >>= fun info -> + List.fold ~init:[] fs ~f:(fun fs ({inspect} as f) -> + inspect info >>= function + | Stop -> return fs + | _ -> return (f::fs)) >>= + set_inspectors obs + + let kill_observers observers kill = + Base.List.rev_filter observers ~f:(fun {Observer.id} -> + not (Set.mem kill id)) + + let provide observation ~f:k = + observations observation >>= fun obs -> + loop empty k obs >>= fun kill -> + if Set.is_empty kill then return () + else + kill_observers obs kill |> + set_observations observation + + let monitor {observer} ~f = + lifts @@ SM.update @@ fun s -> { + s with + key = s.key + 1; + observations = Univ_map.Multi.add s.observations observer + Observer.{ + id = s.key + 1; + observe = f; + } + } + + let inspect {name} ~f = + lifts @@ SM.update @@ fun s -> { + s with + key = s.key + 1; + inspectors = Map.add_multi s.inspectors ~key:name + ~data:Inspector.{ + id = s.key + 1; + inspect = f; + } + } + + let continue k = k Continue + let stop k = k Stop + end + + module type State = sig + val get : 'a state -> 'a machine + val put : 'a state -> 'a -> unit machine + val update : 'a state -> f:('a -> 'a) -> unit machine + end + module Make_state(S : sig + val get : unit -> Univ_map.t t + val set : Univ_map.t -> unit t + val typ : string + end) = struct + type 'a m = 'a t + let get : 'a state -> 'a machine = fun {State.key; init} -> + S.get () >>= fun states -> + match Univ_map.find states key with + | Some s -> return s + | None -> + project >>= fun proj -> + liftk (init proj) + + let put {State.key} x = + S.get () >>= fun states -> + S.set (Univ_map.set states key x) + + let update data ~f = + get data >>= fun s -> put data (f s) + end + + module Local = Make_state(struct + let typ = "local" + let get = get_local + let set = set_local + end) + + module Global = Make_state(struct + let typ = "global" + let get = get_global + let set = set_global + end) + + module State = State + + let get () = CM.return () + let put () : unit machine = CM.return () + let gets f = CM.return (f ()) + let update _ = CM.return () + let modify m _f = m + + + let provide c l x : unit t = liftk (Knowledge.provide c l x) + let collect c l : 'a t = liftk (Knowledge.collect c l) + let conflict c : 'a t = liftk (Knowledge.fail c) + let knowledge : 'a Knowledge.t -> 'a t = liftk + + let fork_state () = lifts (SM.fork ()) + let switch_state id : unit c = lifts (SM.switch id) + let store_curr k = + lifts (SM.update (fun s -> {s with curr = fun () -> k (Ok ())})) + + let lift x = lifts (SM.lift x) + let status x = lifts (SM.status x) + let forks () = lifts (SM.forks ()) + let ancestor x = lifts (SM.ancestor x) + let parent () = lifts (SM.parent ()) + let global = SM.global + let current () = lifts (SM.current ()) + + let notify_fork pid = + current () >>= fun cid -> + Observation.provide forked ~f:(fun observe -> + observe pid cid) + + let sentence_to_death id = + with_global_context (fun () -> + lifts @@ SM.update (fun s -> { + s with deathrow = id :: s.deathrow + })) + + let execute_sentenced = + with_global_context (fun () -> + lifts @@ SM.get () >>= fun s -> + lifts @@ SM.List.iter s.deathrow ~f:SM.kill >>= fun () -> + lifts @@ SM.put {s with deathrow = []}) + + let switch id : unit c = + C.call ~f:(fun ~cc:k -> + current () >>= fun pid -> + store_curr k >>= fun () -> + switch_state id >>= fun () -> + lifts (SM.get ()) >>= fun s -> + execute_sentenced >>= fun () -> + Observation.provide switched ~f:(fun observe -> + observe pid id) >>= fun () -> + s.curr ()) + + + let fork () : unit c = + C.call ~f:(fun ~cc:k -> + current () >>= fun pid -> + store_curr k >>= + fork_state >>= fun () -> + execute_sentenced >>= fun () -> + notify_fork pid) + + + let kill id = + if id = global then return () + else + current () >>= fun cid -> + if id = cid then sentence_to_death id + else lifts @@ SM.kill id + + let die next = + current () >>= fun pid -> + switch_state next >>= fun () -> + lifts (SM.get ()) >>= fun s -> + lifts (SM.kill pid) >>= fun () -> + s.curr () + + + let raise exn = + Observation.provide exn_raised ~f:(fun observe -> + observe exn) >>= fun () -> + fail exn + let catch = catch + + let empty proj = { + proj; + curr = return; + global = Univ_map.empty; + local = Univ_map.empty; + observations = Univ_map.empty; + inspectors = Map.empty (module String); + key = 0; + deathrow = []; + } + + + let finished = + Observation.declare ~package "fini" + ~inspect:(fun k () -> k @@ Info.of_string "fini") + + let inited = Observation.declare ~package "init" + ~inspect:(fun k () -> k @@ Info.of_string "init") + + let notify obs = Observation.provide obs ~f:(fun go -> go ()) + + + let run : type a. a t -> a e = fun comp proj -> + let finish = function + | Ok _ -> SM.return (Ok ()) + | Error err -> SM.return (Error err) in + let state = empty proj in + Knowledge.ignore_m @@ + SM.run (C.run comp finish) state + + module Syntax = struct + include CM.Syntax + let (-->) x p = collect p x + let (//) c s = liftk @@ Knowledge.Object.read c s + let (>>>) x f = Observation.monitor x ~f + end + + include (CM : Monad.S with type 'a t := 'a t + and module Syntax := Syntax) +end diff --git a/lib/bap_primus_machine/bap_primus_machine.mli b/lib/bap_primus_machine/bap_primus_machine.mli new file mode 100644 index 000000000..364c3064e --- /dev/null +++ b/lib/bap_primus_machine/bap_primus_machine.mli @@ -0,0 +1,230 @@ +open Core_kernel +open Monads.Std +open Bap_knowledge + +(** Primus - A non-deterministic interpreter. + + +*) + +module Primus : sig + open Knowledge + type 'a machine + + (** The Machine Exception. + + The exn type is an extensible variant, and components + usually register their own error constructors. *) + type exn = .. + + (** [an observation] of a value of type [a].*) + type 'a observation + + + type observed + + (** Machine exit status. + A machine may terminate normally, or abnormally with the + specified exception. *) + type exit_status = + | Normal + | Exn of exn + + + + (** the machine computation *) + type 'a t = 'a machine + + type 'a state + + type project + + + (** Machine identifier type. *) + type id = Monad.State.Multi.id + + (** [raise exn] raises the machine exception [exn], intiating + an abonormal control flow *) + val raise : exn -> 'a t + + + (** [catch x f] creates a computation that is equal to [x] if + it terminates normally, and to [f e] if [x] terminates + abnormally with the exception [e]. *) + val catch : 'a t -> (exn -> 'a t) -> 'a t + + val collect : ('a,'p) slot -> 'a obj -> 'p t + val provide : ('a,'p) slot -> 'a obj -> 'p -> unit t + val project : project obj t + + val die : id -> unit t + + val conflict : conflict -> 'a t + + + (** [fact x] make the fact [x] determined in the current machine. + + This is the [pure] function w.r.t. to the non-determinism, also + known as lift, since it lifts the inner knowledge monad into the + outer machine monad. + *) + val fact : 'a knowledge -> 'a t + + + (** [run comp project] runs the Primus system. *) + val run : unit t -> project obj -> unit knowledge + + + (** Computation State *) + module State : sig + (** ['a t] is a type of state that holds a value of type + ['a], and can be constructed from the base context of type + ['c]. *) + type 'a t = 'a state + type 'a state = 'a t + + + + (** [declare ~inspect ~uuid ~name make] declares a state with + the given [uuid] and [name]. The name is not required to be + unique, while [uuid] is obviously required to be unique. + + See [uuid] type description for the uuid representation. A + new [uuid] can be obtained in the Linux system is provided + by the [uuidgen] command.*) + val declare : + ?inspect:('a -> Info.t) -> + ?name:string -> + (project obj -> 'a Knowledge.t) -> 'a t + + (** [inspect state value] introspects given [value] of the state. *) + val inspect : 'a t -> 'a -> Info.t + + (** [name state] a state name that was given during the construction. *) + val name : 'a t -> string + end + + + + (** An interface to the state. + + An interface gives an access to operations that query and + modify machine state. *) + module type State = sig + (** [get state] extracts the state. *) + val get : 'a state -> 'a machine + + (** [put state x] saves a machine state *) + val put : 'a state -> 'a -> unit machine + + (** [update state ~f] updates a state using function [f]. *) + val update : 'a state -> f:('a -> 'a) -> unit machine + end + + (** Observations interface. + + An observation is a named event, that can occur during the + program execution. Observations could be provided (usually + by components that are implementing a paricular primitive), + and observed (i.e., a component could be notified every time + an observation is made). In other word, the Observation module + provides a publish/subscribe service. + + The observation system uses the continutation passing style to + enable polymorphic event system which doesn't rely on boxed + types, such as tuples and records, to deliver observation + (events) to subscribers. + + Each observation is parametrized by a type of a function which + is used to provide the observation. For a concrete example, + let's take the [sum] observation, which occurs every time + a sum of two values is computed. This observation will have + three arguments (for simplicity let's assume that they have type + [int]) and (this is true for all observation types) will have + the return type [ok machine], so the final type of the [sum] is + [int -> int -> int -> ok machine]. + + {3 Providing Observations} + + Observations are provided using the [Observation.provide] + function, which takes a function, that will be called with one + parameter, which is a function on itself and has type ['f]. We + call this function [observe], since it is actually the observer + which is being notified. Here is an example, using our [sum] + observation: + + {[ + Observation.make sum ~f:(fun observe -> + observe 1 2 3) + ]} + + + {3 Monitoring Observations} + + It is possible to register a function, which will be called + every time an observation is made via the [provide] function. + The monitor has a little bit more complicated type, as beyond + the actual payload (arguments of the observation), it takes a + [ctrl] instance, which should be used to return from the + observation, via [Observation.continue] or [Observation.stop] + functions. + + *) + module Observation : sig + type 'f t = 'f observation + type info = Info.t + type ctrl + + val declare : + ?inspect:((info -> observed machine) -> 'f) -> + ?package:string -> string -> + 'f observation + + (** [provide obs f] provides the observation of [obs]. + + The function [f] takes one argument a function, + which accepts + + + *) + val provide : 'f observation -> f:('f -> observed machine) -> unit machine + val monitor : 'f observation -> f:(ctrl -> 'f) -> unit machine + val inspect : 'f observation -> f:(info -> observed machine) -> unit machine + + val continue : ctrl -> observed machine + val stop : ctrl -> observed machine + end + + (** [exn_raised exn] occurs every time an abnormal control flow + is initiated *) + val exn_raised : (exn -> observed machine) observation + + + (** Computation Syntax.*) + module Syntax : sig + include Monad.Syntax.S with type 'a t := 'a t + + (** [x-->p] is [collect p x] *) + val (-->) : 'a obj -> ('a,'p) slot -> 'p t + + (** [c // s] is [Object.read c s] *) + val (//) : ('a,_) cls -> string -> 'a obj t + + (** [event >>> action] is the same as + [Observation.monitor event action] *) + val (>>>) : 'f observation -> (Observation.ctrl -> 'f) -> unit t + end + + + + include Monad.State.Multi.S with type 'a t := 'a t + and type id := id + and module Syntax := Syntax + + (** Local state of the machine. *) + module Local : State + + + (** Global state shared across all machine clones. *) + module Global : State +end diff --git a/lib/bap_sema/bap_sema_lift.ml b/lib/bap_sema/bap_sema_lift.ml index 39abfdf04..1a9db7784 100644 --- a/lib/bap_sema/bap_sema_lift.ml +++ b/lib/bap_sema/bap_sema_lift.ml @@ -1,361 +1,281 @@ +open Bap_core_theory + open Core_kernel open Bap_types.Std open Graphlib.Std open Bap_image_std open Bap_disasm_std open Bap_ir -open Format - - -(* A note about lifting call instructions. - - We're labeling calls with an expected continuation, that should be - derived from the BIL. But instead we lift calls in a rather - speculative way, thus breaking the abstraction of the BIL, that - desugars calls into two pseudo instructions: - - - - - A correct way of doing things would be to find a live write to the - place that is used to store return address (ABI specific), and put - this expression as an expected return address (aka continuation). - - But a short survey into existing instruction sets shows, that call - instructions doesn't allow to store something other then next - instruction, e.g., `call` in x86, `bl` in ARM, `jal` in MIPS, - `call` and `jumpl` in SPARC (although the latter allows to choose - arbitrary register to store return address). That's all is not to - say, that it is impossible to encode a call with return address - different from a next instruction, that's why it is called a - speculation. -*) - -type linear = - | Label of tid - | Instr of Ir_blk.elt +let update_jmp jmp ~f = + f (Ir_jmp.dst jmp) (Ir_jmp.alt jmp) @@ fun ~dst ~alt -> + Ir_jmp.reify + ~tid:(Term.tid jmp) + ?cnd:(Ir_jmp.guard jmp) + ?dst ?alt () -let fall_of_block cfg block = +let intra_fall cfg block = Seq.find_map (Cfg.Node.outputs block cfg) ~f:(fun e -> match Cfg.Edge.label e with - | `Fall -> Some (Cfg.Edge.dst e) + | `Fall -> Option.some @@ + Ir_jmp.resolved (Tid.for_addr (Block.addr (Cfg.Edge.dst e))) | _ -> None) -let label_of_fall cfg block = - Option.map (fall_of_block cfg block) ~f:(fun blk -> - Label.indirect Bil.(int (Block.addr blk))) - -let annotate_insn term insn = Term.set_attr term Disasm.insn insn -let annotate_addr term addr = Term.set_attr term address addr - -let linear_of_stmt ?addr return insn stmt : linear list = - let (~@) t = match addr with - | None -> t - | Some addr -> annotate_addr (annotate_insn t insn) addr in - let goto ?cond id = - `Jmp ~@(Ir_jmp.create_goto ?cond (Label.direct id)) in - let jump ?cond exp = - let target = Label.indirect exp in - if Insn.(is return) insn - then Ir_jmp.create_ret ?cond target - else if Insn.(is call) insn - then - Ir_jmp.create_call ?cond (Call.create ?return ~target ()) - else Ir_jmp.create_goto ?cond target in - let jump ?cond exp = Instr (`Jmp ~@(jump ?cond exp)) in - let cpuexn ?cond n = - let landing = Tid.create () in - let takeoff = Tid.create () in - let exn = `Jmp ~@(Ir_jmp.create_int ?cond n landing) in - match return with - | None -> [ - Instr exn; - Label landing; - (* No code was found that follows the interrupt, - so this is a no-return interrupt *) - ] - | Some lab -> [ - Instr (goto takeoff); - Label landing; - Instr (`Jmp ~@(Ir_jmp.create_goto lab)); - Label takeoff; - Instr exn; - ] in - - let rec linearize = function - | Bil.Move (lhs,rhs) -> - [Instr (`Def ~@(Ir_def.create lhs rhs))] - | Bil.If (_, [],[]) -> [] - | Bil.If (cond,[],no) -> linearize Bil.(If (lnot cond, no,[])) - | Bil.If (cond,yes,[]) -> - let yes_label = Tid.create () in - let tail = Tid.create () in - Instr (goto ~cond yes_label) :: - Instr (goto tail) :: - Label yes_label :: - List.concat_map yes ~f:linearize @ - Instr (goto tail) :: - Label tail :: [] - | Bil.If (cond,yes,no) -> - let yes_label = Tid.create () in - let no_label = Tid.create () in - let tail = Tid.create () in - Instr (goto ~cond yes_label) :: - Instr (goto no_label) :: - Label yes_label :: - List.concat_map yes ~f:linearize @ - Instr (goto tail) :: - Label no_label :: - List.concat_map no ~f:linearize @ - Instr (goto tail) :: - Label tail :: [] - | Bil.Jmp exp -> [jump exp] - | Bil.CpuExn n -> cpuexn n - | Bil.Special _ -> [] - | Bil.While (cond,body) -> - let header = Tid.create () in - let tail = Tid.create () in - let finish = Tid.create () in - Instr (goto tail) :: - Label header :: - List.concat_map body ~f:linearize @ - Instr (goto tail) :: - Label tail :: - Instr (goto ~cond header) :: - Instr (goto finish) :: - Label finish :: [] in - linearize stmt - - -let lift_insn ?addr fall init insn = - List.fold (Insn.bil insn) ~init ~f:(fun init stmt -> - List.fold (linear_of_stmt ?addr fall insn stmt) ~init - ~f:(fun (bs,b) -> function - | Label lab -> - Ir_blk.Builder.result b :: bs, - Ir_blk.Builder.create ~tid:lab () - | Instr elt -> - Ir_blk.Builder.add_elt b elt; bs,b)) - -let has_jump_under_condition bil = - with_return (fun {return} -> - let enter_control ifs = if ifs = 0 then ifs else return true in - Bil.fold (object - inherit [int] Stmt.visitor - method! enter_if ~cond:_ ~yes:_ ~no:_ x = x + 1 - method! leave_if ~cond:_ ~yes:_ ~no:_ x = x - 1 - method! enter_jmp _ ifs = enter_control ifs - method! enter_cpuexn _ ifs = enter_control ifs - end) ~init:0 bil |> fun (_ : int) -> false) - -let is_conditional_jump jmp = - Insn.(may affect_control_flow) jmp && - has_jump_under_condition (Insn.bil jmp) - -let has_called block addr = - let finder = - object inherit [unit] Stmt.finder - method! enter_jmp e r = - match e with - | Bil.Int a when Addr.(a = addr) -> r.return (Some ()) - | _ -> r - end in - Bil.exists finder (Insn.bil (Block.terminator block)) - -let fall_of_symtab symtab block = - Option.( - symtab >>= fun symtab -> - match Symtab.enum_calls symtab (Block.addr block) with - | [] -> None - | calls -> - List.find_map calls - ~f:(fun (n,e) -> Option.some_if (e = `Fall) n) >>= fun name -> - Symtab.find_by_name symtab name >>= fun (_,entry,_) -> - Option.some_if Block.(block <> entry) entry >>= fun callee -> - let addr = Block.addr callee in - Option.some_if (not (has_called block addr)) () >>= fun () -> - let bldr = Ir_blk.Builder.create () in - let call = Call.create ~target:(Label.indirect Bil.(int addr)) () in - let () = Ir_blk.Builder.add_jmp bldr (Ir_jmp.create_call call) in - Some (Ir_blk.Builder.result bldr)) - -let blk ?symtab cfg block : blk term list = - let fall_to_fn = fall_of_symtab symtab block in - let fall_label = - match label_of_fall cfg block, fall_to_fn with - | None, Some b -> Some (Label.direct (Term.tid b)) - | fall_label,_ -> fall_label in - List.fold (Block.insns block) ~init:([],Ir_blk.Builder.create ()) - ~f:(fun init (mem,insn) -> - let addr = Memory.min_addr mem in - lift_insn ~addr fall_label init insn) |> - fun (bs,b) -> - let fall = - let jmp = Block.terminator block in - if Insn.(is call) jmp && not (is_conditional_jump jmp) - then None else match fall_label with - | None -> None - | Some dst -> Some (`Jmp (Ir_jmp.create_goto dst)) in - Option.iter fall ~f:(Ir_blk.Builder.add_elt b); - let b = Ir_blk.Builder.result b in - let blocks = match fall_to_fn with - | None -> b :: bs - | Some b' -> b' :: b :: bs in - List.rev blocks |> function - | [] -> assert false - | b::bs -> Term.set_attr b address (Block.addr block) :: bs - -let resolve_jmp ~local addrs jmp = - let update_kind jmp addr make_kind = - Option.value_map ~default:jmp - (Hashtbl.find addrs addr) - ~f:(fun id -> Ir_jmp.with_kind jmp (make_kind id)) in - match Ir_jmp.kind jmp with - | Ret _ | Int _ -> jmp - | Goto (Indirect (Bil.Int addr)) -> - update_kind jmp addr (fun id -> - if local then Goto (Direct id) - else - Call (Call.create ~target:(Direct id) ())) - | Goto _ -> jmp - | Call call -> - let jmp,call = match Call.target call with - | _ when local -> jmp, call - | Indirect (Bil.Int addr) -> - let new_call = ref call in - let jmp = update_kind jmp addr - (fun id -> - new_call := Call.with_target call (Direct id); - Call !new_call) in - jmp, !new_call - | _ -> jmp,call in - match Call.return call with - | Some (Indirect (Bil.Int addr)) when Hashtbl.mem addrs addr -> - update_kind jmp addr - (fun id -> Call (Call.with_return call (Direct id))) - | Some (Indirect (Bil.Int _)) -> - Ir_jmp.with_kind jmp @@ Call (Call.with_noreturn call) - | _ -> jmp - -(* remove all jumps that are after unconditional jump *) -let remove_false_jmps blk = - Term.enum jmp_t blk |> Seq.find ~f:(fun jmp -> - Exp.(Ir_jmp.cond jmp = (Bil.Int Word.b1))) |> function - | None -> blk - | Some last -> - Term.after jmp_t blk (Term.tid last) |> Seq.map ~f:Term.tid |> - Seq.fold ~init:blk ~f:(Term.remove jmp_t) - -let unbound _ = true - -let lift_sub ?symtab entry cfg = - let addrs = Addr.Table.create () in - let recons acc b = - let addr = Block.addr b in - let blks = blk ?symtab cfg b in - Option.iter (List.hd blks) ~f:(fun blk -> - Hashtbl.add_exn addrs ~key:addr ~data:(Term.tid blk)); - acc @ blks in - let blocks = Graphlib.reverse_postorder_traverse - (module Cfg) ~start:entry cfg in - let blks = Seq.fold blocks ~init:[] ~f:recons in - let n = let n = List.length blks in Option.some_if (n > 0) n in - let sub = Ir_sub.Builder.create ?blks:n () in - List.iter blks ~f:(fun blk -> - Ir_sub.Builder.add_blk sub - (Term.map jmp_t blk ~f:(resolve_jmp ~local:true addrs))); +(* a subroutine could be called implicitly via a fallthrough, + therefore it will not be reified into jmp terms automatically, + so we need to do some work here. + We look into the symtable for an implicit call, and if + such exists and it is physically present in the symtab, then we + return a tid for the address, otherwise we return a tid for the + subroutine name. It is important to return the tid for address, + so that we can compare tids in the [insert_call] function. *) +let inter_fall symtab block = + let open Option.Monad_infix in + symtab >>= fun symtab -> + Symtab.implicit_callee symtab (Block.addr block) >>| fun name -> + Ir_jmp.resolved @@ + match Symtab.find_by_name symtab name with + | None -> Tid.for_name name + | Some (_,entry,_) -> Tid.for_addr (Block.addr entry) + +module IrBuilder = struct + + let def_only blk = Term.length jmp_t blk = 0 + + (* concat two def-only blocks *) + let append_def_only b1 b2 = + let b = Ir_blk.Builder.init ~same_tid:true ~copy_defs:true b1 in + Term.enum def_t b2 |> Seq.iter ~f:(Ir_blk.Builder.add_def b); + Term.enum jmp_t b2 |> Seq.iter ~f:(Ir_blk.Builder.add_jmp b); + Ir_blk.Builder.result b + + let append xs ys = match xs, ys with + | [],xs | xs,[] -> xs + | x :: xs, y :: ys when def_only x -> + List.rev_append ys (append_def_only x y :: xs) + | xs, ys -> List.rev_append ys xs + + let ir_of_insn insn = KB.Value.get Term.slot insn + + let set_attributes ?mem insn blks = + let addr = Option.map ~f:Memory.min_addr mem in + let set_attributes k b = + Term.map k b ~f:(fun t -> + let t = Term.set_attr t Disasm.insn insn in + Option.value_map addr ~f:(Term.set_attr t address) + ~default:t) in + List.map blks ~f:(fun blk -> + set_attributes jmp_t blk |> + set_attributes def_t) + + let lift_insn ?mem insn blks = + append blks @@ + set_attributes ?mem insn (ir_of_insn insn) + + let with_first_blk_addressed addr = function + | [] -> [] + | b :: bs -> Term.set_attr b address addr :: bs + + let turn_into_call ret blk = + Term.map jmp_t blk ~f:(update_jmp ~f:(fun dst _ jmp -> + jmp ~dst:ret ~alt:dst)) + + let landing_pad return jmp = + match Ir_jmp.kind jmp with + | Int (_,pad) -> + let pad = Ir_blk.create ~tid:pad () in + let pad = match return with + | None -> pad + | Some dst -> Term.append jmp_t pad (Ir_jmp.reify ~dst ()) in + Some pad + | _ -> None + + let with_landing_pads return bs = match bs with + | [] -> [] + | b :: bs as blks -> + let pads = List.fold ~init:[] blks ~f:(fun pads b -> + Term.enum jmp_t b |> + Seq.fold ~init:pads ~f:(fun pads jmp -> + match landing_pad return jmp with + | Some pad -> pad :: pads + | None -> pads)) in + b :: List.rev_append pads bs + + let resolves_equal x y = + match Ir_jmp.resolve x, Ir_jmp.resolve y with + | First x, First y -> Tid.equal x y + | _ -> false + + let insert_inter_fall alt blk = + [Term.append jmp_t blk @@ Ir_jmp.reify ~alt ()] + + let is_last_jump_nonconditional blk = + match Term.last jmp_t blk with + | None -> false + | Some jmp -> match Ir_jmp.cond jmp with + | Bil.Int x -> Word.equal Word.b1 x + | _ -> false + + let fall_if_possible blk fall = + if is_last_jump_nonconditional blk + then blk + else Term.append jmp_t blk fall + + let blk ?symtab cfg block : blk term list = + let tid = Tid.for_addr (Block.addr block) in + let blks = + Block.insns block |> + List.fold ~init:[Ir_blk.create ~tid ()] ~f:(fun blks (mem,insn) -> + lift_insn ~mem insn blks) in + let fall = intra_fall cfg block in + let blks = with_landing_pads fall blks in + let x = Block.terminator block in + let is_call = Insn.(is call x) + and is_barrier = Insn.(is barrier x) in + with_first_blk_addressed (Block.addr block) @@ + List.rev @@ match blks,fall with + | [],_ -> [] + | blks,_ when is_barrier -> blks + | x::xs, Some dst -> + if is_call + then turn_into_call fall x :: xs + else fall_if_possible x (Ir_jmp.reify ~dst ()) :: xs + | x::xs, None -> + let x = if is_call then turn_into_call fall x else x in + match inter_fall symtab block with + | Some dst when Term.length jmp_t x = 0 -> + insert_inter_fall dst x @ xs + | _ -> x::xs +end + +let blk cfg block = IrBuilder.blk cfg block + +let lift_sub ?symtab ?tid entry cfg = + let sub = Ir_sub.Builder.create ?tid ~blks:32 () in + Graphlib.reverse_postorder_traverse (module Cfg) ~start:entry cfg |> + Seq.iter ~f:(fun block -> + let blks = IrBuilder.blk ?symtab cfg block in + List.iter blks ~f:(Ir_sub.Builder.add_blk sub)); let sub = Ir_sub.Builder.result sub in Term.set_attr sub address (Block.addr entry) -let create_synthetic name = - let sub = Ir_sub.create ~name () in - Tid.set_name (Term.tid sub) name; - Term.(set_attr sub synthetic ()) +(* Rewires some intraprocedural jmps into interprocedural. + + - If a jmp is unresolved, we can't really prove that the control + flow won't leave the boundaries of its subroutine, therefore we + conservatively turn it into a call. + + - If a jmp is resolved, but its destination is not within + the boundaries of its function, then we reclassify it as + interprocedural. + + - If a jmp already has an alternate route, to keep things simple, + we do not touch it. *) +let alternate_nonlocal sub jmp = + let needs_alternation = + Option.is_none (Ir_jmp.alt jmp) && match Ir_jmp.dst jmp with + | None -> false (* already nonlocal *) + | Some dst -> match Ir_jmp.resolve dst with + | Second _ -> true (* all unresolved are potentially calls *) + | First dst -> Option.is_none (Term.find blk_t sub dst) in + if needs_alternation + then update_jmp jmp ~f:(fun dst _ jmp -> jmp ~dst:None ~alt:dst) + else jmp + +(* On the local level of lifting all resolved jmps are pointing to + basic blocks. When we lift on the program (global) level, we + need to link all resolved calls to subroutines. + + The [sub_of_blk] mapping maps tids of entries (basic blocks) to + tids of their corresponding subroutines. If an intraprocedural + edge points to an entry of a subroutine we relink it to that + subroutine. + + If a jump is an interprocedural edge, that points to an entry of + a subroutine we turn it into a tail call. This could be a real tail + call, a self recursive call, or an error from the previous steps of + lifting. In any case, it will guarantee that the entry block has + the local indegree zero. + + Finally, if both jump edges are not pointing to the entries and + the interprocedural edge is resolved, then we have a jump to an + external function, so we relink it to a label that corresponds + to the name of the external function. Note, this step relies on + the [alternate_nonlocal] pass, described above. *) +let link_call symtab addr sub_of_blk jmp = + let open Option.Monad_infix in + let resolve dst = dst jmp >>| Ir_jmp.resolve >>= function + | Second _ -> None + | First tid -> Some tid in + let sub_of_dst dst = + resolve dst >>= Hashtbl.find sub_of_blk >>| Ir_jmp.resolved in + let external_callee () = + addr >>= Symtab.explicit_callee symtab >>| Tid.for_name >>| + Ir_jmp.resolved in + match sub_of_dst Ir_jmp.dst, sub_of_dst Ir_jmp.alt with + | _, (Some _ as alt) -> + update_jmp jmp ~f:(fun dst _ jmp -> jmp ~dst ~alt) + | Some _ as alt, None -> + update_jmp jmp ~f:(fun _ _ jmp -> jmp ~dst:None ~alt) + | None,None -> match resolve Ir_jmp.alt with + | None -> jmp + | Some (_:tid) -> match external_callee () with + | Some alt -> update_jmp jmp ~f:(fun dst _ jmp -> + jmp ~dst ~alt:(Some alt)) + | None -> jmp + + +let insert_synthetic prog = + Term.enum sub_t prog |> + Seq.fold ~init:prog ~f:(fun prog sub -> + Term.enum blk_t sub |> + Seq.fold ~init:prog ~f:(fun prog blk -> + Term.enum jmp_t blk |> + Seq.fold ~init:prog ~f:(fun prog jmp -> + match Ir_jmp.alt jmp with + | None -> prog + | Some dst -> match Ir_jmp.resolve dst with + | Second _ -> prog + | First dst -> + if Option.is_some (Term.find sub_t prog dst) + then prog + else + Term.append sub_t prog @@ + Ir_sub.create ~tid:dst ()))) -let indirect_target jmp = - match Ir_jmp.kind jmp with - | Ret _ | Int _ | Goto _ -> None - | Call call -> match Call.target call with - | Indirect (Bil.Int a) -> Some a - | _ -> None - -let is_indirect_call jmp = Option.is_some (indirect_target jmp) - -let with_address t ~f ~default = - Option.value_map ~default ~f (Term.get_attr t address) - -let with_address_opt t ~f ~default = - let g a = Option.value (f a) ~default in - with_address t ~f:g ~default - -let update_unresolved symtab unresolved exts sub = - let iter cls t ~f = Term.to_sequence cls t |> Seq.iter ~f in - let symbol_exists name = - Option.is_some (Symtab.find_by_name symtab name) in - let is_known a = Option.is_some (Symtab.find_by_start symtab a) in - let is_unknown name = not (symbol_exists name) in - let add_external (name,_) = - if is_unknown name then - Hashtbl.update exts name ~f:(function - | None -> create_synthetic name - | Some x -> x) in - iter blk_t sub ~f:(fun blk -> - iter jmp_t blk ~f:(fun jmp -> - match indirect_target jmp with - | None -> () - | Some a when is_known a -> () - | _ -> - with_address blk ~default:() ~f:(fun addr -> - Hash_set.add unresolved addr; - Symtab.enum_calls symtab addr |> - List.iter ~f:add_external))) - -let resolve_indirect symtab exts blk jmp = - let update_target tar = - Option.some @@ - match Ir_jmp.kind jmp with - | Call c -> Ir_jmp.with_kind jmp (Call (Call.with_target c tar)) - | _ -> jmp in - let resolve_name (name,_) = - match Symtab.find_by_name symtab name with - | Some (_,b,_) -> update_target (Indirect (Int (Block.addr b))) - | _ -> match Hashtbl.find exts name with - | Some s -> update_target (Direct (Term.tid s)) - | None -> None in - with_address_opt blk ~default:jmp ~f:(fun addr -> - Symtab.enum_calls symtab addr |> - List.find_map ~f:resolve_name) let program symtab = let b = Ir_program.Builder.create () in - let addrs = Addr.Table.create () in - let externals = String.Table.create () in - let unresolved = Addr.Hash_set.create () in + let sub_of_blk = Hashtbl.create (module Tid) in + let tid_for_sub = + let tids = Hash_set.create (module Tid) () in + fun name -> + let tid = Tid.for_name name in + match Hash_set.strict_add tids tid with + | Ok () -> tid + | Error _ -> + let tid = Tid.create () in + Tid.set_name tid name; + Hash_set.strict_add_exn tids tid; + tid in Seq.iter (Symtab.to_sequence symtab) ~f:(fun (name,entry,cfg) -> let addr = Block.addr entry in - let sub = lift_sub ~symtab entry cfg in + let blk_tid = Tid.for_addr addr in + let sub_tid = tid_for_sub name in + let sub = lift_sub ~symtab ~tid:sub_tid entry cfg in Ir_program.Builder.add_sub b (Ir_sub.with_name sub name); - Tid.set_name (Term.tid sub) name; - Hashtbl.add_exn addrs ~key:addr ~data:(Term.tid sub); - update_unresolved symtab unresolved externals sub); - Hashtbl.iter externals ~f:(Ir_program.Builder.add_sub b); + Hashtbl.add_exn sub_of_blk ~key:blk_tid ~data:sub_tid;); let program = Ir_program.Builder.result b in - let has_unresolved blk = - with_address blk ~default:false ~f:(Hash_set.mem unresolved) in - Term.map sub_t program - ~f:(fun sub -> Term.map blk_t sub ~f:(fun blk -> - Term.map jmp_t (remove_false_jmps blk) - ~f:(fun j -> - let j = - if is_indirect_call j && has_unresolved blk then - resolve_indirect symtab externals blk j - else j in - resolve_jmp ~local:false addrs j))) - -let sub = lift_sub ?symtab:None -let blk = blk ?symtab:None + Term.map sub_t program ~f:(fun sub -> + Term.map blk_t sub ~f:(fun blk -> + let addr = Term.get_attr blk address in + Term.map jmp_t blk ~f:(fun jmp -> + jmp |> + alternate_nonlocal sub |> + link_call symtab addr sub_of_blk))) |> + insert_synthetic + +let sub blk cfg = lift_sub blk cfg let insn insn = - lift_insn None ([], Ir_blk.Builder.create ()) insn |> - function (bs,b) -> List.rev (Ir_blk.Builder.result b :: bs) + List.rev @@ IrBuilder.lift_insn insn [Ir_blk.create ()] diff --git a/lib/bap_sema/bap_sema_ssa.ml b/lib/bap_sema/bap_sema_ssa.ml index 7782a08a9..7baa9d626 100644 --- a/lib/bap_sema/bap_sema_ssa.ml +++ b/lib/bap_sema/bap_sema_ssa.ml @@ -54,7 +54,15 @@ let iterated_frontier f blks = if Set.equal idf idf' then idf' else fixpoint idf' in fixpoint Tid.Set.empty -let blk_of_tid = Term.find_exn blk_t +let blk_of_tid sub tid = match Term.find blk_t sub tid with + | Some blk -> blk + | None -> + failwithf + "Internal error. Broken invariant in subroutine %s: \ + A term %a is missing" (Ir_sub.name sub) Tid.pps tid + () + + let succs cfg sub tid = Cfg.Node.succs tid cfg |> Seq.map ~f:(blk_of_tid sub) @@ -132,7 +140,7 @@ let rename t = | _ -> phi)) in let pop_defs blk = let pop v = Hashtbl.change vars (Var.base v) (function - | Some (x::xs) -> Some xs + | Some (_::xs) -> Some xs | xs -> xs) in Term.enum phi_t blk |> Seq.iter ~f:(fun phi -> pop (Ir_phi.lhs phi)); diff --git a/lib/bap_types/.merlin b/lib/bap_types/.merlin index 86d75ef62..343687aeb 100644 --- a/lib/bap_types/.merlin +++ b/lib/bap_types/.merlin @@ -3,3 +3,4 @@ PKG uuidm REC B ../../_build/lib/bap_types +B ../../_build/lib/knowledge diff --git a/lib/bap_types/bap_arch.ml b/lib/bap_types/bap_arch.ml index bb0383659..b12c02148 100644 --- a/lib/bap_types/bap_arch.ml +++ b/lib/bap_types/bap_arch.ml @@ -1,4 +1,5 @@ open Core_kernel +open Bap_core_theory open Regular.Std open Bap_common @@ -60,6 +61,18 @@ module T = struct | #arm | #thumb | #x86 | #aarch64 | #r600 | #hexagon | #nvptx | #xcore -> LittleEndian | #ppc | #mips | #sparc | #systemz -> BigEndian + + let equal x y = compare_arch x y = 0 + + let domain = + KB.Domain.optional ~equal ~inspect:sexp_of_t "arch" + + + let slot = KB.Class.property ~package:"bap.std" + Theory.Program.cls "arch" domain + ~persistent:(KB.Persistent.of_binable (module struct + type t = arch option [@@deriving bin_io] + end)) end include T diff --git a/lib/bap_types/bap_arch.mli b/lib/bap_types/bap_arch.mli index 307d051e2..54d1f319d 100644 --- a/lib/bap_types/bap_arch.mli +++ b/lib/bap_types/bap_arch.mli @@ -1,4 +1,5 @@ open Core_kernel +open Bap_core_theory open Regular.Std open Bap_common @@ -8,4 +9,6 @@ val addr_size : arch -> addr_size val endian : arch -> endian +val slot : (Theory.program, arch option) KB.slot + include Regular.S with type t := arch diff --git a/lib/bap_types/bap_attributes.ml b/lib/bap_types/bap_attributes.ml index 8425dc1e5..0b2d412bd 100644 --- a/lib/bap_types/bap_attributes.ml +++ b/lib/bap_types/bap_attributes.ml @@ -51,42 +51,13 @@ end type color = Color.t [@@deriving bin_io, compare, sexp_poly] -let comment = register (module String) - ~name:"comment" - ~uuid:"4b974ab3-bf3b-4a83-8c62-299bca70f02a" - -let python = register (module String) - ~name:"python" - ~uuid:"831a6268-0ca8-4c1b-b4d6-076995f49d84" - -let shell = register (module String) - ~name:"shell" - ~uuid:"8c2c459d-6f3e-42f7-bce7-bf5cfa280f24" - -let mark = register (module Unit) - ~name:"mark" - ~uuid:"8e9801dc-0c64-4943-acf4-bfd02347af91" - -let color = register (module Color) - ~name:"color" - ~uuid:"1938c44a-149d-4c71-832a-7f484be800cc" - -let weight = register (module Float) - ~name:"weight" - ~uuid:"657366ea-9a28-4e5e-8341-c545d861732b" - -let address = register (module Bap_bitvector) - ~name:"address" - ~uuid:"7bcef7c0-0b37-4167-887a-eba0d68891fe" - -let filename = register (module String) - ~name:"filename" - ~uuid:"9701d189-24e3-4348-8610-0dedf780d06b" - -let foreground = register (module Foreground) - ~name:"foreground" - ~uuid:"56b29739-2df4-4e6c-9f63-15e20edf1857" - -let background = register (module Background) - ~name:"background" - ~uuid:"9a80a9cc-4106-48fc-abf3-55d7b333e734" +let comment = register (module String) ~name:"comment" ~uuid:"bap.std.attrs" +let python = register (module String) ~name:"python" ~uuid:"bap.std.attrs" +let shell = register (module String) ~name:"shell" ~uuid:"bap.std.attrs" +let mark = register (module Unit) ~name:"mark" ~uuid:"bap.std.attrs" +let color = register (module Color) ~name:"color" ~uuid:"bap.std.attrs" +let weight = register (module Float) ~name:"weight" ~uuid:"bap.std.attrs" +let address = register (module Bap_bitvector) ~name:"address" ~uuid:"bap.std.attrs" +let filename = register (module String) ~name:"filename" ~uuid:"bap.std.attrs" +let foreground = register (module Foreground) ~name:"foreground" ~uuid:"bap.std.attrs" +let background = register (module Background) ~name:"background" ~uuid:"bap.std.attrs" diff --git a/lib/bap_types/bap_bitvector.ml b/lib/bap_types/bap_bitvector.ml index ada3b180f..650e0cb81 100644 --- a/lib/bap_types/bap_bitvector.ml +++ b/lib/bap_types/bap_bitvector.ml @@ -4,251 +4,286 @@ open Or_error open Format -(* current representation has a very big overhead, - depending on a size of a payload it is minimum five words, - For example, although Zarith stores a 32bit word on a 64 bit - machine in one word and represent it as an unboxed int, we still - take four more words on top of that, as bitvector is represented as - a pointer (1) to a boxed value that contains (3) fields, plus the - header (1), that gives us 5 words (40 bytes), to store 4 bytes of - payload. - - We have the following representation in mind, that will minimize - the overhead. We will store an extra information, attached to a - word in the word itself. Thus, bitvector will become a Z.t. We will - use several bits of the word for meta data. - To be able to store 32 bit words on a 64 bit platform we need to - leave enough space in a 63 bit word for the payload. Ideally, we - would like to have a support for an arbitrary bitwidth, but we can - limit it to 2^14=32 (2 kB), spend one bit for sign (that can be - removed later), thus we will have 48 bits for the payload. - - - small: - +-----------+------+---+ - | payload | size | s | - +-----------+------+---+ - size+15 15 14 0 - +(* The bap_bitvector module is provided as a shim for the new and much + more efficient bitvec library. The sole purpose of this library is + to provide compatiblity with the legacy interace where bitvector + was bearing its width and signedness property, and operations on + bitvectors were defined by its runtime representation. - Given this scheme, all values smaller than 0x100_000_000_0000 will - have the same representation as OCaml int. + Since the shift to the new semantics representation, we no longer + need to store widths and signedness inside the bitvector, as those + properties are now totally defined by the typing context. - The performance overhead is minimal, especially since no - allocations are done anymore. + The implementation of the shim uses Bitv.t as the representation, + however we borrown the least 15 bits for storing bitwidth and the + signedness flag. This is basically the same representation as we + had for BAP 1.x, except that we used Z.t directly, and now we are + using Bitvec.t instead (which is also Z.t underneath the hood). - Speaking of the sign. I would propose to remove it from the - bitvector, as sign is not a property of a bitvector, it is its - interpretation. - Removing the sign will get us extra memory and CPU efficiency. - *) + word format: + +-----------+------+---+ + | payload | size | s | + +-----------+------+---+ + size+15 15 14 0 +*) type endian = LittleEndian | BigEndian - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] -module Bignum = struct - module Repr : Stringable with type t = Z.t = struct - type t = Z.t - let to_string = Z.to_bits - let of_string = Z.of_bits - end - include Z - include Binable.Of_stringable(Repr) -end -type bignum = Bignum.t +module Packed : sig + type t [@@deriving bin_io, sexp] -module type Compare = sig - val compare: int -> int -> int -end + val create : ?signed:bool -> Z.t -> int -> t + val bitwidth : t -> int + val modulus : t -> Bitvec.modulus + val is_signed : t -> bool + val signed : t -> t + val unsigned : t -> t + val payload : t -> Bitvec.t -module Size_poly = struct - let compare x y = Int.compare x y -end + val hash : t -> int -module Size_mono = struct - let compare x y = - if Int.(x <> y) then failwith "Non monomoprhic compare" else 0 -end - -(** internal representation *) -module Make(Size : Compare) = struct - type t = Bignum.t [@@deriving bin_io] - - let module_name = Some "Bap.Std.Bitvector" - - let version = "2.0.0" + val lift1 : t -> (Bitvec.t -> Bitvec.t Bitvec.m) -> t + val lift2 : t -> t -> (Bitvec.t -> Bitvec.t -> Bitvec.t Bitvec.m) -> t + val lift3 : t -> t -> t -> (Bitvec.t -> Bitvec.t -> Bitvec.t -> Bitvec.t Bitvec.m) -> t +end = struct + type packed = {packed : Z.t} [@@unboxed] + type meta = {meta : Z.t} [@@unboxed] + type data = {data : Z.t} [@@unboxed] + type t = packed let metasize = 15 let lenoff = 1 let lensize = metasize - 1 let maxlen = 1 lsl lensize + let metamask = Z.(one lsl metasize - one) + + let meta {packed} = {meta=Z.(packed land metamask)} [@@inline] + let bitwidth x = + let {meta} = meta x in + Z.(to_int (meta asr 1)) [@@inline] + let modulus x = Bitvec.modulus @@ bitwidth x [@@inline] + + let data {packed} = {data=Z.(packed asr metasize)} [@@inline] + let hash x = + let {data} = data x in + Z.hash data + let payload x = + let m = modulus x in + let z = data x in + Bitvec.(bigint z.data mod m) + [@@inline] + + let is_signed {packed=x} = Z.(is_odd x) [@@inline] + + let unsigned ({packed} as x) = + if is_signed x + then {packed=Z.((packed asr 1) lsl 1)} + else x - let meta x = Z.extract x 0 (metasize - 1) - let is_signed = Z.is_odd - let bitwidth x = Z.extract x lenoff lensize |> Z.to_int - let z x = + let mk_signed {packed=x} = {packed=Z.(x lor one)} [@@inline] + let signed x = mk_signed x [@@inline] + + let pack x w = + let meta = Z.of_int (w lsl 1) in + {packed=Z.(x lsl metasize lor meta)} + [@@inline] + + let create ?(signed=false) x w = + let m = Bitvec.modulus w in + let x = Bitvec.(bigint x mod m) in + let r = pack (Bitvec.to_bigint x) w in + if signed then mk_signed r else r + [@@inline] + + let lift1 x f = let w = bitwidth x in - if is_signed x - then Z.signed_extract x metasize w - else Z.extract x metasize w + let x = payload x in + pack Bitvec.(to_bigint (f x mod modulus w)) w + [@@inline] - let signed x = Z.(x lor one) + let lift2 x y f = + let w = bitwidth x in + let x = payload x and y = payload y in + pack Bitvec.(to_bigint (f x y mod modulus w)) w + [@@inline] - let with_z x v = + let lift3 x y z f = let w = bitwidth x in - let v = Z.(extract v 0 w lsl metasize) in - Z.(v lor meta x) - - let create z w = - if w > maxlen - then invalid_argf - "Bitvector overflow: maximum allowed with is %d bits" - maxlen (); - if w <= 0 - then invalid_argf - "A nonpositive width is specified (%s,%d)" - (Z.to_string z) w (); - let meta = Z.(of_int w lsl 1) in - let z = Z.(extract z 0 w lsl metasize) in - Z.(z lor meta) - - let unsigned x = create (z x) (bitwidth x) - let hash x = Z.hash (z x) - let bits_of_z x = Z.to_bits (z x) - let unop op t = op (z t) - let binop op t1 t2 = op (z t1) (z t2) - let lift1 op t = create (unop op t) (bitwidth t) - let lift2 op t1 t2 = create (binop op t1 t2) (bitwidth t1) - - let lift2_triple op t1 t2 : t * t * t = - let (a, b, c) = binop op t1 t2 in - let w = bitwidth t1 in - create a w, create b w, create c w - - let pp_generic - ?(case:[`lower|`upper]=`upper) - ?(prefix:[`auto|`base|`none|`this of string]=`auto) - ?(suffix:[`full|`none|`size]=`none) - ?(format:[`hex|`dec|`oct|`bin]=`hex) ppf x = - let width = bitwidth x in - let is_signed = is_signed x in - let is_negative = Z.compare (z x) Z.zero < 0 in - let x = Z.abs (z x) in - let word = Z.of_int in - let int = Z.to_int in - let base = match format with - | `dec -> word 10 - | `hex -> word 0x10 - | `oct -> word 0o10 - | `bin -> word 0b10 in - let pp_prefix ppf = match format with - | `dec -> () - | `hex -> fprintf ppf "0x" - | `oct -> fprintf ppf "0o" - | `bin -> fprintf ppf "0b" in - if is_negative then fprintf ppf "-"; - let () = match prefix with - | `none -> () - | `this x -> fprintf ppf "%s" x - | `base -> pp_prefix ppf - | `auto -> - if Z.compare x (Z.min (word 10) base) >= 0 - then pp_prefix ppf in - let fmt = format_of_string @@ match format, case with - | `hex,`upper -> "%X" - | `hex,`lower -> "%x" - | _ -> "%d" in - let rec print x = - let d = int Z.(x mod base) in - if x >= base - then print Z.(x / base); - fprintf ppf fmt d in - print x; - match suffix with - | `full -> fprintf ppf ":%d%c" width (if is_signed then 's' else 'u') - | `size -> fprintf ppf ":%d" width + let x = payload x and y = payload y and z = payload z in + pack Bitvec.(to_bigint (f x y z mod modulus w)) w + [@@inline] + + module Stringable = struct + type t = packed + let to_string {packed} = Z.to_bits packed + let of_string s = {packed = Z.of_bits s} + end + + include Binable.Of_stringable(Stringable) + include Sexpable.Of_stringable(Stringable) +end + +let pack = Packed.create +let payload = Packed.payload +let lift1 = Packed.lift1 +let lift2 = Packed.lift2 +let lift3 = Packed.lift3 + + +type t = Packed.t [@@deriving bin_io] + +let create x w = Packed.create (Bitvec.to_bigint x) w [@@inline] +let to_bitvec x = Packed.payload x [@@inline] +let unsigned x = x [@@inline] +let signed x = Packed.signed x [@@inline] +let hash x = Packed.hash x [@@inline] +let bits_of_z x = Bitvec.to_binary (Packed.payload x) +let unop op t = Packed.lift1 t op [@@inline] +let binop op t1 t2 = Packed.lift2 t1 t2 op [@@inline] +let bitwidth x = Packed.bitwidth x [@@inline] +let is_signed x = Packed.is_signed x [@@inline] + +let pp_generic + ?(case:[`lower|`upper]=`upper) + ?(prefix:[`auto|`base|`none|`this of string]=`auto) + ?(suffix:[`full|`none|`size]=`none) + ?(format:[`hex|`dec|`oct|`bin]=`hex) ppf x = + let width = bitwidth x in + let m = Bitvec.modulus width in + let is_signed = is_signed x in + let x = Packed.payload x in + let is_negative = is_signed && Bitvec.(msb x mod m) in + let x = if is_negative then Bitvec.(abs x mod m) else x in + let x = Bitvec.to_bigint x in + let word x = Z.of_int x in + let int x = Z.to_int x in + let base = match format with + | `dec -> word 10 + | `hex -> word 0x10 + | `oct -> word 0o10 + | `bin -> word 0b10 in + let pp_prefix ppf = match format with + | `dec -> () + | `hex -> fprintf ppf "0x" + | `oct -> fprintf ppf "0o" + | `bin -> fprintf ppf "0b" in + if is_negative then fprintf ppf "-"; + let () = match prefix with | `none -> () + | `this x -> fprintf ppf "%s" x + | `base -> pp_prefix ppf + | `auto -> + if Bitvec_order.(x >= (min (word 10) base)) + then pp_prefix ppf in + let fmt = format_of_string @@ match format, case with + | `hex,`upper -> "%X" + | `hex,`lower -> "%x" + | _ -> "%d" in + let rec print x = + let d = int Z.(x mod base) in + if Z.(x >= base) + then print Z.(x / base); + fprintf ppf fmt d in + print x; + match suffix with + | `full -> fprintf ppf ":%d%c" width (if is_signed then 's' else 'u') + | `size -> fprintf ppf ":%d" width + | `none -> () + +let pp_full ppf = pp_generic ~suffix:`full ppf +let pp = pp_full + +let string_of_word x = asprintf "%a" pp_full x + +let of_suffixed stem suffix = + let z = Z.of_string stem in + let sl = String.length suffix in + if sl = 0 + then invalid_arg "Bitvector.of_string: an empty suffix"; + let chop x = String.subo ~len:(sl - 1) x in + match suffix.[sl-1] with + | 's' -> pack ~signed:true z (Int.of_string (chop suffix)) + | 'u' -> pack z (Int.of_string (chop suffix)) + | x when Char.is_digit x -> pack z (Int.of_string suffix) + | _ -> invalid_arg "Bitvector.of_string: invalid prefix format" + +let word_of_string = function + | "false" -> pack Z.zero 1 + | "true" -> pack Z.one 1 + | s -> match String.split ~on:':' s with + | [z; n] -> of_suffixed z n + | _ -> failwithf "Bitvector.of_string: '%s'" s () + +let pp_hex ppf = pp_generic ppf +let pp_dec ppf = pp_generic ~format:`dec ppf +let pp_oct ppf = pp_generic ~format:`oct ppf +let pp_bin ppf = pp_generic ~format:`bin ppf + +let pp_hex_full ppf = pp_generic ~suffix:`full ppf +let pp_dec_full ppf = pp_generic ~format:`dec ~suffix:`full ppf +let pp_oct_full ppf = pp_generic ~format:`oct ~suffix:`full ppf +let pp_bin_full ppf = pp_generic ~format:`bin ~suffix:`full ppf - let compare l r = - let s = Size.compare (bitwidth l) (bitwidth r) in - if s <> 0 then s - else match is_signed l, is_signed r with - | true,true | false,false -> Bignum.compare (z l) (z r) - | true,false -> Bignum.compare (z l) (z (signed r)) - | false,true -> Bignum.compare (z (signed l)) (z r) - - let pp_full ppf = pp_generic ~suffix:`full ppf - let pp = pp_full - - let to_string x = - let z = z x in - match bitwidth x with - | 1 -> if Z.equal z Z.zero then "false" else "true" - | n -> asprintf "%a" pp_full x - - let of_suffixed stem suffix = - let z = Bignum.of_string stem in - let sl = String.length suffix in - if sl = 0 - then invalid_arg "Bitvector.of_string: an empty suffix"; - let chop x = String.subo ~len:(sl - 1) x in - match suffix.[sl-1] with - | 's' -> create z (Int.of_string (chop suffix)) |> signed - | 'u' -> create z (Int.of_string (chop suffix)) - | x when Char.is_digit x -> create z (Int.of_string suffix) - | _ -> invalid_arg "Bitvector.of_string: invalid prefix format" - - let of_string = function - | "false" -> create Bignum.zero 1 - | "true" -> create Bignum.one 1 - | s -> match String.split ~on:':' s with - | [z; n] -> of_suffixed z n - | _ -> failwithf "Bitvector.of_string: '%s'" s () - - let with_validation t ~f = Or_error.map ~f (Validate.result t) - - let extract ?hi ?(lo=0) t = - let n = bitwidth t in - let z = z t in - let hi = Option.value ~default:(n-1) hi in - let len = hi-lo+1 in - if len <= 0 - then failwithf "Bitvector.extract: len %d is negative" len (); - create (Z.extract z lo len) len - - let sexp_of_t t = Sexp.Atom (to_string t) +let string_of_value ?(hex=true) x = + if hex + then asprintf "%a" (fun p -> pp_generic ~prefix:`none ~case:`lower p) x + else asprintf "%a" (fun p -> pp_generic ~format:`dec p) x + +let pp = pp_hex +let to_string = string_of_word +let of_string = word_of_string + +module Sexp_hum = struct + type t = Packed.t + let sexp_of_t x = Sexp.Atom (to_string x) let t_of_sexp = function - | Sexp.Atom s -> of_string s - | _ -> invalid_argf - "Bitvector.of_sexp: expected an atom got a list" () + | Sexp.Atom x -> of_string x + | _ -> invalid_arg "Bitvector.t_of_sexp: expects an atom" end -(* About monomorphic comparison. - - With monomorphic size comparison functions [hash] and [compare] - will not be coherent with each other, since we can't prohibit - someone to compare hashes from bitvectors with different sizes. For - example, it means, that we can't really guarantee that in a Table - all keys are bitvectors with the same size. So, as a consequence we - can't make bitvector a real value implementing [Identifiable] - interface. Since, monomorphic behaviour is rather specific and - unintuitive we will move it in a separate submodule and use size - polymorphic comparison by default. -*) -module T = Make(Size_poly) -include T +include (Sexp_hum : Sexpable.S with type t := Packed.t) + +let msb x = Bitvec.(msb (Packed.payload x) mod Packed.modulus x) +let lsb x = Bitvec.(lsb (Packed.payload x) mod Packed.modulus x) + +type packed = Packed.t [@@deriving bin_io] +let sexp_of_packed = Sexp_hum.sexp_of_t +let packed_of_sexp = Sexp_hum.t_of_sexp + +let compare_mono x y = + if is_signed x || is_signed y then + let x_is_neg = msb x and y_is_neg = msb y in + match x_is_neg, y_is_neg with + | true,false -> -1 + | false,true -> 1 + | _ -> Bitvec.compare (payload x) (payload y) + else Bitvec.compare (payload x) (payload y) + +let with_validation t ~f = Or_error.map ~f (Validate.result t) + +let extract ?hi ?(lo=0) t = + let n = bitwidth t in + let z = Bitvec.to_bigint (payload t) in + let hi = Option.value ~default:(n-1) hi in + let len = hi-lo+1 in + if len <= 0 + then failwithf "Bitvector.extract: len %d is negative" len (); + if is_signed t && msb t + then pack Z.((minus_one lsl n) lor Z.extract z lo len) len + else pack (Z.extract z lo len) len module Cons = struct - let b0 = create (Bignum.of_int 0) 1 - let b1 = create (Bignum.of_int 1) 1 + let b0 = pack Z.zero 1 + let b1 = pack Z.one 1 let of_bool v = if v then b1 else b0 - let of_int32 ?(width=32) n = create (Bignum.of_int32 n) width - let of_int64 ?(width=64) n = create (Bignum.of_int64 n) width - let of_int ~width v = create (Bignum.of_int v) width + let of_int32 ?(width=32) n = pack (Z.of_int32 n) width + let of_int64 ?(width=64) n = pack (Z.of_int64 n) width + let of_int ~width v = pack (Z.of_int v) width let ones n = of_int (-1) ~width:n let zeros n = of_int (0) ~width:n let zero n = of_int 0 ~width:n @@ -258,32 +293,40 @@ include Cons let safe f t = try_with (fun () -> f t) -let to_int_exn = unop Bignum.to_int -let to_int32_exn = unop Bignum.to_int32 -let to_int64_exn = unop Bignum.to_int64 +let to_int_exn x = Bitvec.to_int (payload x) +let to_int32_exn x = Bitvec.to_int32 (payload x) +let to_int64_exn x = Bitvec.to_int64 (payload x) let to_int = safe to_int_exn let to_int32 = safe to_int32_exn let to_int64 = safe to_int64_exn - let of_binary ?width endian num = let num = match endian with | LittleEndian -> num | BigEndian -> String.rev num in let w = Option.value width ~default:(String.length num * 8) in - create (Bignum.of_bits num) w + pack (Z.of_bits num) w -let nsucc t n = with_z t Bignum.(z t + of_int n) -let npred t n = with_z t Bignum.(z t - of_int n) +let nsucc t n = Packed.lift1 t @@ fun t -> Bitvec.(nsucc t n) +let npred t n = Packed.lift1 t @@ fun t -> Bitvec.(npred t n) let (++) t n = nsucc t n let (--) t n = npred t n let succ n = n ++ 1 let pred n = n -- 1 -let gcd_exn = lift2 Bignum.gcd -let lcm_exn = lift2 Bignum.lcm -let gcdext_exn = lift2_triple Bignum.gcdext +let (%:) x w = pack (Bitvec.to_bigint x) w + +let gcd_exn x y = Packed.lift2 x y Bitvec.gcd +let lcm_exn x y = Packed.lift2 x y Bitvec.lcm +let gcdext_exn x y = + let w = bitwidth x in + let m = Bitvec.modulus w in + let x = payload x and y = payload y in + let (g,a,b) = Bitvec.(gcdext x y mod m) in + g %: w, + a %: w, + b %: w let gcd a b = Or_error.try_with (fun () -> gcd_exn a b) @@ -293,46 +336,47 @@ let gcdext a b = Or_error.try_with (fun () -> gcdext_exn a b) let concat x y = - let w = bitwidth x + bitwidth y in - let x = Bignum.(z x lsl bitwidth y) in - let z = Bignum.(x lor z y) in - create z w + let w1 = bitwidth x and w2 = bitwidth y in + let x = payload x and y = payload y in + Bitvec.append w1 w2 x y %: (w1+w2) let (@.) = concat module Unsafe = struct module Base = struct - type t = T.t - let one = create Z.one 1 - let zero = create Z.zero 1 - let succ = lift1 Bignum.succ - let pred = lift1 Bignum.pred - let abs = lift1 Bignum.abs - let neg = lift1 Bignum.neg - let lnot = lift1 Bignum.lognot - let logand = lift2 Bignum.logand - let logor = lift2 Bignum.logor - let logxor = lift2 Bignum.logxor - let add = lift2 Bignum.add - let sub = lift2 Bignum.sub - let mul = lift2 Bignum.mul - let sdiv = lift2 Bignum.div - let udiv = lift2 Bignum.ediv - let srem = lift2 Bignum.rem - let urem = lift2 Bignum.erem - - let sign_disp ~signed ~unsigned x y = - let op = if is_signed x || is_signed y then signed else unsigned in - op x y - - let div = sign_disp ~signed:sdiv ~unsigned:udiv - let rem = sign_disp ~signed:srem ~unsigned:urem - let modulo = rem - - let shift dir x n = create (dir (z x) (Z.to_int (z n))) (bitwidth x) - let lshift = shift Bignum.shift_left - let rshift = shift Bignum.shift_right - let arshift x y = shift Bignum.shift_right (signed x) y + type t = Packed.t + let one = Bitvec.one %: 1 + let zero = Bitvec.zero %: 1 + let succ = succ + let pred = pred + let abs x = lift1 x Bitvec.abs [@@inline] + let neg x = lift1 x Bitvec.neg [@@inline] + let lnot x = lift1 x Bitvec.lnot [@@inline] + let logand x y = lift2 x y Bitvec.logand [@@inline] + let logor x y = lift2 x y Bitvec.logor [@@inline] + let logxor x y = lift2 x y Bitvec.logxor [@@inline] + let add x y = lift2 x y Bitvec.add [@@inline] + let sub x y = lift2 x y Bitvec.sub [@@inline] + let mul x y = lift2 x y Bitvec.mul [@@inline] + let sdiv x y = lift2 x y Bitvec.sdiv [@@inline] + let udiv x y = lift2 x y Bitvec.div [@@inline] + let srem x y = lift2 x y Bitvec.srem [@@inline] + let urem x y = lift2 x y Bitvec.rem [@@inline] + let lshift x y = lift2 x y Bitvec.lshift [@@inline] + let rshift x y = lift2 x y Bitvec.rshift [@@inline] + let arshift x y = lift2 x y Bitvec.arshift [@@inline] + + let div x y = match is_signed x, is_signed y with + | true,_|_,true -> sdiv x y + | _ -> udiv x y + [@@inline] + + let rem x y = match is_signed x, is_signed y with + | true,_|_,true -> srem x y + | _ -> urem x y + [@@inline] + + let modulo x y = rem x y [@@inline] end include Base include (Bap_integer.Make(Base) : Bap_integer.S with type t := t) @@ -340,25 +384,29 @@ end module Safe = struct include Or_error.Monad_infix + let (!$) v = Ok v - type m = t Or_error.t - let (!$) v = Ok v + let badwidth m x = + Or_error.errorf "Word - wrong width, expects %d got %d" m x - let validate_equal (n,m) : Validate.t = - if m = n then Validate.pass - else Validate.failf "expected width %d, but got %d" m n + let lift m x = + let w = bitwidth x in + if m = w then Ok x else badwidth m w - let lift m t : m = - validate_equal (m, bitwidth t) |> Validate.result >>| - fun () -> t + let lift1 op x = match x with + | Ok x -> Ok (op x) + | Error _ as e -> e - let lift1 op x : m = x >>| lift1 op + let lift2 op x y = match x, y with + | Ok x, Ok y -> + let w1 = bitwidth x and w2 = bitwidth y in + if w1 = w2 then Ok (op x y) else badwidth w1 w2 + | (Error _ as e),_ | _, (Error _ as e) -> e - let lift2 op (x : m) (y : m) : m = - x >>= fun x -> y >>= fun y -> - let v = validate_equal (bitwidth x, bitwidth y) in - Validate.result v >>| fun () -> op x y + let lift2h op x y = match x, y with + | Ok x, Ok y -> Ok (op x y) + | (Error _ as e),_ | _, (Error _ as e) -> e let int = lift let i1 = lift 1 @@ -372,16 +420,14 @@ module Safe = struct | Word_size.W32 -> i32 module Base = struct - type t = m - let one = i1 (one 1) - let zero = i1 (zero 1) - let succ = lift1 Bignum.succ - let pred = lift1 Bignum.pred - let abs = lift1 Bignum.abs - let neg = lift1 Bignum.neg - - let lnot = lift1 Bignum.lognot - + type nonrec t = t Or_error.t + let one = i1 Unsafe.one + let zero = i1 Unsafe.zero + let succ = lift1 Unsafe.succ + let pred = lift1 Unsafe.pred + let abs = lift1 Unsafe.abs + let neg = lift1 Unsafe.neg + let lnot = lift1 Unsafe.lnot let logand = lift2 Unsafe.logand let logor = lift2 Unsafe.logor let logxor = lift2 Unsafe.logxor @@ -392,34 +438,20 @@ module Safe = struct let udiv = lift2 Unsafe.udiv let srem = lift2 Unsafe.rem let urem = lift2 Unsafe.urem - - let sign_disp ~signed ~unsigned x y = - x >>= fun x -> y >>= fun y -> - let op = if is_signed x || is_signed y then signed else unsigned in - op !$x !$y - - let div = sign_disp ~signed:sdiv ~unsigned:udiv - let rem = sign_disp ~signed:srem ~unsigned:urem - let modulo = rem - - let shift dir (x : m) (y : m) : m = - x >>= fun x -> y >>= fun y -> - if unop Bignum.fits_int y - then Ok (dir x y) - else Or_error.errorf - "cannot perform shift, because rhs doesn't fit int: %s" @@ - to_string y - - let lshift = shift Unsafe.lshift - let rshift = shift Unsafe.rshift - let arshift = shift Unsafe.arshift + let sdiv = lift2 Unsafe.sdiv + let div = lift2 Unsafe.div + let rem = lift2 Unsafe.rem + let lshift = lift2h Unsafe.lshift + let rshift = lift2h Unsafe.rshift + let arshift = lift2h Unsafe.arshift + let modulo = rem end include Bap_integer.Make(Base) end module Int_exn = struct module Base = struct - type t = T.t + type t = Packed.t let one = one 1 let zero = zero 1 @@ -453,11 +485,11 @@ let extract_exn = extract let extract ?hi ?lo x = Or_error.try_with (fun () -> extract_exn ?hi ?lo x) -let is_zero = unop Bignum.(equal zero) -let is_one = unop Bignum.(equal one) -let is_positive = unop Bignum.(fun z -> gt z zero) +let is_zero x = Bitvec.compare (payload x) Bitvec.zero = 0 +let is_one x = Bitvec.compare (payload x) Bitvec.one = 0 +let is_negative x = is_signed x && msb x +let is_positive x = not (is_zero x) && not (is_negative x) let is_non_positive = Fn.non is_positive -let is_negative = unop Bignum.(fun z -> lt z zero) let is_non_negative = Fn.non is_negative @@ -466,13 +498,13 @@ let validate check msg x = else Validate.fails msg x sexp_of_t let validate_positive = - validate is_positive "should be positive" + validate is_positive "expects a positive number" let validate_non_positive = - validate is_non_positive "should be non positive" + validate is_non_positive "expects a non positive number" let validate_negative = - validate is_negative "should be negative" + validate is_negative "expects a negative number" let validate_non_negative = - validate is_non_negative "should be non negative" + validate is_non_negative "expects a non negative number" let enum_chars t endian = let open Sequence in @@ -511,7 +543,14 @@ let bits_of_byte byte = let enum_bits bv endian = enum_chars bv endian |> Sequence.map ~f:bits_of_byte |> Sequence.concat -module Mono = Comparable.Make(Make(Size_mono)) +module Mono = Comparable.Make(struct + type t = packed [@@deriving sexp] + let compare x y = + if phys_equal x y then 0 + else match Int.compare (bitwidth x) (bitwidth y) with + | 0 -> compare_mono x y + | _ -> failwith "Non monomorphic comparison" + end) module Trie = struct module Common = struct @@ -557,28 +596,24 @@ module Trie = struct end include Or_error.Monad_infix -include Regular.Make(T) +include Regular.Make(struct + type t = packed [@@deriving bin_io, sexp] + let compare x y = + if phys_equal x y then 0 + else match Int.compare (bitwidth x) (bitwidth y) with + | 0 -> compare_mono x y + | r -> r + [@@inline] + let version = "2.0.0" + let module_name = Some "Bap.Std.Word" + let pp ppf = pp_generic ppf + let hash = Packed.hash + end) module Int_err = Safe include (Unsafe : Bap_integer.S with type t := t) let one = Cons.one let zero = Cons.zero -let pp_hex ppf = pp_generic ppf -let pp_dec ppf = pp_generic ~format:`dec ppf -let pp_oct ppf = pp_generic ~format:`oct ppf -let pp_bin ppf = pp_generic ~format:`bin ppf - -let pp_hex_full ppf = pp_generic ~suffix:`full ppf -let pp_dec_full ppf = pp_generic ~format:`dec ~suffix:`full ppf -let pp_oct_full ppf = pp_generic ~format:`oct ~suffix:`full ppf -let pp_bin_full ppf = pp_generic ~format:`bin ~suffix:`full ppf - -let string_of_value ?(hex=true) x = - if hex - then asprintf "%a" (fun p -> pp_generic ~prefix:`none ~case:`lower p) x - else asprintf "%a" (fun p -> pp_generic ~format:`dec p) x - - (* old representation for backward compatibility. *) module V1 = struct module Bignum = struct @@ -602,39 +637,41 @@ end (* stable serialization protocol *) module Stable = struct module V1 = struct - type t = bignum + type t = Packed.t let compare = compare let of_legacy {V1.z; w; signed=s} = - let x = create z w in - if s then signed x else x + let x = pack z w in + if s then Packed.signed x else x let to_legacy x = V1.{ - z = z x; + z = Bitvec.to_bigint (payload x); w = bitwidth x; signed = is_signed x; } include Binable.Of_binable(V1)(struct - type t = bignum + type t = Packed.t let to_binable = to_legacy let of_binable = of_legacy end) include Sexpable.Of_sexpable(V1)(struct - type t = bignum + type t = Packed.t let to_sexpable = to_legacy let of_sexpable = of_legacy end) end module V2 = struct - type nonrec t = t [@@deriving bin_io, compare, sexp] + type t = Packed.t [@@deriving bin_io, sexp] + let compare = compare end end -let pp = pp_hex +let to_string = string_of_word +let of_string = word_of_string let () = add_reader ~desc:"Janestreet Binary Protocol" ~ver:"1.0.0" "bin" @@ -645,6 +682,15 @@ let () = (Data.sexp_reader (module Stable.V1)); add_writer ~desc:"Janestreet Sexp Protocol" ~ver:"1.0.0" "sexp" (Data.sexp_writer (module Stable.V1)); + add_reader ~desc:"Janestreet Binary Protocol" ~ver:"2.0.0" "bin" + (Data.bin_reader (module Packed)); + add_writer ~desc:"Janestreet Binary Protocol" ~ver:"2.0.0" "bin" + (Data.bin_writer (module Packed)); + add_reader ~desc:"Janestreet Sexp Protocol" ~ver:"2.0.0" "sexp" + (Data.sexp_reader (module Sexp_hum)); + add_writer ~desc:"Janestreet Sexp Protocol" ~ver:"2.0.0" "sexp" + (Data.sexp_writer (module Sexp_hum)); + let add name desc pp = add_writer ~desc ~ver:"2.0.0" name (Data.Write.create ~pp ()) in add "hex" "Hexadecimal without a suffix" pp_hex; diff --git a/lib/bap_types/bap_bitvector.mli b/lib/bap_types/bap_bitvector.mli index 9b076bf3b..e9ffe7f70 100644 --- a/lib/bap_types/bap_bitvector.mli +++ b/lib/bap_types/bap_bitvector.mli @@ -7,11 +7,15 @@ type t type endian = | LittleEndian | BigEndian - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] include Regular.S with type t := t include Bap_integer.S with type t := t module Mono : Comparable.S with type t := t + +val create : Bitvec.t -> int -> t +val to_bitvec : t -> Bitvec.t + val of_string : string -> t val of_bool : bool -> t val of_int : width:int -> int -> t @@ -33,6 +37,8 @@ val signed : t -> t val unsigned : t -> t val is_zero : t -> bool val is_one : t -> bool +val msb : t -> bool +val lsb : t -> bool val bitwidth : t -> int val extract : ?hi:int -> ?lo:int -> t -> t Or_error.t val extract_exn : ?hi:int -> ?lo:int -> t -> t diff --git a/lib/bap_types/bap_common.ml b/lib/bap_types/bap_common.ml index 67252a6e4..2ac4a4ea7 100644 --- a/lib/bap_types/bap_common.ml +++ b/lib/bap_types/bap_common.ml @@ -21,7 +21,7 @@ module type Trie = Bap_trie_intf.S type endian = Bitvector.endian = LittleEndian | BigEndian - [@@deriving sexp, bin_io, compare] +[@@deriving sexp, bin_io, compare] module Size = struct @@ -35,23 +35,25 @@ module Size = struct | `r256 ] [@@deriving bin_io, compare, sexp, variants] - type 'a p = - 'a constraint 'a = [< all] [@@deriving bin_io, compare, sexp] + type 'a p = 'a constraint 'a = [< all] + [@@deriving bin_io, compare, sexp] type t = all p - [@@deriving bin_io, compare, sexp] + [@@deriving bin_io, compare, sexp] + + end (** size of operand *) type size = Size.t - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] (** size of address *) type addr_size = [ `r32 | `r64 ] Size.p - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] type nat1 = int - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] (** The IR type of a BIL expression *) module Type = struct @@ -60,11 +62,12 @@ module Type = struct | Imm of nat1 (** [Mem (a,t)]memory with a specified addr_size *) | Mem of addr_size * size - [@@deriving bin_io, compare, sexp, variants] + | Unk + [@@deriving bin_io, compare, sexp, variants] end type typ = Type.t - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] (** Supported architectures *) @@ -107,14 +110,14 @@ module Arch = struct | `aarch64 | `aarch64_be ] - [@@deriving bin_io, compare, enumerate, sexp] + [@@deriving bin_io, compare, enumerate, sexp] type ppc = [ | `ppc | `ppc64 | `ppc64le ] - [@@deriving bin_io, compare, enumerate, sexp] + [@@deriving bin_io, compare, enumerate, sexp] type mips = [ | `mips @@ -122,31 +125,31 @@ module Arch = struct | `mips64 | `mips64el ] - [@@deriving bin_io, compare, enumerate, sexp] + [@@deriving bin_io, compare, enumerate, sexp] type sparc = [ | `sparc | `sparcv9 ] - [@@deriving bin_io, compare, enumerate, sexp] + [@@deriving bin_io, compare, enumerate, sexp] type nvptx = [ | `nvptx | `nvptx64 ] - [@@deriving bin_io, compare, enumerate, sexp] + [@@deriving bin_io, compare, enumerate, sexp] type hexagon = [`hexagon] - [@@deriving bin_io, compare, enumerate, sexp] + [@@deriving bin_io, compare, enumerate, sexp] type r600 = [`r600] - [@@deriving bin_io, compare, enumerate, sexp] + [@@deriving bin_io, compare, enumerate, sexp] type systemz = [`systemz] - [@@deriving bin_io, compare, enumerate, sexp] + [@@deriving bin_io, compare, enumerate, sexp] type xcore = [`xcore] - [@@deriving bin_io, compare, enumerate, sexp] + [@@deriving bin_io, compare, enumerate, sexp] type t = [ | aarch64 @@ -171,10 +174,10 @@ end *) type arch = Arch.t - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] type word = Bap_bitvector.t - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] type addr = Bap_bitvector.t - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] diff --git a/lib/bap_types/bap_eval.ml b/lib/bap_types/bap_eval.ml index 6643d3628..c02c3e743 100644 --- a/lib/bap_types/bap_eval.ml +++ b/lib/bap_types/bap_eval.ml @@ -140,7 +140,12 @@ module Make2(State : Monad.S2) = struct with exn -> None method eval_let var u body = - self#eval_exp (Exp.Let (var,u,body)) + self#eval_exp u >>= fun u -> + self#lookup var >>= fun w -> + self#update var u >>= fun () -> + self#eval_exp body >>= fun r -> + self#update var w >>= fun () -> + State.return r method eval_unknown _ _ = self#bot diff --git a/lib/bap_types/bap_exp.ml b/lib/bap_types/bap_exp.ml index 99e42f24c..2ebdd086c 100644 --- a/lib/bap_types/bap_exp.ml +++ b/lib/bap_types/bap_exp.ml @@ -1,5 +1,7 @@ open Core_kernel open Regular.Std +open Bap_knowledge +open Bap_core_theory open Bap_common open Format open Bap_bil @@ -11,6 +13,7 @@ module Size = Bap_size type binop = exp -> exp -> exp type unop = exp -> exp + module PP = struct open Bap_bil @@ -56,21 +59,25 @@ module PP = struct type precendence = int let op_prec op = Binop.(match op with - | TIMES | DIVIDE | SDIVIDE | MOD| SMOD -> 8 - | PLUS | MINUS -> 7 - | LSHIFT | RSHIFT | ARSHIFT -> 6 - | LT|LE|SLT|SLE -> 5 - | EQ|NEQ -> 4 - | AND -> 3 - | XOR -> 2 - | OR -> 1) + | TIMES | DIVIDE | SDIVIDE | MOD| SMOD -> 8 + | PLUS | MINUS -> 7 + | LSHIFT | RSHIFT | ARSHIFT -> 6 + | LT|LE|SLT|SLE -> 5 + | EQ|NEQ -> 4 + | AND -> 3 + | XOR -> 2 + | OR -> 1) let prec x = Exp.(match x with - | Var _ | Int _ | Unknown _ -> 10 - | Load _ | Cast _ | Extract _ -> 10 - | UnOp _ -> 9 - | BinOp (op,x,y) -> op_prec op - | Store _ | Let _ | Ite _ | Concat _ -> 0) + | Var _ | Int _ | Unknown _ -> 10 + | Load _ | Cast _ | Extract _ -> 10 + | UnOp _ -> 9 + | BinOp (op,_,_) -> op_prec op + | Store _ | Let _ | Ite _ | Concat _ -> 0) + + let msb x = + let m = Bitvec.modulus (Word.bitwidth x) in + Bitvec.(msb (Word.to_bitvec x) mod m) let rec pp fmt exp = let open Bap_bil.Exp in @@ -88,7 +95,7 @@ module PP = struct pr "%a[%a]" pp mem pp idx | Load (mem, idx, edn, s) -> pr "%a[%a, %a]:%a" pp mem pp idx pp_edn edn Bap_size.pp s - | Store (mem, idx, exp, edn, `r8) -> + | Store (mem, idx, exp, _, `r8) -> pr "@[<4>%a@;with [%a] <- %a@]" pp mem pp idx pp exp | Store (mem, idx, exp, edn, s) -> @@ -106,6 +113,8 @@ module PP = struct pr ("%a" ^^ pfmt p e) pp_unop Unop.NOT pp e | BinOp (EQ,Int x, e) as p when is_b0 x -> pr ("%a" ^^ pfmt p e) pp_unop Unop.NOT pp e + | BinOp (PLUS,le,(Int x as re)) as p when msb x -> + pr (pfmt p le ^^ " - " ^^ pfmt p re) pp le Word.pp (Word.neg x) | BinOp (op, le, re) as p -> pr (pfmt p le ^^ " %a " ^^ pfmt p re) pp le pp_binop op pp re | UnOp (op, exp) as p -> @@ -233,6 +242,21 @@ module Infix = struct end + +let equal x y = phys_equal x y || compare x y = 0 +let to_string = Format.asprintf "%a" PP.pp +let domain = Knowledge.Domain.flat "exp" ~equal + ~empty:(Unknown ("empty",Unk)) + ~inspect:(fun exp -> Sexp.Atom (to_string exp)) + +let persistent = Knowledge.Persistent.of_binable (module struct + type t = Bap_bil.exp [@@deriving bin_io] + end) + +let slot = Knowledge.Class.property ~package:"bap.std" + ~persistent Theory.Value.cls "exp" domain + ~desc:"semantics of expressions in BIL" + include Regular.Make(struct type t = Bap_bil.exp [@@deriving bin_io, compare, sexp] let hash = Hashtbl.hash diff --git a/lib/bap_types/bap_exp.mli b/lib/bap_types/bap_exp.mli index e24435ab3..714be52d7 100644 --- a/lib/bap_types/bap_exp.mli +++ b/lib/bap_types/bap_exp.mli @@ -1,6 +1,8 @@ (** Extends [exp] interface. *) open Core_kernel +open Bap_core_theory open Regular.Std +open Bap_knowledge open Bap_common open Bap_bil open Format @@ -99,3 +101,5 @@ module Infix : sig (** [a ^ b] contatenate [a] and [b] *) val ( ^ ) : exp -> exp -> exp end + +val slot : (Theory.Value.cls, exp) Knowledge.slot diff --git a/lib/bap_types/bap_helpers.ml b/lib/bap_types/bap_helpers.ml index 3725f56db..d9bfca2f4 100644 --- a/lib/bap_types/bap_helpers.ml +++ b/lib/bap_types/bap_helpers.ml @@ -90,10 +90,10 @@ class substitution x y = object(self) inherit bil_mapper as super method! map_let z ~exp ~body = if Bap_var.(z = x) - then super#map_let z ~exp:(self#map_exp exp) ~body - else super#map_let z ~exp:(self#map_exp exp) ~body: - (super#map_exp body) - + then Let (z,self#map_exp exp,body) + else super#map_let z + ~exp:(self#map_exp exp) + ~body:(self#map_exp body) method! map_var z = match super#map_var z with | Exp.Var z when Bap_var.(z = x) -> y @@ -209,16 +209,17 @@ module Type = struct and binop op x y = match op with | LSHIFT|RSHIFT|ARSHIFT -> shift x y | _ -> match unify x y with - | Type.Mem _ -> Type_error.expect_imm () + | Type.Mem _ | Type.Unk -> Type_error.expect_imm () | Type.Imm _ as t -> match op with | LT|LE|EQ|NEQ|SLT|SLE -> Type.Imm 1 | _ -> t and shift x y = match infer x, infer y with - | Type.Mem _,_ | _,Type.Mem _ -> Type_error.expect_imm () + | Type.Mem _,_ | _,Type.Mem _ + | Type.Unk,_ | _,Type.Unk -> Type_error.expect_imm () | t, Type.Imm _ -> t and load m a r = match infer m, infer a with - | Type.Imm _,_ -> Type_error.expect_mem () - | _,Type.Mem _ -> Type_error.expect_imm () + | (Type.Imm _|Unk),_ -> Type_error.expect_mem () + | _,(Type.Mem _|Unk) -> Type_error.expect_imm () | Type.Mem (s,_),Type.Imm s' -> let s = Size.in_bits s in if s = s' then Type.Imm (Size.in_bits r) @@ -237,12 +238,12 @@ module Type = struct and cast c s x = let t = Type.Imm s in match c,infer x with - | _,Type.Mem _ -> Type_error.expect_imm () + | _,(Type.Mem _|Unk) -> Type_error.expect_imm () | (UNSIGNED|SIGNED),_ -> t | (HIGH|LOW), Type.Imm s' -> if s' >= s then t else Type_error.wrong_cast () and extract hi lo x = match infer x with - | Type.Mem _ -> Type_error.expect_imm () + | Type.Mem _ | Unk -> Type_error.expect_imm () | Type.Imm _ -> (* we don't really need a type of x, as the extract operation can both narrow and widen. Though it is a question whether it is @@ -278,9 +279,9 @@ module Type = struct | Ok u -> Some (Type_error.bad_type ~exp:t ~got:u) | Error err -> Some err and jmp x = match infer x with - | Ok (Imm s) when Result.is_ok (Size.addr_of_int s) -> None + | Ok (Imm _) -> None | Ok (Mem _) -> Some Type_error.bad_imm - | Ok (Imm _) -> Some Type_error.bad_cast + | Ok Unk -> Some Type_error.unknown | Error err -> Some err and cond x = match infer x with | Ok (Imm 1) -> None @@ -325,7 +326,7 @@ module Eff = struct let width x = match Type.infer_exn x with | Type.Imm x -> x - | Type.Mem _ -> failwith "width is not for memory" + | _ -> failwith "expected an immediate type" (* approximates a number of non-zero bits in a bitvector. *) module Nz = struct @@ -419,7 +420,7 @@ module Eff = struct | UnOp (_,x) | Cast (_,_,x) | Extract (_,_,x) -> eff x - | Let (_,_,_) -> assert false (* must be let-normalized *) + | Let (_,x,y) -> all [eff x; eff y] and div y = match Nz.bits y with | Nz.Maybe | Nz.Empty -> raise | _ -> none @@ -432,6 +433,15 @@ module Eff = struct let raises t = Set.mem t Raises let of_list : t list -> t = Set.Poly.union_list end + +class rewriter x y = object + inherit bil_mapper as super + method! map_exp z = + let z = super#map_exp z in + if Bap_exp.(z = x) then y else z +end + + module Simpl = struct open Bap_bil open Binop @@ -444,7 +454,9 @@ module Simpl = struct let zero width = Int (Word.zero width) let ones width = Int (Word.ones width) let nothing _ = false - + let subst x y = + let r = new substitution x y in + r#map_exp (* requires: let-free, simplifications( constant-folding, @@ -466,27 +478,33 @@ module Simpl = struct | UnOp (op,x) -> unop op x | Var _ | Int _ | Unknown (_,_) as const -> const | Cast (t,s,x) -> cast t s x - | Let (v,x,y) -> Let (v, exp x, exp y) - | Ite (x,y,z) -> Ite (exp x, exp y, exp z) + | Let (v,x,y) -> let_ v x y + | Ite (x,y,z) -> ite_ x y z | Extract (h,l,x) -> extract h l x | Concat (x,y) -> concat x y + and ite_ c x y = match exp c, exp x, exp y with + | Int c,x,y -> if Bitvector.(c = b1) then x else y + | c,x,y -> Ite (c,x,y) + and let_ v x y = match exp x with + | Int _ | Unknown _ as r -> exp (subst v r y) + | r -> Let(v,r,exp y) and concat x y = match exp x, exp y with | Int x, Int y -> Int (Word.concat x y) | x,y -> Concat (x,y) and cast t s x = match exp x with | Int w -> Int (Apply.cast t s w) - | _ -> Cast (t,s,x) + | x -> Cast (t,s,x) and extract hi lo x = match exp x with | Int w -> Int (Bitvector.extract_exn ~hi ~lo w) | x -> Extract (hi,lo,x) and unop op x = match exp x with - | UnOp(op,Int x) -> Int (Apply.unop op x) + | Int x -> Int (Apply.unop op x) | UnOp(op',x) when op = op' -> exp x | x -> UnOp(op, x) and binop op x y = let width = match Type.infer_exn x with | Type.Imm s -> s - | Type.Mem _ -> failwith "binop" in + | _ -> failwith "binop" in let keep op x y = BinOp(op,x,y) in let int f = function Int x -> f x | _ -> false in let is0 = int is0 and is1 = int is1 and ism1 = int ism1 in @@ -507,7 +525,7 @@ module Simpl = struct | (MOD|SMOD),_,y when is1 y -> zero width | (LSHIFT|RSHIFT|ARSHIFT),x,y when is0 y -> x | (LSHIFT|RSHIFT|ARSHIFT),x,_ when is0 x -> x - | (LSHIFT|RSHIFT|ARSHIFT),x,_ when ism1 x -> x + | ARSHIFT,x,_ when ism1 x -> x | AND,x,y when is0 x && removable y -> x | AND,x,y when is0 y && removable x -> y | AND,x,y when ism1 x -> y @@ -564,12 +582,6 @@ let fix compare f x = let fixpoint = fix compare_bil -class rewriter x y = object - inherit bil_mapper as super - method! map_exp z = - let z = super#map_exp z in - if Bap_exp.(z = x) then y else z -end let substitute x y = (new rewriter x y)#run @@ -676,23 +688,19 @@ module Normalize = struct (* we don't need a full-fledged type inference here. requires: well-typed exp *) - let infer_addr_size exp = - let open Exp in - let rec infer = function + let infer_storage_type exp = + let rec infer : exp -> typ = function | Var v -> Var.typ v - | Store (m,_,_,_,_) -> infer m - | Ite (_,x,y) -> both x y | Unknown (_,t) -> t - | _ -> invalid_arg "type error" - and both x y = - match infer x, infer y with - | t1,t2 when Type.(t1 = t2) -> t1 + | Store (m,_,_,_,_) | Ite (_,m,_) | Let (_,_,m) -> infer m | _ -> invalid_arg "type error" in match infer exp with - | Type.Mem (s,_) -> s - | Type.Imm _ -> invalid_arg "type error" + | Type.Mem (ks,vs) -> ks,vs + | _ -> invalid_arg "type error" + let infer_addr_size x = fst (infer_storage_type x) + let infer_value_size x = snd (infer_storage_type x) let make_succ m = let int n = @@ -701,6 +709,7 @@ module Normalize = struct let sum a n = Exp.BinOp (Binop.PLUS, a,int n) in sum + (* rewrite_store_little Store(m,a,x,e,s) => Store(..(Store(Store(m,a,x[0],e,1),a+1,x[1],e,1))..,a+s-1,x[s-1],e,1) @@ -712,37 +721,36 @@ module Normalize = struct *) let expand_store m a x e s = + let vs = infer_value_size m in let (++) = make_succ m in let n = Size.in_bytes s in let nth i = if e = BigEndian then nth (n-i-1) else nth i in let rec expand i = if i >= 0 - then Exp.Store(expand (i-1),(a++i),nth i x,LittleEndian,`r8) + then Exp.Store(expand (i-1),(a++i),nth i x,LittleEndian,vs) else m in - if s = `r8 then Exp.Store (m,a,x,e,s) + if Size.equal vs s then Exp.Store (m,a,x,e,s) else expand (n-1) (* x[a,el]:n => x[a+n-1] @ ... @ x[a] x[a,be]:n => x[a] @ ... @ x[a+n-1] - This operation duplicates the address expression, this may break - semantics if this expression is non-generative. - Special care should be taken if the expression contains store operations, that has an effect that may interfere with the result of the load operation. *) let expand_load m a e s = + let vs = infer_value_size m in let (++) = make_succ m in let cat x y = if e = LittleEndian then Exp.Concat (y,x) else Exp.Concat (x,y) in - let load a = Exp.Load (m,a,e,`r8) in + let load a = Exp.Load (m,a,e,vs) in let rec expand a i = if i > 1 then cat (load a) (expand (a++1) (i-1)) else load a in - if s = `r8 then load a + if Size.equal vs s then load a else expand a (Size.in_bytes s) let expand_memory = map_exp @@ object @@ -758,19 +766,19 @@ module Normalize = struct expand_store mem addr x e s end - (* ensures: no-lets, one-byte-stores, one-byte-loads. + (* ensures: one-byte-stores, one-byte-loads. This is the first step of normalization. The full normalization, e.g., remove ite and hoisting storages can be only done on the BIL level. requires: - - generative-load-addr, - - generative-store-mem, - - generative-store-val, - - generative-let-value + - generative-load-addr, + - generative-store-mem, + - generative-store-val, + - generative-let-value *) - let normalize_exp x = expand_memory (reduce_let x) + let normalize_exp x = expand_memory x type assume = Assume of (exp * bool) @@ -1031,8 +1039,8 @@ module Normalize = struct (* ensures: all while conditions are free from: - - ite expressions; - - store operations. + - ite expressions; + - store operations. Note the latter sounds more strong then the implementation, but it is true, as in a well-typed program a conditional must has diff --git a/lib/bap_types/bap_ir.ml b/lib/bap_types/bap_ir.ml index fa3dc506d..5accceb4f 100644 --- a/lib/bap_types/bap_ir.ml +++ b/lib/bap_types/bap_ir.ml @@ -1,8 +1,13 @@ +let package = "bap.std" + open Core_kernel +open Bap_core_theory open Regular.Std open Bap_common open Bap_bil +open Bap_knowledge +module Toplevel = Bap_toplevel module Value = Bap_value module Dict = Value.Dict module Vec = Bap_vector @@ -30,70 +35,110 @@ type dict = Dict.t [@@deriving bin_io, compare, sexp] type 'a vector = 'a Vec.t module Tid = struct - exception Overrun - type t = Int63.t [@@deriving bin_io, compare, sexp] + open KB.Syntax + type t = Theory.Label.t [@@deriving bin_io, compare, sexp] + let last = Toplevel.var "last" + let name = Toplevel.var "name" + let repr = Toplevel.var "repr" + let ivec = Toplevel.var "ivec" + let addr = Toplevel.var "addr" - module Tid_generator = Bap_state.Make(struct - type t = Int63.t ref - let create () = ref (Int63.zero) - end) - let create = - fun () -> - let last_tid = !Tid_generator.state in - Int63.incr last_tid; - if last_tid.contents = Int63.zero - then raise Overrun; - last_tid.contents - - let nil = Int63.zero - module Tid = Regular.Make(struct - type nonrec t = Int63.t [@@deriving bin_io, compare, sexp] - let module_name = Some "Bap.Std.Tid" - let version = "1.0.0" + let generate f x = + Toplevel.put last (f x); + Toplevel.get last - let hash = Int63.hash + let for_ivec s = generate Theory.Label.for_ivec s + let for_addr s = generate Theory.Label.for_addr @@ + Bap_bitvector.to_bitvec s - let pp ppf tid = - Format.fprintf ppf "%08Lx" (Int63.to_int64 tid) - let to_string tid = Format.asprintf "%a" pp tid - end) - module Name_resolver = Bap_state.Make(struct - type t = string Tid.Table.t - let create () = Tid.Table.create () - end) + let set slot tid name = Toplevel.exec begin + KB.provide slot tid (Some name) + end - let names = Name_resolver.state + let set_addr = set Theory.Label.addr + let set_ivec = set Theory.Label.ivec - let rev_lookup name = - Hashtbl.to_alist !names |> List.find_map ~f:(fun (tid,x) -> - Option.some_if (x = name) tid) |> function - | None -> invalid_argf "unbound name: %s" name () - | Some name -> name - let from_string_exn str = match str.[0] with - | '%' -> Scanf.sscanf str "%%%X" (Int63.of_int) - | '@' -> Scanf.sscanf str "@%s" rev_lookup - | _ -> invalid_arg "label should start from '%' or '@'" + let get slot tid = Toplevel.eval slot (Knowledge.return tid) - let from_string str = Or_error.try_with ~backtrace:true (fun () -> - from_string_exn str) + let get_name = get Theory.Label.name + (* let get_addr = get Theory.Label.addr addr *) + let get_ivec = get Theory.Label.ivec - let set_name tid name = - Hashtbl.set !names ~key:tid ~data:name + let add_name tid name = Toplevel.exec begin + KB.provide Theory.Label.aliases tid @@ + Set.singleton (module String) name + end - let name tid = match Hashtbl.find !names tid with - | None -> Format.asprintf "%%%a" Tid.pp tid + let set_name tid name = + set Theory.Label.name tid name; + add_name tid name + + + let for_name s = + let t = generate Theory.Label.for_name s in + set_name t s; + t + + let intern n = + Toplevel.put name begin + KB.Symbol.intern n Theory.Program.cls >>= fun t -> + KB.provide Theory.Label.name t (Some n) >>| fun () -> + t + end; + Toplevel.get name + + let repr tid = + Toplevel.put repr (KB.Object.repr Theory.Program.cls tid); + Toplevel.get repr + + let parse name = + Toplevel.put last begin + KB.Object.read Theory.Program.cls name + end; + Toplevel.get last + + let create () = + Toplevel.put last begin + KB.Object.create Theory.Program.cls + end; + Toplevel.get last + + let to_string : t -> string = fun tid -> + Format.asprintf "%%%08Lx" (Int63.to_int64 (KB.Object.id tid)) + + let of_string : string -> t = fun str -> + if String.is_empty str + then intern str + else match str.[0] with + | '%' -> parse @@ sprintf "#<%s 0x%s>" + (KB.Class.fullname Theory.Program.Semantics.cls) + (String.subo ~pos:1 str) + | '@' -> intern (String.subo ~pos:1 str) + | _ -> intern str + + let nil = create () + + let pp ppf tid = Format.fprintf ppf "%s" (to_string tid) + + let name t = match get_name t with + | None -> to_string t | Some name -> sprintf "@%s" name - module State = struct - let set_name_resolver resolver = names := resolver - end - - let (!!) = from_string_exn - include Tid + let from_string_exn = of_string + let from_string x = Ok (from_string_exn x) + let (!!) = of_string + include Regular.Make(struct + type t = Theory.Label.t [@@deriving bin_io, compare, sexp] + let module_name = Some "Bap.Std.Tid" + let version = "2.0.0" + let hash x = Int63.hash (KB.Object.id x) + let pp = pp + let to_string tid = to_string tid + end) end type tid = Tid.t [@@deriving bin_io, compare, sexp] @@ -107,24 +152,123 @@ type 'a term = { type label = | Direct of tid | Indirect of exp - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] type call = {target : label; return : label option} - [@@deriving bin_io, compare, fields, sexp] +[@@deriving bin_io, compare, fields, sexp] type jmp_kind = | Call of call | Goto of label | Ret of label | Int of int * tid - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] type intent = In | Out | Both [@@deriving bin_io, compare, sexp] -type jmp = (exp * jmp_kind) [@@deriving bin_io, compare, sexp] -type def = (var * exp) [@@deriving bin_io, compare, sexp] -type phi = (var * exp Tid.Map.t) [@@deriving bin_io, compare, sexp] +module Rhs : sig + type top = (Theory.Value.cls, unit) KB.cls + type t = top Knowledge.value [@@deriving bin_io, compare, sexp] + + val empty : t + val of_value : 'a Theory.value -> t + val of_exp : exp -> t + val with_exp : exp -> t -> t + val exp : t -> exp + + include Base.Comparable.S with type t := t +end = struct + type top = (Theory.Value.cls, unit) KB.cls + let cls : top = Theory.Value.cls + + let forget v = KB.Value.refine v () + + let empty = KB.Value.empty cls + + let of_value x = forget x + + let of_exp exp = + KB.Value.put Exp.slot empty exp + + let with_exp exp x = + KB.Value.put Exp.slot x exp + + let exp x = KB.Value.get Exp.slot x [@@inline] + include (val KB.Value.derive cls) +end + + +let to_var v = Theory.Var.create (Var.sort v) (Var.ident v) + +module Def : sig + type t = { + var : unit Theory.var; + rhs : Rhs.t; + } [@@deriving bin_io, compare, sexp] + + val reify : 'a Theory.var -> 'a Theory.value -> t + + val of_bil : var -> exp -> t +end = struct + + type t = { + var : Theory.Var.Top.t; + rhs : Rhs.t; + } [@@deriving bin_io, compare, sexp] + + let reify lhs rhs = { + var = Theory.Var.forget lhs; + rhs = Rhs.of_value rhs; + } + + let of_bil var exp = { + var = to_var var; + rhs = Rhs.of_exp exp; + } +end + +module Phi = struct + type t = { + var : Theory.Var.Top.t; + map : Rhs.t Tid.Map.t; + } [@@deriving bin_io, compare, sexp] +end + +module Cnd : sig + type t = Theory.Bool.t Theory.value + [@@deriving bin_io, compare, sexp] + + val of_exp : exp -> t + val exp : t -> exp + + include Base.Comparable.S with type t := t +end = struct + + let empty = Theory.Value.empty Theory.Bool.t + let of_exp = KB.Value.put Exp.slot empty + let exp = KB.Value.get Exp.slot + let cls = KB.Class.refine Theory.Value.cls Theory.Bool.t + include (val KB.Value.derive cls) +end + +module Jmp = struct + type dst = Resolved of tid + | Indirect of { + vec : Rhs.t; + len : int; + } + [@@deriving bin_io, compare, sexp] + type t = { + cnd : Cnd.t option; + dst : dst option; + alt : dst option; + } [@@deriving bin_io, compare, sexp] +end + +type jmp = Jmp.t [@@deriving bin_io, compare, sexp] +type def = Def.t [@@deriving bin_io, compare, sexp] +type phi = Phi.t [@@deriving bin_io, compare, sexp] type blk = { phis : phi term array; @@ -132,7 +276,7 @@ type blk = { jmps : jmp term array; } [@@deriving bin_io, compare, fields, sexp] -type arg = var * exp * intent option +type arg = Def.t [@@deriving bin_io, compare, sexp] type sub = { @@ -171,7 +315,8 @@ end = struct let mangle_sub s = let addr = Dict.find s.dict Bap_attributes.address in let name = mangle_name addr s.tid s.self.name in - Tid.set_name s.tid name; + Tid.add_name s.tid s.self.name; + Tid.add_name s.tid name; let self = {s.self with name} in {s with self} @@ -252,23 +397,6 @@ let pp_attr ppf attr = let pp_attrs ppf dict = Dict.data dict |> Seq.iter ~f:(pp_attr ppf) -module Leaf = struct - let create ?(tid=Tid.create ()) lhs rhs = { - tid; - self = (lhs,rhs); - dict = Value.Dict.empty; - } - - let make tid exp dst = { - tid; self = (exp,dst); dict = Value.Dict.empty - } - - let lhs {self=(x,_)} = x - let rhs {self=(_,x)} = x - let with_lhs def lhs = {def with self = (lhs, snd def.self)} - let with_rhs def rhs = {def with self = (fst def.self, rhs)} -end - type nil [@@deriving bin_io, compare, sexp] type _ typ = @@ -311,7 +439,9 @@ let cls typ par nil field = { let hash_of_term t = Tid.hash (tid t) -let make_term tid self : 'a term = {tid; self; dict = Dict.empty} +let make_term tid self : 'a term = { + tid; self; dict = Dict.empty; +} let nil_top = make_term Tid.nil (Program.empty ()) @@ -323,20 +453,48 @@ let program_t = { get = (fun _ -> assert false); } -let nil_def : def term = - Leaf.make Tid.nil undefined_var undefined_exp +module Void : sig + type t + val t : t Theory.Value.sort +end = struct + let unsorted = Theory.Value.Sort.Name.declare ~package "Void" + type unsorted + type t = unsorted Theory.Value.Sort.sym + let t = Theory.Value.Sort.sym unsorted +end + +let undefined_variable = + Theory.Var.(forget @@ define Void.t "undefined") + +let undefined_semantics = + Rhs.of_value @@ Theory.Value.empty Void.t + +let empty self = { + tid = Tid.nil; + dict = Value.Dict.empty; + self +} + +let nil_def : def term = empty Def.{ + var = undefined_variable; + rhs = undefined_semantics; + } -let nil_phi : phi term = - Leaf.make Tid.nil undefined_var Tid.Map.empty +let nil_phi : phi term = empty Phi.{ + var = undefined_variable; + map = Map.empty (module Tid); + } -let nil_jmp : jmp term = - Leaf.make Tid.nil undefined_exp (Goto (Direct Tid.nil)) +let nil_jmp : jmp term = empty Jmp.{ + cnd = None; + dst = None; + alt = None; + } let nil_blk : blk term = - make_term Tid.nil {phis=[| |] ; defs = [| |] ; jmps = [| |] } + make_term Tid.nil {phis= [| |] ; defs = [| |] ; jmps = [| |] } -let nil_arg : arg term = - make_term Tid.nil (undefined_var,undefined_exp,None) +let nil_arg : arg term = nil_def let nil_sub : sub term = make_term Tid.nil { name = "undefined"; blks = [| |] ; args = [| |]} @@ -360,11 +518,14 @@ let term_pp pp_self ppf t = let attrs = Dict.data t.dict in Seq.iter attrs ~f:(fun attr -> pp_open_tag ppf (asprintf "%a" pp_attr attr)); - fprintf ppf "@[%a: %a@]@." Tid.pp t.tid pp_self t.self; + fprintf ppf "@[%08Lx: %a@]@." (Int63.to_int64 (KB.Object.id t.tid)) + pp_self t.self; Seq.iter attrs ~f:(fun _ -> pp_close_tag ppf ()) - - +let pp_value slots ppf x = + match slots with + | [] -> KB.Value.pp ppf x + | slots -> KB.Value.pp_slots slots ppf x module Label = struct type t = label @@ -384,7 +545,8 @@ module Label = struct let hash = Hashtbl.hash let pp ppf = function | Indirect exp -> Bap_exp.pp ppf exp - | Direct tid -> Format.fprintf ppf "%s" @@ Tid.name tid + | Direct tid -> Format.fprintf ppf "%s" @@ + Tid.name tid end) end @@ -416,23 +578,64 @@ module Call = struct end) end + + module Ir_arg = struct type t = arg term - let create ?(tid=Tid.create()) ?intent var exp : t = - make_term tid (var,exp,intent) - - let lhs {self=(r,_,_)} = r - let rhs {self=(_,r,_)} = r - let intent {self=(_,_,r)} = r - let map3 (x,y,z) ~f = (x,y, f z) - let with_intent (t : t) intent : t = { - t with self = map3 t.self ~f:(fun _ -> Some intent) - } - let with_unknown_intent t : t = { - t with self = map3 t.self ~f:(fun _ -> None) + + module Intent = struct + module T = struct + type t = intent option [@@deriving bin_io] + end + + let equal_intent x y = compare_intent x y = 0 + + let domain = + KB.Domain.optional ~equal:equal_intent ~inspect:sexp_of_intent "intent" + let persistent = KB.Persistent.of_binable (module T) + let slot = KB.Class.property ~package ~persistent + Theory.Value.cls "arg-intent" domain + let set intent x = match intent with + | None -> x + | Some intent -> KB.Value.put slot x intent + let get x = KB.Value.get slot x + end + + let set_intent ({self={Def.rhs} as self} as t) intent : t = { + t with self = { + self with rhs = Intent.set (Some intent) rhs + } } + + let var {self={Def.var}} = var + let value {self={Def.var; rhs}} = + let sort = Theory.Var.sort var in + KB.Value.refine rhs sort + + let with_intent arg intent = set_intent arg (Some intent) + + let reify ?(tid=Tid.create()) ?intent lhs rhs = + set_intent (make_term tid @@ Def.reify lhs rhs) intent + + let create ?(tid=Tid.create()) ?intent var exp : t = + set_intent (make_term tid @@ Def.of_bil var exp) intent + let lhs {self={Def.var}} = Var.reify var + let rhs {self={Def.rhs}} = Rhs.exp rhs + let intent {self={Def.rhs}} = KB.Value.get Intent.slot rhs + + let with_unknown_intent t : t = set_intent t None + let name arg = Var.name (lhs arg) + let with_rhs ({self={Def.var}} as t) rhs = { + t with self = Def.{var; rhs} + } + + let string_of_intent = function + | Some In -> "in " + | Some Out -> "out " + | Some Both -> "in out " + | None -> "" let warn_unused = Bap_value.Tag.register (module Unit) ~name:"warn-unused" @@ -454,35 +657,65 @@ module Ir_arg = struct ~name:"nonnull" ~uuid:"3c0a6181-9a9c-4cf4-aa37-8ceebd773952" + let pp_sort ppf var = match Var.typ (Var.reify var) with + | Unk -> Theory.Value.Sort.pp ppf (Theory.Var.sort var) + | typ -> Bap_type.pp ppf typ - include Regular.Make(struct - type t = arg term [@@deriving bin_io, compare, sexp] - let module_name = Some "Bap.Std.Arg" - let version = "1.0.0" + let pp_self pp_rhs ppf {Def.var; rhs} = + Format.fprintf ppf "%s :: %s%a = %a" + (Theory.Var.name var) + (string_of_intent @@ Intent.get rhs) + pp_sort var + pp_rhs rhs - let hash = hash_of_term + let pp ppf arg = term_pp (pp_self (fun ppf rhs -> + let exp = Rhs.exp rhs in + Bap_exp.pp ppf exp)) ppf arg - let string_of_intent = function - | Some In -> "in " - | Some Out -> "out " - | Some Both -> "in out " - | None -> "" + let pp_slots slots = term_pp (pp_self (pp_value slots)) - let pp_self ppf (var,exp,intent) = - Format.fprintf ppf "%s :: %s%a = %a" - (Var.name var) - (string_of_intent intent) - Bap_type.pp (Var.typ var) - Bap_exp.pp exp - let pp = term_pp pp_self - end) -end + module V2 = struct + type t = arg term [@@deriving bin_io, compare, sexp] + let module_name = Some "Bap.Std.Arg" + let version = "2.0.0" + let hash = hash_of_term + let pp = pp + end + include Regular.Make(V2) +end module Ir_def = struct type t = def term - include Leaf + + let reify ?(tid=Tid.create()) lhs rhs = + make_term tid @@ Def.reify lhs rhs + + let create ?(tid=Tid.create ()) var exp = + make_term tid @@ Def.of_bil var exp + + let var {self={Def.var}} = var + let value {self={Def.var; rhs}} = + let sort = Theory.Var.sort var in + KB.Value.refine rhs sort + + let lhs {self={Def.var}} = Var.reify var + let rhs {self={Def.rhs}} = Rhs.exp rhs + + let with_lhs ({self={Def.rhs}} as t ) v = { + t with self = Def.{ + var = to_var v; + rhs; + } + } + + let with_rhs ({self={Def.var; rhs}} as t) exp = { + t with self = Def.{ + var; + rhs = Rhs.with_exp exp rhs + } + } let map_exp def ~f : def term = with_rhs def (f (rhs def)) @@ -491,50 +724,98 @@ module Ir_def = struct let free_vars def = Exp.free_vars (rhs def) + let pp_self ppf {Def.var; rhs} = + Format.fprintf ppf + "%s := %a" (Theory.Var.name var) Bap_exp.pp (Rhs.exp rhs) - include Regular.Make(struct - type t = def term [@@deriving bin_io, compare, sexp] - let module_name = Some "Bap.Std.Def" - let version = "1.0.0" - let hash = hash_of_term + let pp_self_slots slots ppf {Def.var; rhs} = + Format.fprintf ppf + "%s := %a" (Theory.Var.name var) (pp_value slots) rhs - let pp_self ppf (lhs,rhs) = - Format.fprintf ppf "%a := %a" Var.pp lhs Bap_exp.pp rhs + let pp = term_pp pp_self + let pp_slots ds = term_pp (pp_self_slots ds) - let pp = term_pp pp_self - end) + module V2 = struct + type t = def term [@@deriving bin_io, compare, sexp] + let module_name = Some "Bap.Std.Def" + let version = "2.0.0" + let hash = hash_of_term + let pp = pp + end + include Regular.Make(V2) end module Ir_phi = struct type t = phi term - include Leaf - let of_list ?tid var bs : phi term = - create ?tid var (Tid.Map.of_alist_reduce bs ~f:(fun _ x -> x)) + let var {self={Phi.var}} = var + let lhs phi = Var.reify (var phi) + + let with_lhs ({self} as t) lhs = { + t with self = Phi.{ + self with var = to_var lhs; + } + } + + let reify ?(tid=Tid.create ()) var bs = + let bs = List.map bs ~f:(fun (t,x) -> t, Rhs.of_value x) in + make_term tid Phi.{ + var = Theory.Var.forget var; + map = Map.of_alist_exn (module Tid) bs + } + + let of_list ?(tid=Tid.create()) var bs : phi term = + let bs = List.map bs ~f:(fun (t,x) -> t, Rhs.of_exp x) in + make_term tid Phi.{ + var = to_var var; + map = Map.of_alist_exn (module Tid) bs + } + + let create ?tid var src exp : phi term = + of_list ?tid var [src,exp] - let create ?tid:_ var src exp : phi term = of_list var [src,exp] + let values {self={Phi.map}} : (tid * exp) Seq.t = + Map.to_sequence map |> + Seq.map ~f:(fun (t,x) -> t, Rhs.exp x) - let values (phi : phi term) : (tid * exp) Seq.t = - Map.to_sequence (rhs phi) + let options {self={Phi.map; var}} : (tid * _) Seq.t = + let sort = Theory.Var.sort var in + Map.to_sequence map |> + Seq.map ~f:(fun (t,x) -> t, KB.Value.refine x sort) - let update (phi : phi term) tid exp : phi term = - with_rhs phi (Map.set (rhs phi) ~key:tid ~data:exp) - let remove phi tid : phi term = - with_rhs phi (Map.remove (rhs phi) tid) + let update ({self={Phi.map; var}} as t) tid exp : phi term = { + t with self = Phi.{ + var; + map = Map.set map ~key:tid ~data:(Rhs.of_exp exp) + } + } - let select phi tid : exp option = - Map.find (rhs phi) tid + let remove ({self={Phi.map; var}} as t) tid : phi term = { + t with self = Phi.{ + var; + map = Map.remove map tid + } + } + + let select {self={Phi.map}} tid : exp option = + Option.map (Map.find map tid) ~f:Rhs.exp let select_or_unknown phi tid = match select phi tid with | Some thing -> thing | None -> - let name = Format.asprintf "no path from %a" Tid.pp tid in - Bap_exp.Exp.unknown name (Var.typ (lhs phi)) + let name = Format.asprintf "unresolved-tid %a" Tid.pp tid in + let typ = Var.typ (lhs phi) in + Exp.unknown name typ - let map_exp phi ~f : phi term = - with_rhs phi (Map.map (rhs phi) ~f) + let map_exp ({self={Phi.var; map}} as t) ~f : phi term = { + t with + self = { + var; + map = Map.map map ~f:(fun rhs -> Rhs.with_exp (f (Rhs.exp rhs)) rhs) + } + } let substitute phi x y = map_exp phi ~f:(Exp.substitute x y) @@ -542,47 +823,144 @@ module Ir_phi = struct values phi |> Seq.fold ~init:Bap_var.Set.empty ~f:(fun vars (_,e) -> Set.union vars (Exp.free_vars e)) - include Regular.Make(struct - type t = phi term [@@deriving bin_io, compare, sexp] - let module_name = Some "Bap.Std.Phi" - let version = "1.0.0" - - let hash = hash_of_term - - let pp_self ppf (lhs,rhs) = - Format.fprintf ppf "%a := phi(%s)" - Var.pp lhs - (String.concat ~sep:", " @@ - List.map ~f:(fun (id,exp) -> - Format.asprintf "[%a, %%%a]" Bap_exp.pp exp Tid.pp id) - (Map.to_alist rhs)) - let pp = term_pp pp_self - end) + let pp_self ppf {Phi.var; map} = + Format.fprintf ppf "%s := phi(%s)" + (Theory.Var.name var) + (String.concat ~sep:", " @@ + List.map ~f:(fun (id,exp) -> + let exp = Rhs.exp exp in + Format.asprintf "[%a, %%%a]" Bap_exp.pp exp Tid.pp id) + (Map.to_alist map)) + + let pp_self_slots ds ppf {Phi.var; map} = + Format.fprintf ppf "%s := phi(%s)" + (Theory.Var.name var) + (String.concat ~sep:", " @@ + List.map ~f:(fun (id,exp) -> + Format.asprintf "[%a, %%%a]" (pp_value ds) exp Tid.pp id) + (Map.to_alist map)) + + let pp = term_pp pp_self + let pp_slots ds = term_pp (pp_self_slots ds) + + module V2 = struct + type t = phi term [@@deriving bin_io, compare, sexp] + let module_name = Some "Bap.Std.Phi" + let version = "2.0.0" + let pp = pp + let hash = hash_of_term + end + include Regular.Make(V2) end module Ir_jmp = struct type t = jmp term - include Leaf + type dst = Jmp.dst - let create_call ?tid ?(cond=always) call = - create ?tid cond (Call call) + let resolved tid = Jmp.Resolved tid + let indirect dst = Jmp.Indirect { + vec = Rhs.of_value dst; + len = Theory.Bitv.size (KB.Class.sort (KB.Value.cls dst)); + } + + let reify ?(tid=Tid.create ()) ?cnd ?alt ?dst () = + make_term tid Jmp.{cnd; dst; alt} - let create_goto ?tid ?(cond=always) dest = - create ?tid cond (Goto dest) + let dst_of_lbl : label -> Jmp.dst option = function + | Direct tid -> Some (Resolved tid) + | Indirect exp -> match Bap_helpers.Type.infer_exn exp with + | Imm len -> Some (Indirect {vec = Rhs.of_exp exp; len}) + | _ -> None - let create_ret ?tid ?(cond=always) dest = - create ?tid cond (Ret dest) + let lbl_of_dst : Jmp.dst -> label = function + | Resolved tid -> Direct tid + | Indirect {vec} -> Indirect (Rhs.exp vec) + + let create ?(tid=Tid.create()) ?(cond=always) kind = + let cnd = if cond = always then None else Some (Cnd.of_exp cond) in + make_term tid @@ match kind with + | Goto lbl -> Jmp.{ + cnd; + dst = dst_of_lbl lbl; alt = None; + } + | Ret lbl -> Jmp.{ + cnd; + dst = None; alt = dst_of_lbl lbl + } + | Int (int,ret) -> + let alt = Tid.create () in + Tid.set_ivec alt int; + Jmp.{ + cnd; + dst = Some (Resolved ret); + alt = Some (Resolved alt); + } + | Call t -> { + cnd; + dst = Option.bind ~f:dst_of_lbl (Call.return t); + alt = dst_of_lbl (Call.target t); + } - let create_int ?tid ?(cond=always) n t = - create ?tid cond (Int (n,t)) + let ivec_of_dst : Jmp.dst -> int option = function + | Indirect _ -> None + | Resolved t -> Tid.get_ivec t + + let kind_of_jmp {Jmp.dst; alt} = + match dst, alt with + | None, None -> Goto (Indirect (Exp.unknown "unknown" Unk)) + | Some dst, None -> Goto (lbl_of_dst dst) + | None, Some alt -> Call (Call.create ~target:(lbl_of_dst alt) ()) + | Some dst, Some alt -> match dst, ivec_of_dst alt with + | Resolved dst, Some vec -> Int (vec,dst) + | _ -> Call (Call.create () + ~return:(lbl_of_dst dst) + ~target:(lbl_of_dst alt)) + + let create_call ?tid ?cond call = create ?tid ?cond (Call call) + let create_goto ?tid ?cond dest = create ?tid ?cond (Goto dest) + let create_ret ?tid ?cond dest = create ?tid ?cond (Ret dest) + let create_int ?tid ?cond n t = create ?tid ?cond (Int (n,t)) + + let guard {self={Jmp.cnd}} = cnd + let with_guard jmp cnd = {jmp with self = Jmp.{ + jmp.self with cnd + }} + + let dst {self={Jmp.dst}} = dst + let alt {self={Jmp.alt}} = alt + + let with_dst jmp dst = {jmp with self = Jmp.{ + jmp.self with dst + }} + + let with_alt jmp alt = {jmp with self = Jmp.{ + jmp.self with alt + }} + + let resolve = function + | Jmp.Resolved t -> Either.first t + | Jmp.Indirect {vec; len} -> + let s = Theory.Bitv.define len in + Either.second (KB.Value.refine vec s) + + let kind : jmp term -> jmp_kind = fun t -> + kind_of_jmp t.self + + let cond_of_jmp {Jmp.cnd} = match cnd with + | None -> always + | Some cnd -> KB.Value.get Exp.slot cnd + + + let cond : jmp term -> exp = fun t -> cond_of_jmp t.self + + let with_cond t cnd = { + t with self = Jmp.{ + t.self with cnd = Some (Cnd.of_exp cnd) + } + } - let create ?tid ?(cond=always) kind = - create ?tid cond kind - let kind = rhs - let cond = lhs - let with_cond = with_lhs - let with_kind = with_rhs + let with_kind t kind = create ~tid:t.tid ~cond:(cond t) kind let exps (jmp : jmp term) : exp Sequence.t = let open Sequence.Generator in @@ -607,7 +985,7 @@ module Ir_jmp = struct let return = Option.map (Call.return call) ~f:map_label in let target = map_label (Call.target call) in Call.create ?return ~target () in - let jmp = with_cond jmp (f (cond jmp)) in + let jmp : jmp term = with_cond jmp (f (cond jmp)) in let kind = match kind jmp with | Call t -> Call (map_call t) | Goto t -> Goto (map_label t) @@ -629,30 +1007,35 @@ module Ir_jmp = struct | Goto t -> eval_label t | _ -> assert false + let pp_dst ppf = function + | Goto dst -> Format.fprintf ppf "goto %a" Label.pp dst + | Call sub -> Call.pp ppf sub + | Ret dst -> Format.fprintf ppf "return %a" Label.pp dst + | Int (n,t) -> + Format.fprintf ppf "interrupt 0x%X return %%%a" n Tid.pp t - include Regular.Make(struct - type t = jmp term [@@deriving bin_io, compare, sexp] - let module_name = Some "Bap.Std.Jmp" - let version = "1.0.0" + let pp_cond ppf cond = + if Exp.(cond <> always) then + Format.fprintf ppf "when %a " Bap_exp.pp cond - let hash = hash_of_term + let pp_self ppf jmp = + Format.fprintf ppf "%a%a" + pp_cond (cond_of_jmp jmp) + pp_dst (kind_of_jmp jmp) - let pp_dst ppf = function - | Goto dst -> Format.fprintf ppf "goto %a" Label.pp dst - | Call sub -> Call.pp ppf sub - | Ret dst -> Format.fprintf ppf "return %a" Label.pp dst - | Int (n,t) -> - Format.fprintf ppf "interrupt 0x%X return %%%a" n Tid.pp t + let pp = term_pp pp_self + let pp_slots _ = pp - let pp_cond ppf cond = - if Exp.(cond <> always) then - Format.fprintf ppf "when %a " Bap_exp.pp cond - let pp_self ppf (lhs,rhs) = - Format.fprintf ppf "%a%a" pp_cond lhs pp_dst rhs + module V2 = struct + type t = jmp term [@@deriving bin_io, compare, sexp] + let module_name = Some "Bap.Std.Jmp" + let version = "2.0.0" + let hash = hash_of_term + let pp = pp + end - let pp = term_pp pp_self - end) + include Regular.Make(V2) end @@ -808,6 +1191,24 @@ module Term = struct ~name:"postcondition" ~uuid:"f248e4c1-9efc-4c70-a864-e34706e2082b" + let equal x y = + compare_term compare_blk x y = 0 + + let equal_tids x y = Tid.equal (tid x) (tid y) + + let domain = Knowledge.Domain.flat ~empty:[] "bir" + ~equal:(List.equal ~equal:equal_tids) + ~inspect:(fun blks -> Sexp.List (List.map blks ~f:(fun b -> + Sexp.Atom (name b)))) + + let persistent = Knowledge.Persistent.of_binable (module struct + type t = blk term list [@@deriving bin_io] + end) + + let slot = Knowledge.Class.property ~package ~persistent + Theory.Program.Semantics.cls "bir" domain + + let change t p tid f = Array.findi (t.get p.self) ~f:(fun _ x -> x.tid = tid) |> function | None -> Option.value_map (f None) ~f:(append t p) ~default:p @@ -917,21 +1318,21 @@ module Term = struct map jmp_t ~f:(self#map_term jmp_t) - - method map_arg arg = { - arg with - self = map1 ~f:self#map_sym arg.self |> - map2 ~f:self#map_exp + method private map_assn ({self=Def.{var;rhs}} as t) = { + t with + self = Def.{ + var = to_var (self#map_sym @@ Var.reify var); + rhs = Rhs.with_exp (self#map_exp (Rhs.exp rhs)) rhs; + } } + method map_def = self#map_assn + method map_arg = self#map_assn + method map_phi phi = - let phi = Ir_phi.(with_lhs phi (self#map_sym (lhs phi))) in + let phi = Ir_phi.with_lhs phi @@ self#map_sym (Ir_phi.lhs phi) in Ir_phi.map_exp phi ~f:self#map_exp - method map_def def = - let def = Ir_def.(with_lhs def (self#map_sym (lhs def))) in - Ir_def.map_exp def ~f:self#map_exp - method map_jmp jmp = Ir_jmp.map_exp jmp ~f:self#map_exp end @@ -947,16 +1348,16 @@ module Term = struct method leave_term : 't 'p. ('p,'t) cls -> 't term -> 'a -> 'a = fun _cls _t x -> x method visit_term : 't 'p. ('p,'t) cls -> 't term -> 'a -> 'a = fun cls t x -> - let x = self#enter_term cls t x in - switch cls t - ~program:(fun t -> self#run t x) - ~sub:(fun t -> self#visit_sub t x) - ~arg:(fun t -> self#visit_arg t x) - ~blk:(fun t -> self#visit_blk t x) - ~phi:(fun t -> self#visit_phi t x) - ~def:(fun t -> self#visit_def t x) - ~jmp:(fun t -> self#visit_jmp t x) |> - self#leave_term cls t + let x = self#enter_term cls t x in + switch cls t + ~program:(fun t -> self#run t x) + ~sub:(fun t -> self#visit_sub t x) + ~arg:(fun t -> self#visit_arg t x) + ~blk:(fun t -> self#visit_blk t x) + ~phi:(fun t -> self#visit_phi t x) + ~def:(fun t -> self#visit_def t x) + ~jmp:(fun t -> self#visit_jmp t x) |> + self#leave_term cls t method enter_program _p x = x method leave_program _p x = x @@ -998,21 +1399,21 @@ module Term = struct method visit_arg arg x = self#enter_arg arg x |> - self#visit_var (fst3 arg.self) |> - self#visit_exp (snd3 arg.self) |> + self#visit_var (Ir_arg.lhs arg) |> + self#visit_exp (Ir_arg.rhs arg) |> self#leave_arg arg method visit_phi phi x = self#enter_phi phi x |> - self#visit_var (fst phi.self) |> fun x -> - Map.fold (snd phi.self) ~init:x ~f:(fun ~key:_ ~data x -> - self#visit_exp data x) |> + self#visit_var (Ir_phi.lhs phi) |> fun x -> + Seq.fold (Ir_phi.values phi) ~init:x ~f:(fun data (_,x) -> + self#visit_exp x data) |> self#leave_phi phi method visit_def def x = self#enter_def def x |> - self#visit_var (fst def.self) |> - self#visit_exp (snd def.self) |> + self#visit_var (Ir_def.lhs def) |> + self#visit_exp (Ir_def.rhs def) |> self#leave_def def method visit_jmp jmp x = @@ -1179,8 +1580,8 @@ module Ir_blk = struct let substitute ?skip blk x y = map_exp ?skip blk ~f:(Exp.substitute x y) - let map_phi_lhs p ~f = Ir_phi.(with_lhs p (f (lhs p))) - let map_def_lhs d ~f = Ir_def.(with_lhs d (f (lhs d))) + let map_phi_lhs p ~f = Ir_phi.with_lhs p (f (Ir_phi.lhs p)) + let map_def_lhs d ~f = Ir_def.with_lhs d (f (Ir_def.lhs d)) let map_lhs ?(skip=[]) blk ~f = { blk with self = { @@ -1190,11 +1591,13 @@ module Ir_blk = struct } } + let has_lhs cls lhs blk x = + Term.to_sequence cls blk |> + Seq.exists ~f:(fun t -> Var.(lhs t = x)) + let defines_var blk x = - let exists t = - Term.to_sequence t blk |> - Seq.exists ~f:(fun y -> Var.(Leaf.lhs y = x)) in - exists phi_t || exists def_t + has_lhs phi_t Ir_phi.lhs blk x || + has_lhs def_t Ir_def.lhs blk x let free_vars blk = let (++) = Set.union and (--) = Set.diff in @@ -1223,20 +1626,27 @@ module Ir_blk = struct dominator = id || Term.(after def_t b dominator |> Seq.exists ~f:(fun x -> x.tid = id)) + let pp_self ppf self = + Format.fprintf ppf "@[@.%a%a%a@]" + (Array.pp Ir_phi.pp) self.phis + (Array.pp Ir_def.pp) self.defs + (Array.pp Ir_jmp.pp) self.jmps + + let pp_self_slots ds ppf self = + Format.fprintf ppf "@[@.%a%a%a@]" + (Array.pp (Ir_phi.pp_slots ds)) self.phis + (Array.pp (Ir_def.pp_slots ds)) self.defs + (Array.pp (Ir_jmp.pp_slots ds)) self.jmps + + let pp_slots ds = term_pp (pp_self_slots ds) + let pp = term_pp pp_self + include Regular.Make(struct type t = blk term [@@deriving bin_io, compare, sexp] let module_name = Some "Bap.Std.Blk" let version = "1.0.0" - let hash = hash_of_term - - let pp_self ppf self = - Format.fprintf ppf "@[@.%a%a%a@]" - (Array.pp Ir_phi.pp) self.phis - (Array.pp Ir_def.pp) self.defs - (Array.pp Ir_jmp.pp) self.jmps - - let pp = term_pp pp_self + let pp = pp end) end @@ -1247,7 +1657,9 @@ module Ir_sub = struct let create ?(tid=Tid.create ()) ?name () : t = let name = match name with | Some name -> name - | None -> Tid.to_string tid in + | None -> match Tid.get_name tid with + | None -> Tid.to_string tid + | Some name -> name in make_term tid { name; args = [| |] ; @@ -1257,7 +1669,7 @@ module Ir_sub = struct let name sub = sub.self.name let with_name sub name = - Tid.set_name (Term.tid sub) name; + Tid.add_name (Term.tid sub) name; {sub with self = {sub.self with name}} module Enum(T : Bap_value.S) = struct @@ -1272,7 +1684,7 @@ module Ir_sub = struct module Args = struct type t = (arg term, arg term * arg term) Either.t - [@@deriving bin_io, compare, sexp] + [@@deriving bin_io, compare, sexp] let pp ppf = function | First x -> Format.fprintf ppf "(%s)" (Ir_arg.name x) | Second (x,y) -> @@ -1348,22 +1760,42 @@ module Ir_sub = struct | None -> Format.asprintf "sub_%a" Tid.pp tid in make_term tid {name; args; blks} end + let pp_self ppf self = + Format.fprintf ppf "@[sub %s(%s)@.%a%a@]" + self.name + (String.concat ~sep:", " @@ + Array.to_list @@ + Array.map self.args ~f:Ir_arg.name) + (Array.pp Ir_arg.pp) self.args + (Array.pp Ir_blk.pp) self.blks + + let pp_self ppf self = + Format.fprintf ppf "@[sub %s(%s)@.%a%a@]" + self.name + (String.concat ~sep:", " @@ + Array.to_list @@ + Array.map self.args ~f:Ir_arg.name) + (Array.pp Ir_arg.pp) self.args + (Array.pp Ir_blk.pp) self.blks + + let pp_self_slots ds ppf self = + Format.fprintf ppf "@[sub %s(%s)@.%a%a@]" + self.name + (String.concat ~sep:", " @@ + Array.to_list @@ + Array.map self.args ~f:Ir_arg.name) + (Array.pp (Ir_arg.pp_slots ds)) self.args + (Array.pp (Ir_blk.pp_slots ds)) self.blks + + let pp = term_pp pp_self + let pp_slots ds = term_pp (pp_self_slots ds) + include Regular.Make(struct type t = sub term [@@deriving bin_io, compare, sexp] let module_name = Some "Bap.Std.Sub" let version = "1.0.0" - + let pp = pp let hash = hash_of_term - let pp_self ppf self = - Format.fprintf ppf "@[sub %s(%s)@.%a%a@]" - self.name - (String.concat ~sep:", " @@ - Array.to_list @@ - Array.map self.args ~f:Ir_arg.name) - (Array.pp Ir_arg.pp) self.args - (Array.pp Ir_blk.pp) self.blks - - let pp = term_pp pp_self end) end @@ -1474,15 +1906,22 @@ module Ir_program = struct end + let pp_self ppf self = + Format.fprintf ppf "@[program@.%a@]" + (Array.pp Ir_sub.pp) self.subs + + let pp_self_slots ds ppf self = + Format.fprintf ppf "@[program@.%a@]" + (Array.pp (Ir_sub.pp_slots ds)) self.subs + + let pp_slots ds = term_pp (pp_self_slots ds) + let pp = term_pp pp_self + include Regular.Make(struct type t = program term [@@deriving bin_io, compare, sexp] let module_name = Some "Bap.Std.Program" let version = "1.0.0" - + let pp = pp let hash = hash_of_term - let pp_self ppf self = - Format.fprintf ppf "@[program@.%a@]" - (Array.pp Ir_sub.pp) self.subs - let pp = term_pp pp_self end) end diff --git a/lib/bap_types/bap_ir.mli b/lib/bap_types/bap_ir.mli index b2d6b5ace..a0401a0e9 100644 --- a/lib/bap_types/bap_ir.mli +++ b/lib/bap_types/bap_ir.mli @@ -1,9 +1,14 @@ open Core_kernel open Regular.Std +open Bap_core_theory open Bap_common open Bap_bil open Bap_value open Bap_visitor +open Bap_core_theory + +type tid = Theory.Label.t +[@@deriving bin_io, compare, sexp] type 'a term [@@deriving bin_io, compare, sexp] type program [@@deriving bin_io, compare, sexp] @@ -14,27 +19,25 @@ type blk [@@deriving bin_io, compare, sexp] type phi [@@deriving bin_io, compare, sexp] type def [@@deriving bin_io, compare, sexp] type jmp [@@deriving bin_io, compare, sexp] - -type tid [@@deriving bin_io, compare, sexp] type call [@@deriving bin_io, compare, sexp] type label = | Direct of tid | Indirect of exp - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] type jmp_kind = | Call of call | Goto of label | Ret of label | Int of int * tid - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] type intent = | In | Out | Both - [@@deriving bin_io, compare, sexp] +[@@deriving bin_io, compare, sexp] type ('a,'b) cls @@ -48,19 +51,25 @@ val jmp_t : (blk, jmp) cls module Tid : sig type t = tid - val create : unit -> t - val set_name : t -> string -> unit - val name : t -> string + + val for_name : string -> t Bap_toplevel.t + val for_addr : addr -> t Bap_toplevel.t + val for_ivec : int -> t Bap_toplevel.t + + val create : unit -> t Bap_toplevel.t + val set_name : t -> string -> unit Bap_toplevel.t + val name : t -> string Bap_toplevel.t val from_string : string -> tid Or_error.t val from_string_exn : string -> tid val (!!) : string -> tid include Regular.S with type t := t - module Tid_generator : Bap_state.S - module Name_resolver : Bap_state.S end module Term : sig type 'a t = 'a term + + val slot : (Theory.Program.Semantics.cls, blk term list) KB.slot + val clone : 'a t -> 'a t val same : 'a t -> 'a t -> bool val name : 'a t -> string @@ -103,7 +112,6 @@ module Term : sig val invariant : exp tag val postcondition : exp tag - class mapper : object inherit exp_mapper method run : program term -> program term @@ -193,7 +201,7 @@ module Ir_program : sig val add_sub : t -> sub term -> unit val result : t -> program term end - + val pp_slots : string list -> Format.formatter -> t -> unit include Regular.S with type t := t end @@ -221,6 +229,7 @@ module Ir_sub : sig val returns_twice : unit tag val nothrow : unit tag val entry_point : unit tag + val pp_slots : string list -> Format.formatter -> t -> unit include Regular.S with type t := t end @@ -276,11 +285,19 @@ module Ir_blk : sig val add_elt : t -> elt -> unit val result : t -> blk term end + val pp_slots : string list -> Format.formatter -> t -> unit include Regular.S with type t := t end module Ir_def : sig type t = def term + + val reify : ?tid:tid -> 'a Theory.var -> 'a Theory.value -> t + + val var : t -> unit Theory.var + val value : t -> unit Theory.value + + val create : ?tid:tid -> var -> exp -> t val lhs : t -> var val rhs : t -> exp @@ -289,11 +306,28 @@ module Ir_def : sig val map_exp : t -> f:(exp -> exp) -> t val substitute : t -> exp -> exp -> t val free_vars : t -> Bap_var.Set.t + val pp_slots : string list -> Format.formatter -> t -> unit + include Regular.S with type t := t end module Ir_jmp : sig type t = jmp term + type dst + + val reify : ?tid:tid -> + ?cnd:Theory.Bool.t Theory.value -> + ?alt:dst -> ?dst:dst -> unit -> t + + val guard : t -> Theory.Bool.t Theory.value option + val with_guard : t -> Theory.Bool.t Theory.value option -> t + val dst : t -> dst option + val alt : t -> dst option + + val resolved : tid -> dst + val indirect : 'a Theory.Bitv.t Theory.value -> dst + val resolve : dst -> (tid,'a Theory.Bitv.t Theory.value) Either.t + val create : ?tid:tid -> ?cond:exp -> jmp_kind -> t val create_call : ?tid:tid -> ?cond:exp -> call -> t val create_goto : ?tid:tid -> ?cond:exp -> label -> t @@ -307,11 +341,22 @@ module Ir_jmp : sig val map_exp : t -> f:(exp -> exp) -> t val substitute : t -> exp -> exp -> t val free_vars : t -> Bap_var.Set.t + val pp_slots : string list -> Format.formatter -> t -> unit + include Regular.S with type t := t end module Ir_phi : sig type t = phi term + + val reify : ?tid:tid -> + 'a Theory.var -> + (tid * 'a Theory.value) list -> + t + + val var : t -> unit Theory.var + val options : t -> (tid * unit Theory.value) seq + val create : ?tid:tid -> var -> tid -> exp -> t val of_list : ?tid:tid -> var -> (tid * exp) list -> t val lhs : t -> var @@ -324,12 +369,20 @@ module Ir_phi : sig val map_exp : t -> f:(exp -> exp) -> t val substitute : t -> exp -> exp -> t val free_vars : t -> Bap_var.Set.t + val pp_slots : string list -> Format.formatter -> t -> unit include Regular.S with type t := t end module Ir_arg : sig type t = arg term + val reify : ?tid:tid -> ?intent:intent -> + 'a Theory.var -> + 'a Theory.value -> t + + val var : t -> unit Theory.var + val value : t -> unit Theory.value + val create : ?tid:tid -> ?intent:intent -> var -> exp -> t val lhs : t -> var val rhs : t -> exp @@ -342,6 +395,11 @@ module Ir_arg : sig val warn_unused : unit tag val restricted : unit tag val nonnull : unit tag + val pp_slots : string list -> Format.formatter -> t -> unit + + module Intent : sig + val slot : (Theory.Value.cls, intent option) KB.slot + end include Regular.S with type t := t end diff --git a/lib/bap_types/bap_state.ml b/lib/bap_types/bap_state.ml deleted file mode 100644 index 855113f7a..000000000 --- a/lib/bap_types/bap_state.ml +++ /dev/null @@ -1,19 +0,0 @@ -open Core_kernel - -module type S = sig - type t - val fresh : unit -> t - val store : t -> unit -end - -module Make(T : sig - type t - val create : unit -> t - end) = struct - type t = T.t - let state = ref @@ T.create () - let fresh () = - state := T.create (); - !state - let store t = state := t -end diff --git a/lib/bap_types/bap_state.mli b/lib/bap_types/bap_state.mli deleted file mode 100644 index 29154cfb8..000000000 --- a/lib/bap_types/bap_state.mli +++ /dev/null @@ -1,15 +0,0 @@ -open Bap_value - -module type S = sig - type t - val fresh : unit -> t - val store : t -> unit -end - -module Make(T : sig - type t - val create : unit -> t - end) : sig - include S with type t = T.t - val state : t ref -end diff --git a/lib/bap_types/bap_stmt.ml b/lib/bap_types/bap_stmt.ml index d09a44a45..8f80afeed 100644 --- a/lib/bap_types/bap_stmt.ml +++ b/lib/bap_types/bap_stmt.ml @@ -1,4 +1,6 @@ open Core_kernel +open Bap_core_theory +open Bap_knowledge open Regular.Std open Bap_common open Format @@ -90,3 +92,22 @@ module Stmts_data = struct set_default_writer "bin"; set_default_reader "bin" end + + +let domain = Knowledge.Domain.flat "bil" + ~empty:[] + ~inspect:(function + | [] -> Sexp.List [] + | bil -> Sexp.Atom (Stmts_pp.to_string bil)) + ~equal:(fun x y -> + phys_equal x y || + Int.(compare_bil x y = 0)) + + +let persistent = Knowledge.Persistent.of_binable (module struct + type t = stmt list [@@deriving bin_io] + end) + +let slot = Knowledge.Class.property ~package:"bap.std" + ~persistent Theory.Program.Semantics.cls "bil" domain + ~desc:"semantics of statements in BIL" diff --git a/lib/bap_types/bap_stmt.mli b/lib/bap_types/bap_stmt.mli index 556a6b0f4..d5048c4c7 100644 --- a/lib/bap_types/bap_stmt.mli +++ b/lib/bap_types/bap_stmt.mli @@ -1,8 +1,11 @@ open Core_kernel +open Bap_core_theory +open Bap_knowledge open Regular.Std open Bap_common open Bap_bil + include Regular.S with type t := stmt val pp_stmts : Format.formatter -> stmt list -> unit @@ -22,3 +25,7 @@ end module Stmts_pp : Printable.S with type t = stmt list module Stmts_data : Data.S with type t = stmt list + +val slot : (Theory.Program.Semantics.cls, stmt list) Knowledge.slot +val domain : stmt list Knowledge.domain +val persistent : stmt list Knowledge.persistent diff --git a/lib/bap_types/bap_toplevel.ml b/lib/bap_types/bap_toplevel.ml new file mode 100644 index 000000000..fc7361845 --- /dev/null +++ b/lib/bap_types/bap_toplevel.ml @@ -0,0 +1,70 @@ +open Core_kernel +open Bap_knowledge +open Knowledge.Syntax + +let package = "bap.std-internal" +type 'a t = 'a +type env = Knowledge.state ref +type main = Main +type 'p var = (main,'p option) Knowledge.slot + +let state = ref Knowledge.empty +let slots = ref 0 +let env = state + +let set s = state := s +let reset () = state := Knowledge.empty +let current () = !state + +exception Internal_runtime_error of Knowledge.conflict [@@deriving sexp] +exception Not_found [@@deriving sexp] + +let try_eval slot exp = + let cls = Knowledge.Slot.cls slot in + match Knowledge.run cls exp !state with + | Ok (v,s) -> + state := s; + Ok (Knowledge.Value.get slot v) + | Error conflict -> Error conflict + +let eval slot exp = + try_eval slot exp |> function + | Ok v -> v + | Error conflict -> + raise (Internal_runtime_error conflict) + +let main = Knowledge.Class.declare ~package "main" Main + + +let var name = + incr slots; + let name = sprintf "%s%d" name !slots in + let order x y : Knowledge.Order.partial = match x, y with + | Some _, Some _ | None, None -> EQ + | None,Some _ -> LT + | Some _,None -> GT in + let dom = Knowledge.Domain.define ~empty:None ~order "any" in + Knowledge.Class.property ~package main name dom + +let this = + Knowledge.Symbol.intern ~package "main" main + +let try_exec stmt = + let stmt = stmt >>= fun () -> this in + match Knowledge.run main stmt !state with + | Ok (_,s) -> Ok (state := s) + | Error conflict -> Error conflict + +let exec stmt = + try_exec stmt |> function + | Ok () -> () + | Error err -> raise (Internal_runtime_error err) + +let put slot exp = exec @@begin + exp >>= fun v -> this >>= fun x -> + Knowledge.provide slot x (Some v) + end + +let get slot = eval slot this |> function + | None -> raise Not_found + | Some x -> x diff --git a/lib/bap_types/bap_toplevel.mli b/lib/bap_types/bap_toplevel.mli new file mode 100644 index 000000000..5f7c65720 --- /dev/null +++ b/lib/bap_types/bap_toplevel.mli @@ -0,0 +1,21 @@ +open Bap_knowledge + +type 'a t = 'a +type 'p var + +val eval : ('a,'p) Knowledge.slot -> 'a Knowledge.obj knowledge -> 'p t +val exec : unit knowledge -> unit t + +val try_eval : ('a,'p) Knowledge.slot -> 'a Knowledge.obj knowledge -> + ('p,Knowledge.conflict) result t + +val try_exec : unit knowledge -> (unit,Knowledge.conflict) result t + +val get : 'p var -> 'p t +val put : 'p var -> 'p knowledge -> unit t +val var : string -> 'p var + + +val reset : unit -> unit +val set : Knowledge.state -> unit +val current : unit -> Knowledge.state diff --git a/lib/bap_types/bap_type.ml b/lib/bap_types/bap_type.ml index 39204bd33..eaed4e631 100644 --- a/lib/bap_types/bap_type.ml +++ b/lib/bap_types/bap_type.ml @@ -12,6 +12,7 @@ module T = struct let pp fmt = function + | Unk -> fprintf fmt "unk" | Imm n -> fprintf fmt "u%u" n | Mem (idx, elm) -> fprintf fmt "%a?%a" Bap_size.pp (idx :> size) Bap_size.pp elm diff --git a/lib/bap_types/bap_type_error.ml b/lib/bap_types/bap_type_error.ml index 730e1013e..5c50caafc 100644 --- a/lib/bap_types/bap_type_error.ml +++ b/lib/bap_types/bap_type_error.ml @@ -6,6 +6,7 @@ type t = [ | `bad_kind of [`mem | `imm] | `bad_type of typ * typ | `bad_cast + | `unknown ] [@@deriving bin_io, compare, sexp] type type_error = t [@@deriving bin_io, compare, sexp] @@ -16,6 +17,7 @@ let bad_mem = `bad_kind `mem let bad_imm = `bad_kind `imm let bad_cast = `bad_cast let bad_type ~exp ~got = `bad_type (exp,got) +let unknown = `unknown let expect_mem () = raise (T (`bad_kind `mem)) let expect_imm () = raise (T (`bad_kind `imm)) @@ -23,6 +25,7 @@ let wrong_cast () = raise (T (`bad_cast)) let expect e ~got = raise (T (`bad_type (e,got))) let to_string : type_error -> string = function + | `unknown -> "a non-representable in BIL type" | `bad_kind `mem -> "expected storage, got immediate value" | `bad_kind `imm -> "expected immediate value, got storage" | `bad_cast -> "malformed cast arguments" diff --git a/lib/bap_types/bap_type_error.mli b/lib/bap_types/bap_type_error.mli index 916a2022f..80e9bbd22 100644 --- a/lib/bap_types/bap_type_error.mli +++ b/lib/bap_types/bap_type_error.mli @@ -7,6 +7,7 @@ type t = [ | `bad_kind of [`mem | `imm] | `bad_type of typ * typ | `bad_cast + | `unknown ] [@@deriving bin_io, compare, sexp] exception T of t [@@deriving sexp_of] @@ -16,6 +17,7 @@ val bad_mem : t val bad_imm : t val bad_cast : t val bad_type : exp:typ -> got:typ -> t +val unknown : t val expect_mem : unit -> 'a val expect_imm : unit -> 'a diff --git a/lib/bap_types/bap_types.ml b/lib/bap_types/bap_types.ml index 9db4d0032..681334154 100644 --- a/lib/bap_types/bap_types.ml +++ b/lib/bap_types/bap_types.ml @@ -8,6 +8,7 @@ open Core_kernel open Regular.Std open Bap_common +open Bap_knowledge (** This module is included into [Bap.Std], you need to open it specifically if you're developing inside BAP *) @@ -15,7 +16,7 @@ module Std = struct (** A definition for a regular type, and a handy module, that can create regular types out of thin air. *) module Integer = Integer - module State = Bap_state + module Toplevel = Bap_toplevel module Trie = struct include Bap_trie_intf include Bap_trie @@ -65,6 +66,7 @@ module Std = struct type typ = Type.t = | Imm of int | Mem of addr_size * size + | Unk [@@deriving bin_io, compare, sexp] include Bap_bil.Cast include Bap_bil.Binop @@ -98,6 +100,10 @@ module Std = struct include Bap_bil_pass module Pass = Bap_bil_pass.Pass_pp include Bap_bil_optimizations + + let slot = Bap_stmt.slot + let domain = Bap_stmt.domain + let persistent = Bap_stmt.persistent end (** Types of BIL expressions *) @@ -140,6 +146,7 @@ module Std = struct let eval = Bap_expi.eval let simpl = Bap_helpers.Simpl.exp let pp_adt = Bap_bil_adt.pp_exp + let slot = Bap_exp.slot end (** [Regular] interface for BIL statements *) @@ -240,5 +247,4 @@ module Std = struct module Callgraph = Bap_ir_callgraph - end diff --git a/lib/bap_types/bap_value.ml b/lib/bap_types/bap_value.ml index 396490396..1877412ea 100644 --- a/lib/bap_types/bap_value.ml +++ b/lib/bap_types/bap_value.ml @@ -1,127 +1,193 @@ +open Bap_core_theory open Core_kernel open Regular.Std open Format -module Typeid = String - module type S = sig type t [@@deriving bin_io, compare, sexp] val pp : Format.formatter -> t -> unit end +module Uid = Type_equal.Id.Uid +module Typeid = String + type void -type univ = Univ.t type literal = (void,void,void) format -type typeid = String.t [@@deriving bin_io, compare, sexp] +type uid = Uid.t +type typeid = Typeid.t [@@deriving bin_io, compare, sexp] type 'a tag = { - tid : typeid; key : 'a Type_equal.Id.t; + slot : (Theory.program,'a option) KB.slot; } -type t = { - uuid : typeid; - data : string -} [@@deriving bin_io, sexp] +module Value = struct + type t = Univ_map.Packed.t = T : 'a Type_equal.Id.t * 'a -> t +end -type value = t [@@deriving bin_io, sexp] +module Equal = struct + type ('a,'b) t = ('a,'b) Type_equal.t = T : ('a,'a) t + let proof = Type_equal.Id.same_witness_exn + let try_prove = Type_equal.Id.same_witness +end type type_info = { - pp : Format.formatter -> univ -> unit; - of_string : string -> univ; - to_string : univ -> string; - compare : univ -> univ -> int; + pp : Format.formatter -> Value.t -> unit; + of_string : string -> Value.t; + to_string : Value.t -> string; + of_sexp : Sexp.t -> Value.t; + to_sexp : Value.t -> Sexp.t; + collect : Theory.Label.t -> Univ_map.t -> Univ_map.t KB.t; + compare : Value.t -> Value.t -> int; } +let names : string Hash_set.t = Hash_set.create (module Typeid) () let types : (typeid, type_info) Hashtbl.t = - Typeid.Table.create ~size:128 () - -let register (type a) ~name ~uuid - (typ : (module S with type t = a)) : a tag = - let module S = (val typ) in - match Hashtbl.find types uuid with - | None -> - let uuid = match Uuidm.of_string uuid with - | None -> invalid_arg "Invalid UUID format" - | Some uuid -> Uuidm.to_bytes uuid in - let key = Type_equal.Id.create name S.sexp_of_t in - let pp ppf univ = S.pp ppf (Univ.match_exn univ key) in - let of_string str = - Univ.create key (Binable.of_string (module S) str) in - let to_string x = - Binable.to_string (module S) (Univ.match_exn x key) in - let compare x y = match Univ.match_ x key, Univ.match_ y key with - | Some x, Some y -> S.compare x y - | _,_ -> Type_equal.Id.Uid.compare - (Univ.type_id_uid x) (Univ.type_id_uid y) in - let info = { - pp; - of_string; - to_string; - compare; - } in - Hashtbl.add_exn types ~key:uuid ~data:info; - {key; tid = uuid} - | Some _ -> - invalid_argf "UUID %s is already in use" (Typeid.to_string uuid) () - -let nil = - let key = Type_equal.Id.create "nil" sexp_of_string in - let pp ppf _ = fprintf ppf "" in - let of_string s = Univ.create key s in - (* to_string is it is called only in create, where it is guaranteed - that the tag exists, in other words, to_string will be never - called *) - let to_string x = assert false in - let compare x y = Pervasives.compare x y in + Hashtbl.create ~size:128 (module Typeid) + +let uid = Type_equal.Id.uid + +type ('a,'b) eq = ('a,'b) Type_equal.t = T : ('a,'a) eq + +let register_slot (type a) slot + (module S : S with type t = a) : a tag = + let name = KB.Slot.name slot in + let key = Type_equal.Id.create name S.sexp_of_t in + let pp ppf (Value.T (k,x)) = + let T = Equal.proof k key in + S.pp ppf x in + let of_string str = + Value.T (key, Binable.of_string (module S) str) in + let to_string (Value.T (k,x)) = + let T = Equal.proof k key in + Binable.to_string (module S) x in + let of_sexp str = + Value.T (key, S.t_of_sexp str) in + let to_sexp (Value.T (k,x)) = + let T = Equal.proof k key in + S.sexp_of_t x in + let compare (Value.T (kx,x)) (Value.T (ky,y)) = + match Equal.try_prove kx ky with + | None -> Uid.compare (uid kx) (uid ky) + | Some T -> + let T = Equal.proof kx key in + S.compare x y in + let collect obj dict = + let open KB.Syntax in + KB.collect slot obj >>| function + | None -> dict + | Some x -> Univ_map.set dict key x in let info = { + pp; + of_sexp; + to_sexp; of_string; to_string; - pp; compare + collect; + compare; } in - Hashtbl.add_exn types ~key:"" ~data:info; - info - -let typeof v = match Hashtbl.find types v.uuid with - | Some info -> info - | None -> nil - -let compare_value x y = - match compare x.uuid y.uuid with - | 1 | -1 as r -> r - | _ -> match typeof x, typeof y with - | t, s -> t.compare (t.of_string x.data) (s.of_string y.data) + Hashtbl.add_exn types ~key:name ~data:info; + {key; slot} + +let register (type a) ~name ~uuid (module S : S with type t = a) = + let persistent = KB.Persistent.of_binable (module struct + type t = S.t option [@@deriving bin_io] + end) in + let equal x y = S.compare x y = 0 in + let domain = KB.Domain.optional ~equal name in + let slot = KB.Class.property ~persistent ~package:uuid + Theory.Program.cls name domain in + register_slot slot (module S) + +let find_separator s = + if String.is_empty s then None + else String.Escaping.index s ~escape_char:'\\' ':' + +let tagname (Value.T (k,_)) = + let fullname = Type_equal.Id.name k in + match find_separator fullname with + | None -> fullname + | Some len -> String.subo fullname ~pos:(len+1) + +let typeid (Value.T (k,_)) = Type_equal.Id.name k + +let info typeid = + Hashtbl.find_and_call types typeid + ~if_found:ident + ~if_not_found:(fun typeid -> + invalid_argf "Can't deserialize type %s, \ + as it is no longer known to the system" + typeid ()) + +let ops x = info (typeid x) +let compare_value x y = (ops x).compare x y let compare = compare_value +let sexp_of_value x = Sexp.List [ + Sexp.Atom (typeid x); + (ops x).to_sexp x; + ] + +let value_of_sexp = function + | Sexp.List [Atom typeid; repr] -> + (info typeid).of_sexp repr + | _ -> invalid_arg "Value.t_of_sexp: broken representation" + + +module Univ = struct + type t = Value.t + let sexp_of_t = sexp_of_value + let t_of_sexp = value_of_sexp + let compare = compare_value + module Repr = struct + type t = { + typeid : string; + data : string; + } [@@deriving bin_io] + end + + include Binable.Of_binable(Repr)(struct + type t = Value.t + let to_binable x = Repr.{ + typeid = typeid x; + data = (ops x).to_string x; + } + let of_binable {Repr.typeid; data} = + (info typeid).of_string data + end) +end -let create tag x : value = - let info = Hashtbl.find_exn types tag.tid in - { - uuid = tag.tid; - data = info.to_string (Univ.create tag.key x); - } +let create {key} x = Value.T (key,x) +let is {key} (Value.T (k,_)) = Type_equal.Id.same key k -let univ x = (typeof x).of_string x.data -let get t x = - if Typeid.(x.uuid = t.tid) - then Univ.match_ (univ x) t.key +let get + : type a. a tag -> Value.t -> a option = + fun {key} (Value.T (k,x)) -> + if Type_equal.Id.same key k + then + let T = Equal.proof key k in + Some x else None -let get_exn t x = match get t x with - | Some x -> x - | None -> invalid_arg "Value.get_exn: wrong tag" -let is t x = Typeid.equal x.uuid t.tid -let tagname x = Univ.type_id_name (univ x) -let typeid x = x.uuid + +let get_exn + : type a. a tag -> Value.t -> a = + fun {key} (Value.T (k,x)) -> + let T = Equal.proof key k in + x + module Tag = struct type 'a t = 'a tag let name tag = Type_equal.Id.name tag.key - let typeid tag = tag.tid + let typeid tag = name tag let key tag = tag.key + let uid tag = uid tag.key let register (type a) ~name ~uuid (typ : (module S with type t = a)) : a tag = - let uuid = Typeid.of_string (string_of_format uuid) in - register ~name:(string_of_format name) ~uuid typ + register ~name ~uuid typ + + let register_slot slot ops = register_slot slot ops + let slot tag = tag.slot let same_witness t1 t2 = Option.try_with (fun () -> @@ -137,10 +203,10 @@ module Match = struct type 's t = { default : (unit -> 's); - handlers : (value -> 's) Typeid.Map.t; + handlers : (Value.t -> 's) Map.M(Uid).t; } - let empty = Typeid.Map.empty + let empty = Map.empty (module Uid) let default default = { handlers = empty; @@ -148,66 +214,71 @@ module Match = struct } let case t f (tab : 's t) = - let h = Map.set tab.handlers t.tid (fun v -> f (get_exn t v)) in + let h = Map.set tab.handlers (Tag.uid t) (fun v -> f (get_exn t v)) in {tab with handlers = h} - let run v tab = - match Map.find tab.handlers v.uuid with + let run (Value.T (k,_) as v) tab = + match Map.find tab.handlers (uid k) with | Some f -> f v | None -> tab.default () - let switch = run - let select x y = switch y x end module Dict = struct - - type t = value Typeid.Map.t [@@deriving bin_io, compare, sexp] - - let uuid tag = tag.tid - - let empty = Typeid.Map.empty - - let set t tag data = - Map.set t ~key:(uuid tag) ~data:(create tag data) - - let mem t key = Map.mem t (uuid key) - let remove t key = Map.remove t (uuid key) - let is_empty = Map.is_empty - let find t tag = match Map.find t (uuid tag) with - | None -> None - | Some x -> get tag x - let add t key data = - if mem t key then `Duplicate else `Ok (set t key data) - let change t key update = - let orig = find t key in - let next = update orig in - match next with - | Some data -> set t key data - | None -> if Option.is_none orig then t else remove t key - - let to_sequence t : (typeid * value) Sequence.t = - Map.to_sequence t - let data t : value Sequence.t = - Map.to_sequence t |> Sequence.map ~f:snd - - let filter t ~f = Map.filter t ~f - + type t = Univ_map.t + let empty = Univ_map.empty + let is_empty = Univ_map.is_empty + let set dict {key} x = Univ_map.set dict key x + let remove dict {key} = Univ_map.remove dict key + let mem dict {key} = Univ_map.mem dict key + let find dict {key} = Univ_map.find dict key + let add dict {key} x = Univ_map.add dict key x + let change dict {key} f = Univ_map.change dict key ~f + let data dict = + Univ_map.to_alist dict |> + Seq.of_list + let to_sequence dict = + Seq.map (data dict) ~f:(fun v -> typeid v,v) + let filter t ~f = + data t |> + Seq.fold ~init:empty ~f:(fun dict (Value.T (k,x) as v) -> + if f v then Univ_map.set dict k x else dict) + + let compare x y = + compare_list + compare_value + (Univ_map.to_alist x) + (Univ_map.to_alist y) + + module Data = struct + type t = Univ.t list [@@deriving bin_io, sexp] + let of_dict = Univ_map.to_alist + let to_dict = + List.fold ~init:empty ~f:(fun dict (Value.T (k,x)) -> + Univ_map.set dict k x) + end + include Binable.Of_binable(Data)(struct + type t = Univ_map.t + let to_binable = Data.of_dict + let of_binable = Data.to_dict + end) + include Sexpable.Of_sexpable(Data)(struct + type t = Univ_map.t + let to_sexpable = Data.of_dict + let of_sexpable = Data.to_dict + end) end type dict = Dict.t [@@deriving bin_io, compare, sexp] - - +type t = Univ.t [@@deriving bin_io, compare, sexp] +include struct type value = Univ.t [@@deriving bin_io] end include Regular.Make(struct - type nonrec t = value [@@deriving bin_io, compare, sexp] - let compare = compare + type t = Univ.t [@@deriving bin_io, compare, sexp] + let compare = Univ.compare let hash = Hashtbl.hash - - let pp ppf v = - let t = typeof v in - t.pp ppf (t.of_string v.data) + let pp ppf v = (ops v).pp ppf v let module_name = Some "Bap.Std.Value" - let version = "1.0.0" + let version = "2.0.0" end) diff --git a/lib/bap_types/bap_value.mli b/lib/bap_types/bap_value.mli index 327a9b0d7..5e5353fd6 100644 --- a/lib/bap_types/bap_value.mli +++ b/lib/bap_types/bap_value.mli @@ -1,3 +1,5 @@ +open Bap_core_theory + open Core_kernel open Regular.Std @@ -38,9 +40,14 @@ end module Tag : sig type 'a t = 'a tag - val register : name:literal -> uuid:literal -> + val register : name:string -> uuid:string -> (module S with type t = 'a) -> 'a tag + val register_slot : (Theory.program,'a option) KB.slot -> (module S with type t = 'a) -> 'a tag + + val slot : 'a tag -> (Theory.program, 'a option) KB.slot + + val name : 'a tag -> string val typeid : 'a tag -> typeid val key : 'a tag -> 'a Type_equal.Id.t diff --git a/lib/bap_types/bap_var.ml b/lib/bap_types/bap_var.ml index db59f46a7..c414f5098 100644 --- a/lib/bap_types/bap_var.ml +++ b/lib/bap_types/bap_var.ml @@ -1,51 +1,133 @@ open Core_kernel +open Bap_core_theory open Regular.Std open Bap_common -module Id = struct - include Bap_state.Make(struct - type t = Int63.t ref - let create () = ref Int63.zero - end) - let create () = - let id = !state in - Int63.incr id; - !id -end -module T = struct - type t = { - var : string; - ind : int; - typ : typ; - vir : bool; - } [@@deriving sexp, bin_io, compare] - - let hash v = String.hash v.var - let module_name = Some "Bap.Std.Var" - let version = "1.0.0" - let pp fmt v = - Format.fprintf fmt "%s%s" v.var - (if v.ind <> 0 then sprintf ".%d" v.ind else "" ) -end +type var = Var : 'a Theory.Var.t -> var +type t = var + +let reify v = Var v +let sort (Var v) = + Theory.Value.Sort.forget (Theory.Var.sort v) + + +let ident (Var v) = Theory.Var.ident v +let name (Var v) = Theory.Var.name v +let with_index (Var v) ver = + Var (Theory.Var.versioned v ver) +let index (Var v) = Theory.Var.version v +let base v = with_index v 0 + -include T +let typ v = + let s = sort v in + Theory.Bool.refine s |> function + | Some _ -> Type.Imm 1 + | None -> Theory.Bitv.refine s |> function + | Some bits -> Type.Imm (Theory.Bitv.size bits) + | None -> Theory.Mem.refine s |> function + | None -> Type.Unk + | Some mems -> + let ks, vs = Theory.Mem.(keys mems, vals mems) in + let ks, vs = Theory.Bitv.(size ks, size vs) in + match Bap_size.addr_of_int_opt ks, Bap_size.of_int_opt vs with + | Some ks, Some vs -> Type.Mem (ks,vs) + | _ -> Type.Unk -let name v = v.var -let with_index v ind = {v with ind} -let index v = v.ind -let base v = {v with ind = 0} -let typ v = v.typ -let is_physical v = not v.vir -let is_virtual v = v.vir +let is_virtual (Var v) = Theory.Var.is_virtual v +let is_physical v = not (is_virtual v) + +let unknown = + let unknown = + Theory.Value.Sort.Name.declare ~package:"bap-std" "Unknown" in + Theory.Value.Sort.sym unknown + +let sort_of_typ t = + let ret = Theory.Value.Sort.forget in + match t with + | Type.Imm 1 -> ret Theory.Bool.t + | Type.Imm m -> ret @@ Theory.Bitv.define m + | Type.Mem (ks,vs) -> + let ks,vs = Bap_size.(in_bits ks, in_bits vs) in + let ks,vs = Theory.Bitv.(define ks, define vs) in + ret @@ Theory.Mem.define ks vs + | Type.Unk -> ret @@ unknown + +module Generator = struct + module Toplevel = Bap_toplevel + open KB.Syntax + + let empty = Theory.Var.Ident.of_string "nil" + let ident_t = KB.Domain.flat ~empty + ~inspect:Theory.Var.Ident.sexp_of_t + ~equal:Theory.Var.Ident.equal + "ident" + type generator = Gen + let generator = + KB.Class.declare ~package:"bap.std.internal" "var-generator" Gen + + let ident = Toplevel.var "ident" + + let fresh s = + Toplevel.put ident begin + Theory.Var.fresh s >>| Theory.Var.ident + end; + Toplevel.get ident +end + let create ?(is_virtual=false) ?(fresh=false) name typ = - let var = - if fresh then name ^ Int63.to_string (Id.create ()) - else name in - {ind = 0; var; typ; vir = is_virtual} + let sort = sort_of_typ typ in + if is_virtual || fresh + then + let iden = Generator.fresh sort in + Var (Theory.Var.create sort iden) + else + Var (Theory.Var.define sort name) let same x y = base x = base y + +module T = struct + type t = var + + module Repr = struct + type t = {name : Theory.Var.ident; sort : Theory.Value.Sort.Top.t} + [@@deriving bin_io, compare, sexp] + + let of_var v = { + name = ident v; + sort = sort v + } + let to_var {name;sort} = + Var (Theory.Var.create sort name) + end + + include Sexpable.Of_sexpable(Repr)(struct + type t = var + let to_sexpable = Repr.of_var + let of_sexpable = Repr.to_var + end) + + include Binable.Of_binable(Repr)(struct + type t = var + let to_binable = Repr.of_var + let of_binable = Repr.to_var + end) + + let compare x y = + Theory.Var.Ident.compare (ident x) (ident y) + + let hash x = Hashtbl.hash (ident x) + + let version = "2.0.0" + + let module_name = Some "Bap.Std.Var" + + let pp ppf x = + Format.fprintf ppf "%s" (Theory.Var.Ident.to_string (ident x)) +end + include Regular.Make(T) diff --git a/lib/bap_types/bap_var.mli b/lib/bap_types/bap_var.mli index 80f3455af..2a387ee3c 100644 --- a/lib/bap_types/bap_var.mli +++ b/lib/bap_types/bap_var.mli @@ -1,8 +1,16 @@ open Core_kernel +open Bap_core_theory + open Regular.Std open Bap_common type t include Regular.S with type t := t + +val reify : 'a Theory.var -> t +val ident : t -> Theory.Var.ident +val sort : t -> Theory.Value.Sort.Top.t + +(* Old interface *) val create : ?is_virtual:bool -> ?fresh:bool -> string -> typ -> t val with_index : t -> int -> t val index : t -> int @@ -12,4 +20,3 @@ val name : t -> string val typ : t -> typ val is_virtual : t -> bool val is_physical : t -> bool -module Id : Bap_state.S diff --git a/lib/bitvec/.merlin b/lib/bitvec/.merlin new file mode 100644 index 000000000..2575fd4b8 --- /dev/null +++ b/lib/bitvec/.merlin @@ -0,0 +1,2 @@ +REC +PKG zarith diff --git a/lib/bitvec/bitvec.ml b/lib/bitvec/bitvec.ml new file mode 100644 index 000000000..ad4c7772d --- /dev/null +++ b/lib/bitvec/bitvec.ml @@ -0,0 +1,504 @@ +type t = Z.t + +(* Invariant: a bitvector is always normed. + + A normed bitvector is always treated by Z as non-negative: + - Z.sign (norm w x) >= 0 + + We have to enforce this invariant since in Z negative numbers are + not equal to their two complement forms, in other words, + Z.(-1 != 0xFFFFF..FF). + + This makes sense, since in Z a number has an arbitrary length and + thus the number of ones in the two complement form is essentially + unlimited, thus any fixed length bitvector of all ones is not equal + to any negative number. Thus in order to implement canonical + comparison function we need to parametrize it with the length of + the length of the bitvector, which in fact prevents us from having + one single type for bivector sets, maps, and any regular or + comparable interface in general. Another problem with + non-normalized bitvectors is serialization. The [to_bits] function + doesn't preserve the sign, the default marshaling function is not + portable between OCaml runtimes with different word sizes, and last + but not least, the same problem with non-canonical representation + perists = 0xFFFF_FFFF is not equal to -1, or even more, 0xDEADBEEF + is not equal to -559038737. + + The are two ramifications of enforcing the normalized canonical + representation: + + - some extra performance cost due to occasional calls to normalize + (it is usually just one instruction - logand) + + - values with more than 62 significand bits will be stored in a + boxed notation (we loose 1 bit wrt to the non-normalized) + representation. +*) + +type modulus = {m : Z.t} [@@unboxed] + + +type 'a m = modulus -> 'a + + +(* we predicate the normalization, to prevent extra allocations + in cases when the bitvector has the boxed representation, so + that in cases when `x = norm x` we do not create a fresh new + value with the same contents. + + The normalization is required for negative numbers, since Z + represent them non canonically. + + Specialized versions may use a different normalization procedure, + this is the general case that shall work with all range of widths. +*) +let norm {m} x = + if Z.sign x < 0 || Z.geq x m + then Z.(x land m) + else x +[@@inline] + + +(* defining extrenal as the apply primitive will enable + inlining (with any other operator definition the vanilla + version of OCaml will create a closure) +*) +external (mod) : 'a m -> modulus -> 'a = "%apply" + +let modulus w = { + m = Z.(one lsl w - one); +} + +let m1 = modulus 1 +let m8 = modulus 8 +let m32 = modulus 32 +let m64 = modulus 64 + + +let compare x y = + if x == y then 0 else Z.compare x y +[@@inline] + +let hash x = Z.hash x [@@inline] + +let one = Z.one +let zero = Z.zero +let ones m = norm m @@ Z.minus_one +let bool x = if x then one else zero +let int x m = norm m @@ Z.of_int x [@@inline] +let int32 x m = norm m @@ Z.of_int32 x [@@inline] +let int64 x m = norm m @@ Z.of_int64 x [@@inline] +let bigint x m = norm m @@ x [@@inline] + +let append w1 w2 x y = + let w = w1 + w2 in + let ymask = Z.(one lsl w2 - one) in + let zmask = Z.(one lsl w - one) in + let ypart = Z.(y land ymask) in + let xpart = Z.(x lsl w2) in + Z.(xpart lor ypart land zmask) +[@@inline] + +let extract ~hi ~lo x = + Z.extract x lo (hi - lo + 1) +[@@inline] + +let setbit x n = Z.(x lor (one lsl n)) [@@inline] + +let select bits x = + let rec loop n y = function + | [] -> y + | v :: vs -> + let y = if Z.testbit x v + then setbit y n + else y in + loop (n+1) y vs in + loop 0 zero bits + +let repeat m ~times:n x = + let mask = Z.(one lsl m - one) in + let x = Z.(x land mask) in + let rec loop i y = + if i < n + then + let off = i * m in + let stamp = Z.(x lsl off) in + loop (i+1) Z.(y lor stamp) + else y in + loop 0 zero + +let concat m xs = + let mask = Z.(one lsl m - one) in + List.fold_left (fun y x -> + let x = Z.(x land mask) in + Z.(y lsl m lor x)) zero xs + +(* we can't use higher functions such as op1 and op2, + to lift Z operations to bitv since this will prevent + inlining and will introduce extra indirect calls, so + we have to make it in a verbose method. Flambda, where + are you? +*) + +let succ x m = norm m @@ Z.succ x [@@inline] +let pred x m = norm m @@ Z.pred x [@@inline] +let nsucc x n m = norm m @@ Z.(x + of_int n) [@@inline] +let npred x n m = norm m @@ Z.(x - of_int n) [@@inline] +let lnot x m = norm m @@ Z.lognot x [@@inline] +let neg x m = norm m @@ Z.neg x [@@inline] +let nth x n _ = Z.testbit x n [@@inline] +let msb x m = Z.(equal m.m (norm m x lor (m.m asr 1))) [@@inline] +let lsb x _ = Z.is_odd x [@@inline] +let abs x m = if msb x m then neg x m else x [@@inline] +let add x y m = norm m @@ Z.add x y [@@inline] +let sub x y m = norm m @@ Z.sub x y [@@inline] +let mul x y m = norm m @@ Z.mul x y [@@inline] +let div x y m = if Z.(y = zero) + then ones m + else norm m @@ Z.div x y +[@@inline] + +let sdiv x y m = match msb x m, msb y m with + | false, false -> div x y m + | true, false -> neg (div (neg x m) y m) m + | false, true -> neg (div x (neg y m) m) m + | true, true -> div (neg x m) (neg y m) m +[@@inline] + +let rem x y m = + if Z.(y = zero) + then x + else norm m @@ Z.rem x y +[@@inline] + +(* 2's complement signed remainder (sign follows dividend) *) +let srem x y m = match msb x m, msb y m with + | false,false -> rem x y m + | true,false -> neg (rem (neg x m) y m) m + | false,true -> neg (rem x (neg y m) m) m + | true,true -> neg (rem (neg x m) (neg y m) m) m +[@@inline] + +(* 2's complement signed remained (sign follows the divisor) *) +let smod s t m = + let u = rem s t m in + if Z.(u = zero) then u + else match msb s m, msb t m with + | false,false -> u + | true,false -> add (neg u m) t m + | false,true -> add u t m + | true,true -> neg u m +[@@inline] + +let logand x y m = norm m @@ Z.logand x y [@@inline] +let logor x y m = norm m @@ Z.logor x y [@@inline] +let logxor x y m = norm m @@ Z.logxor x y [@@inline] + + +(* extracts no more than [m] bits from [x]. + The [Z.to_int] function may allocate since it throws + an exception in case of the overflow. So we have to + pay with the GC check after each call to [Z.to_int], + so instead of no-op we have a call, and a couple of + blocks with a dozen of instructions. + + The implementation below is safe and doens't rely on + whether the FAST_PATH or NATINT options are enabled + in zarith. It first looks whether [x] is actually an + [int] and then casts it to the [int] type, otherwise + it extracts the low (min m Sys.int_size) bits and + cast them to_int using the slow Z.to_int, so that in + case if zarith is not using NATINT representation, + we will still be on the safe side. + + Note: this function is pure optimization, so in case + if you're not sure, it could be replaced with its + else branch. +*) +let to_int_fast x m : int = + if Obj.is_int (Obj.repr x) then Obj.magic x + else Z.to_int @@ Z.signed_extract x 0 (min m Sys.int_size) +[@@inline] + +let shift n {m} ~overshift ~in_bounds = + let w = Z.numbits m in + if Z.(lt n (of_int w)) + then in_bounds w (to_int_fast n w) + else overshift +[@@inline] + +let lshift x n m = shift n m + ~overshift:zero + ~in_bounds:(fun _ n -> norm m @@ Z.shift_left x n) +[@@inline] + +let rshift x n m = shift n m + ~overshift:zero + ~in_bounds:(fun _ n -> norm m @@ Z.shift_right x n) +[@@inline] + +let arshift x n m = + let msb = msb x m in + shift n m + ~in_bounds:(fun w n -> + let x = Z.(x asr n) in + norm m @@ if msb + then + let n = w - n in + let y = ones mod m in + Z.(y lsl n lor x) + else x) + ~overshift:(if msb then ones m else zero) +[@@inline] + +let gcd x y m = + if Z.(equal x zero) then y else + if Z.(equal y zero) then x else + norm m @@ Z.gcd x y +[@@inline] + +let lcm x y m = + if Z.(equal x zero) || Z.(equal y zero) then zero + else norm m @@ Z.lcm x y +[@@inline] + +let gcdext x y m = + if Z.(equal x zero) then (y,zero,one) else + if Z.(equal y zero) then (x,one,zero) else + let (g,a,b) = Z.gcdext x y in + (norm m g, norm m a, norm m b) +[@@inline] + +let signed_compare x y m = match msb x m, msb y m with + | true, true -> compare y x + | false,false -> compare x y + | true,false -> -1 + | false,true -> 1 +[@@inline] + +module Syntax = struct + let (!!) x m = int x m [@@inline] +let (~-) x m = neg x m [@@inline] +let (~~) x m = lnot x m [@@inline] +let (+) x y m = add x y m [@@inline] +let (-) x y m = sub x y m [@@inline] +let ( * ) x y m = mul x y m [@@inline] +let (/) x y m = div x y m [@@inline] +let (/$) x y m = sdiv x y m [@@inline] +let (%) x y m = rem x y m [@@inline] +let (%$) x y m = smod x y m [@@inline] +let (%^) x y m = srem x y m [@@inline] +let (land) x y m = logand x y m [@@inline] +let (lor) x y m = logor x y m [@@inline] +let (lxor) x y m = logxor x y m [@@inline] +let (lsl) x y m = lshift x y m [@@inline] +let (lsr) x y m = rshift x y m [@@inline] +let (asr) x y m = arshift x y m [@@inline] +let (++) x n m = nsucc x n m [@@inline] +let (--) x n m = npred x n m [@@inline] +end + +module type S = sig + type 'a m + val bool : bool -> t + val int : int -> t m + val int32 : int32 -> t m + val int64 : int64 -> t m + val bigint : Z.t -> t m + val zero : t + val one : t + val ones : t m + val succ : t -> t m + val nsucc : t -> int -> t m + val pred : t -> t m + val npred : t -> int -> t m + val neg : t -> t m + val lnot : t -> t m + val abs : t -> t m + val add : t -> t -> t m + val sub : t -> t -> t m + val mul : t -> t -> t m + val div : t -> t -> t m + val sdiv : t -> t -> t m + val rem : t -> t -> t m + val srem : t -> t -> t m + val smod : t -> t -> t m + val nth : t -> int -> bool m + val msb : t -> bool m + val lsb : t -> bool m + val logand : t -> t -> t m + val logor : t -> t -> t m + val logxor : t -> t -> t m + val lshift : t -> t -> t m + val rshift : t -> t -> t m + val arshift : t -> t -> t m + val gcd : t -> t -> t m + val lcm : t -> t -> t m + val gcdext : t -> t -> (t * t * t) m + + val (!$) : string -> t + val (!!) : int -> t m + val (~-) : t -> t m + val (~~) : t -> t m + val ( + ) : t -> t -> t m + val ( - ) : t -> t -> t m + val ( * ) : t -> t -> t m + val ( / ) : t -> t -> t m + val ( /$ ) : t -> t -> t m + val (%) : t -> t -> t m + val (%$) : t -> t -> t m + val (%^) : t -> t -> t m + val (land) : t -> t -> t m + val (lor) : t -> t -> t m + val (lxor) : t -> t -> t m + val (lsl) : t -> t -> t m + val (lsr) : t -> t -> t m + val (asr) : t -> t -> t m + val (++) : t -> int -> t m + val (--) : t -> int -> t m +end + +module type Modulus = sig + val modulus : modulus +end + +let to_string = Z.format "%#x" +let of_string x = + let r = Z.of_string x in + if Z.sign r < 0 + then invalid_arg + (x ^ " - invalid string representation, sign is not expected") + else r + +let (!$) x = of_string x [@@inline] +let (!!) x m = int x m [@@inline] + +(* all bitvectors are normalized and therefore non-negative, + so we don't need to check the lower bound. +*) +let max_uint = Z.(one lsl Sys.int_size - one) +let max_sint = Z.(of_int max_int) +let max_uint32 = Z.(one lsl 32 - one) +let max_sint32 = Z.(one lsl 31 - one) +let max_uint64 = Z.(one lsl 64 - one) +let max_sint64 = Z.(one lsl 63 - one) + + +let fits_int = Z.fits_int +let fits_int32 = Z.fits_int32 +let fits_int64 = Z.fits_int64 + +let doesn't_fit r x = + failwith (to_string x ^ " doesn't fit the " ^ r ^ " type") + +let convert tname size convert max_signed max_unsigned x = + if Z.(x <= max_signed) then convert x + else if Z.(x <= max_unsigned) + then convert (Z.signed_extract x 0 size) + else doesn't_fit tname x + +let to_int x = + convert "int" Sys.int_size Z.to_int max_sint max_uint x +[@@inline] + +let to_int32 x = + convert "int32" 32 Z.to_int32 max_sint32 max_uint32 x +[@@inline] + +let to_int64 x = + convert "int" 64 Z.to_int64 max_sint64 max_uint64 x +[@@inline] + +let to_bigint x = x [@@inline] + +let of_binary = Z.of_bits +let to_binary = Z.to_bits +let pp ppf x = + Format.fprintf ppf "%s" (to_string x) + +module Make(M : Modulus) : S with type 'a m = 'a = struct + type 'a m = 'a + let m = M.modulus + + let bool x = bool x [@@inline] + let int x = int x mod m [@@inline] + let int32 x = int32 x mod m [@@inline] + let int64 x = int64 x mod m [@@inline] + let bigint x = bigint x mod m [@@inline] + let zero = zero + let one = one + let ones = ones mod m + let succ x = succ x mod m [@@inline] + let nsucc x n = nsucc x n mod m [@@inline] + let pred x = pred x mod m [@@inline] + let npred x n = npred x n mod m [@@inline] + let neg x = neg x mod m [@@inline] + let lnot x = lnot x mod m [@@inline] + let abs x = abs x mod m [@@inline] + let add x y = add x y mod m [@@inline] + let sub x y = sub x y mod m [@@inline] + let mul x y = mul x y mod m [@@inline] + let div x y = div x y mod m [@@inline] + let sdiv x y = sdiv x y mod m [@@inline] + let rem x y = rem x y mod m [@@inline] + let srem x y = srem x y mod m [@@inline] + let smod x y = smod x y mod m [@@inline] + let nth x y = nth x y mod m [@@inline] + let msb x = msb x mod m [@@inline] + let lsb x = lsb x mod m [@@inline] + let logand x y = logand x y mod m [@@inline] + let logor x y = logor x y mod m [@@inline] + let logxor x y = logxor x y mod m [@@inline] + let lshift x y = lshift x y mod m [@@inline] + let rshift x y = rshift x y mod m [@@inline] + let arshift x y = arshift x y mod m [@@inline] + let gcd x y = gcd x y mod m [@@inline] + let lcm x y = lcm x y mod m [@@inline] + let gcdext x y = gcdext x y mod m [@@inline] + + let (!$) x = of_string x [@@inline] + let (!!) x = int x [@@inline] + let (~-) x = neg x [@@inline] + let (~~) x = lnot x [@@inline] + let (+) x y = add x y [@@inline] + let (-) x y = sub x y [@@inline] + let ( * ) x y = mul x y [@@inline] + let (/) x y = div x y [@@inline] + let (/$) x y = sdiv x y [@@inline] + let (%) x y = rem x y [@@inline] + let (%$) x y = smod x y [@@inline] + let (%^) x y = srem x y [@@inline] + let (land) x y = logand x y [@@inline] + let (lor) x y = logor x y [@@inline] + let (lxor) x y = logxor x y [@@inline] + let (lsl) x y = lshift x y [@@inline] + let (lsr) x y = rshift x y [@@inline] + let (asr) x y = arshift x y [@@inline] + let (++) x n = nsucc x n [@@inline] + let (--) x n = npred x n [@@inline] +end [@@inline] + +module M1 = Make(struct + let modulus = m1 + end) + +module M8 = Make(struct + let modulus = m8 + end) + +module M32 = Make(struct + let modulus = m32 + end) + +module M64 = Make(struct + let modulus = m64 + end) + +include Syntax +let equal x y = compare x y = 0 [@@inline] +let (<) x y = compare x y < 0 [@@inline] +let (>) x y = compare x y > 0 [@@inline] +let (<=) x y = compare x y <= 0 [@@inline] +let (>=) x y = compare x y >= 0 [@@inline] +let (=) x y = compare x y = 0 [@@inline] +let (<>) x y = compare x y <> 0 [@@inline] diff --git a/lib/bitvec/bitvec.mli b/lib/bitvec/bitvec.mli new file mode 100644 index 000000000..20403856a --- /dev/null +++ b/lib/bitvec/bitvec.mli @@ -0,0 +1,547 @@ +(** abstract representation of a fixed size bitvector *) +type t + +(** a computation in some modulo *) +type 'a m + +(** type denoting the arithmetic modulus *) +type modulus + + +(** [modulus s] is the modulus of bitvectors with size [s]. + + This is a number $2^s-1$, also known as a Mersenne number. +*) +val modulus : int -> modulus + + +(** [m1 = modulus 1] = $1$ is the modulus of bitvectors with size [1] *) +val m1 : modulus + +(** [m8 = modulus 8] = $255$ is the modulus of bitvectors with size [8] *) +val m8 : modulus + +(** [m32 = modulus 32] = $2^32-1$ is the modulus of bitvectors with size [32] *) +val m32 : modulus + +(** [m64 = modulus 64] = $2^64-1$ is the modulus of bitvectors with size [64] *) +val m64 : modulus + + +(** [(x y) mod m] applies operation [] modulo [m]. + + Example: [(x + y) mod m] returns the sum of [x] and [y] modulo [m]. + + Note: the [mod] function is declared as a primitive to enable + support for inlining in non flambda versions of OCaml. Indeed, + underneath the hood the ['a m] type is defined as the reader monad + [modulus -> 'a], however we don't want to create a closure every + time we compute an operation over bitvectors. With this trick, all + versions of OCaml no matter the optimization options will inline + [(x+y) mod m] and won't create any closures, even if [m] is not + known at compile time. +*) +external (mod) : 'a m -> modulus -> 'a = "%apply" + +module type S = sig + (** an abstract representation of an operation modulo some number.*) + type 'a m + + (** [bool x] returns [one] if [x] and [zero] otherwise. *) + val bool : bool -> t + + (** [int n mod m] is [n] modulo [m]. *) + val int : int -> t m + + (** [int32 n mod m] is [n] modulo [m]. *) + val int32 : int32 -> t m + + (** [int64 n mod m] is [n] modulo [m]. *) + val int64 : int64 -> t m + + (** [bigint n mod m] is [n] modulo [m]. *) + val bigint : Z.t -> t m + + (** [zero] is [0]. *) + val zero : t + + (** [one] is [1]. *) + val one : t + + (** [ones mod m] is a bitvector of size [m] with all bits set *) + val ones : t m + + (** [succ x mod m] is the successor of [x] modulo [m] *) + val succ : t -> t m + + + (** [nsucc x n mod m] is the [n]th successor of [x] modulo [m] *) + val nsucc : t -> int -> t m + + + (** [pred x mod m] is the predecessor of [x] modulo [m] *) + val pred : t -> t m + + (** [npred x n mod m] is the [n]th predecessor of [x] modulo [m] *) + val npred : t -> int -> t m + + (** [neg x mod m] is the 2-complement of [x] modulo [m]. *) + val neg : t -> t m + + (** [lnot x] is the 1-complement of [x] modulo [m]. *) + val lnot : t -> t m + + (** [abs x mod m] absolute value of [x] modulo [m]. + + The absolute value of [x] is equal to [neg x] if + [msb x] and to [x] otherwise. *) + val abs : t -> t m + + (** [add x y mod m] is [x + y] modulo [m] *) + val add : t -> t -> t m + + (** [sub x y mod m] is [x - y] modulo [m] *) + val sub : t -> t -> t m + + (** [mul x y mod m] is [x * y] modulo [m] *) + val mul : t -> t -> t m + + (** [div x y mod m] is [x / y] modulo [m], + + where [/] is the truncating towards zero division, + that returns [ones m] if [y = 0]. + *) + val div : t -> t -> t m + + + (** [sdiv x y mod m] is signed division of [x] by [y] modulo [m], + + The signed division operator is defined in terms of the [div] + operator as follows: + {v + / + | div x y mod m : if not mx /\ not my + | neg (div (neg x) y) mod m if mx /\ not my + x sdiv y mod m = < + | neg (div x (neg y)) mod m if not mx /\ my + | div (neg x) (neg y) mod m if mx /\ my + \ + + where mx = msb x mod m, + and my = msb y mod m. + v} + + *) + val sdiv : t -> t -> t m + + (** [rem x y mod m] is the remainder of [x / y] modulo [m]. *) + val rem : t -> t -> t m + + (** [srem x y mod m] is the signed remainder [x / y] modulo [m]. + + This version of the signed remainder where the sign follows the + dividend, and is defined via the [rem] operation as follows + + {v + / + | rem x y mod m : if not mx /\ not my + | neg (rem (neg x) y) mod m if mx /\ not my + x srem y mod m = < + | neg (rem x (neg y)) mod m if not mx /\ my + | neg (rem (neg x) (neg y)) mod m if mx /\ my + \ + + where mx = msb x mod m, + and my = msb y mod m. + v} + *) + val srem : t -> t -> t m + + (** [smod x y mod m] is the signed remainder of [x / y] modulo [m]. + + This version of the signed remainder where the sign follows the + divisor, and is defined in terms of the [rem] operation as + follows: + + {v + / + | u if u = 0 + x smod y mod m = < + | v if u <> 0 + \ + + / + | u if not mx /\ not my + | add (neg u) y mod m if mx /\ not my + v = < + | add u x mod m if not mx /\ my + | neg u mod m if mx /\ my + \ + + where mx = msb x mod m, + and my = msb y mod m, + and u = rem s t mod m. + v} + *) + val smod : t -> t -> t m + + + (** [nth x n mod m] is [true] if [n]th bit of [x] is [set]. + + Returns [msb x mod m] if [n >= m] + and [lsb x mod m] if [n < 0] + *) + val nth : t -> int -> bool m + + + (** [msb x mod m] returns the most significand bit of [x]. *) + val msb : t -> bool m + + + (** [lsb x mod m] returns the least significand bit of [x]. *) + val lsb : t -> bool m + + (** [logand x y mod m] is a bitwise logical and of [x] and [y] modulo [m] *) + val logand : t -> t -> t m + + (** [logor x y mod m] is a bitwise logical or of [x] and [y] modulo [m]. *) + val logor : t -> t -> t m + + (** [logxor x y mod m] is exclusive [or] between [x] and [y] modulo [m] *) + val logxor : t -> t -> t m + + (** [lshift x y mod m] shifts [x] to left by [y]. + Returns [0] is [y >= m]. + *) + val lshift : t -> t -> t m + + (** [rshift x y mod m] shifts [x] right by [y] bits. + Returns [0] if [y >= m] + *) + val rshift : t -> t -> t m + + (** [arshift x y mod m] shifts [x] right by [y] with [msb x] + filling. + + Returns [ones mod m] if [y >= m /\ msb x mod m] + and [zero] if [y >= m /\ msb x mod m = 0] + *) + val arshift : t -> t -> t m + + + (** [gcd x y mod m] returns the greatest common divisor modulo [m] + + [gcd x y] is the meet operation of the divisibility lattice, + with [0] being the top of the lattice and [1] being the bottom, + therefore [gcd x 0 = gcd x 0 = x]. + *) + val gcd : t -> t -> t m + + + (** [lcm x y mod] returns the least common multiplier modulo [m]. + + [lcm x y] is the meet operation of the divisibility lattice, + with [0] being the top of the lattice and [1] being the bottom, + therefore [lcm x 0 = lcm 0 x = 0] + + *) + val lcm : t -> t -> t m + + + (** [(g,a,b) = gcdext x y mod m], where + - [g = gcd x y mod m], + - [g = (a * x + b * y) mod m]. + + The operation is well defined if one or both operands are equal + to [0], in particular: + - [(x,1,0) = gcdext(x,0)], + - [(x,0,1) = gcdext(0,x)]. + *) + val gcdext : t -> t -> (t * t * t) m + + + + (** [!$x] is [of_string x] *) + val (!$) : string -> t + + (** [!!x mod m] is [int x mod m] *) + val (!!) : int -> t m + + + (** [~-x mod m] is [neg x mod m] *) + val (~-) : t -> t m + + + (** [~~x mod m] is [lnot x mod m] *) + val (~~) : t -> t m + + (** [(x + y) mod m] is [add x y mod m] *) + val ( + ) : t -> t -> t m + + (** [(x - y) mod m] is [sub x y mod m *) + val ( - ) : t -> t -> t m + + (** [(x * y) mod m] is [mul x y mod m] *) + val ( * ) : t -> t -> t m + + (** [(x / y) mod m] is [div x y mod m] *) + val ( / ) : t -> t -> t m + + (** [x /$ y mod m] is [sdiv x y mod m] *) + val ( /$ ) : t -> t -> t m + + (** [(x % y) mod m] is [rem x y mod m] *) + val (%) : t -> t -> t m + + (** [(x %$ y) mod m] is [smod x y mod m] *) + val (%$) : t -> t -> t m + + (** [(x %^ y) mod m] is [srem x y mod m] *) + val (%^) : t -> t -> t m + + (** [(x land y) mod m] is [logand x y mod m] *) + val (land) : t -> t -> t m + + (** [(x lor y) mod m] is [logor x y mod m] *) + val (lor) : t -> t -> t m + + (** [(x lxor y) mod m] is [logxor x y mod m] *) + val (lxor) : t -> t -> t m + + (** [(x lsl y) mod m] [lshift x y mod m] *) + val (lsl) : t -> t -> t m + + (** [(x lsr y) mod m] is [rshift x y mod m] *) + val (lsr) : t -> t -> t m + + (** [(x asr y) = arshift x y] *) + val (asr) : t -> t -> t m + + (** [(x ++ n) mod m] is [nsucc x n mod m] *) + val (++) : t -> int -> t m + + + (** [(x -- n) mod m]is [npred x n mod m] *) + val (--) : t -> int -> t m +end + + + + +(** [compare x y] compares [x] and [y] as unsigned integers, + i.e., + [compare x y] = [compare (to_nat x) (to_nat y)] +*) +val compare : t -> t -> int + +(** [equal x y] is true if [x] and [y] represent the same integers *) +val equal : t -> t -> bool + + +val (<) : t -> t -> bool (** [x < y] iff [compare x y = -1] *) +val (>) : t -> t -> bool (** [x > y] iff [compare x y = 1] *) +val (=) : t -> t -> bool (** [x = y] iff [compare x y = 0] *) +val (<>) : t -> t -> bool (** [x <> y] iff [compare x y <> 0] *) +val (<=) : t -> t -> bool (** [x <= y] iff [compare x y <= 0] *) +val (>=) : t -> t -> bool (** [x >= y] iff [compare x y >= 0] *) + +(** [hash x] returns such [z] that forall [y] s.t. [x=y], [hash y = z] *) +val hash : t -> int + + +(** [pp ppf x] is a pretty printer for the bitvectors. + + Could be used standalone or as an argument to the [%a] format + specificator, e.g., + + {[ + Format.fprintf "0xBEEF != %a" Bitvec.pp !$"0xBEAF" + ]} + +*) +val pp : Format.formatter -> t -> unit + +(** [to_binary x] returns a canonical binary representation of [x] *) +val to_binary : t -> string + +(** [of_binary s] returns a bitvector [x] s.t. [to_binary x = s].*) +val of_binary : string -> t + + +(** [to_string x] returns a textual (human readable) representation + of the bitvector [x]. *) +val to_string : t -> string + +(** [of_string s] returns a bitvector that corresponds to [s]. + + The set of accepted strings is defined by the following EBNF grammar: + + {v + valid-numbers ::= + | "0b", bin-digit, {bin-digit} + | "0o", oct-digit, {oct-digit} + | "0x", hex-digit, {hex-digit} + | dec-digit, {dec-digit} + + bin-digit ::= '0' | '1' + oct-digit ::= '0'-'7' + dec-digit ::= '0'-'9' + hex-digit ::= '0'-'9' |'a'-'f'|'A'-'F' + v} + + The function is not defined if [s] is not in [valid-numbers]. +*) +val of_string : string -> t + +(** [fits_int x] is [true] if [x] could be represented with the OCaml + [int] type. + + Note: it is not always true that [fits_int (int x mod m)], since + depending on [m] a negative number might not fit into the OCaml + representation. For positive numbers it is true, however. +*) +val fits_int : t -> bool + + +(** [to_int x] returns an OCaml integer that has the same + representation as [x]. + + The function is undefined if [not (fits_int x)]. +*) +val to_int : t -> int + + +(** [fits_int32 x] is [true] if [x] could be represented with the OCaml + [int] type. + + Note: it is not always true that [fits_int32 (int32 x mod m)], + since depending on [m] the negative [x] may not fit back into the + [int32] representation. For positive numbers it is true, however. +*) +val fits_int32 : t -> bool + +(** [to_int32 x] returns an OCaml integer that has the same + representation as [x]. + + The function is undefined if [not (fits_int32 x)]. +*) +val to_int32 : t -> int32 + +(** [fits_int64 x] is [true] if [x] could be represented with the OCaml + [int] type. + + Note: it is not always true that [fits_int64 (int64 x mod m)], + since depending on [m] the negative [x] might not fit back into the + [int64] representation. For positive numbers it is true, however. +*) +val fits_int64 : t -> bool + +(** [to_int64 x] returns an OCaml integer that has the same + representation as [x]. + + The function is undefined if [not (fits_int64 x)]. +*) +val to_int64 : t -> int64 + + +(** [to_bigint x] returns a natural number that corresponds to [x]. + + The returned value is always positive. +*) +val to_bigint : t -> Z.t + + +(** [extract ~hi ~lo x] extracts bits from [lo] to [hi]. + + The operation is effectively equivalent to + [(x lsr lo) mod (hi-lo+1)] +*) +val extract : hi:int -> lo:int -> t -> t + + +(** [select bits x] builds a bitvector from [bits] of [x]. + + Returns a bitvector [y] such that [nth] bit of it is + equal to [List.nth bits n] bit of [x]. + + Returns [zero] if [bits] are empty. +*) +val select : int list -> t -> t + +(** [append m n x y] takes [m] bits of [x] and [n] bits of [y] + and returns their concatenation. The result has [m+n] bits. + + + Examples: + - [append 16 16 !$"0xdead" !$"0xbeef" = !$"0xdeadbeef"]; + - [append 12 20 !$"0xbadadd" !$"0xbadbeef" = !$"0xadddbeef"];; +*) +val append : int -> int -> t -> t -> t + +(** [repeat m ~times:n x] repeats [m] bits of [x] [n] times. + + The result has [m*n] bits. +*) +val repeat : int -> times:int -> t -> t + +(** [concat m xs] concatenates [m] bits of each [x] in [xs]. + + The operation is the reduction of the [append] operation with [m=n]. + The result has [m * List.length xs] bits and is equal to [0] + if [xs] is empty. +*) +val concat : int -> t list -> t + +include S with type 'a m := 'a m + +module type Modulus = sig + val modulus : modulus +end + + +(** [module Mx = Make(Modulus)] produces a module [Mx] + which implements all operation in [S] modulo + [Modulus.modulus], so that all operations return a + bitvector directly. +*) +module Make(M : Modulus) : sig + include S with type 'a m = 'a +end + + +(** [M1] specializes [Make(struct let modulus = m1 end)] + + The specialization relies on a few arithmetic equalities + and on an efficient implementation of the modulo operation + as the [even x] aka [lsb x] operation. +*) +module M1 : S with type 'a m = 'a + +(** [M8] specializes [Make(struct let modulus = m8 end)] + + This specialization relies on a fact, that 8 bitvectors + always fit into OCaml integer representation, so it avoids + calls to the underlying arbitrary precision arithmetic + library. +*) +module M8 : S with type 'a m = 'a + + +(** [M32] specializes [Make(struct let modulus = m32 end)] + + This specialization relies on a fact, that 32 bitvectors + always fit into OCaml integer representation, so it avoids + calls to the underlying arbitrary precision arithmetic + library. +*) +module M32 : S with type 'a m = 'a + + +(** [M64] specializes [Make(struct let modulus = m64 end)] + + This specialization tries to minimize calls to the arbitrary + precision arithmetic library whenever, it is known that the result + will not overflow the OCaml int representation. + +*) +module M64 : S with type 'a m = 'a diff --git a/lib/bitvec_binprot/bitvec_binprot.ml b/lib/bitvec_binprot/bitvec_binprot.ml new file mode 100644 index 000000000..37fa66cfb --- /dev/null +++ b/lib/bitvec_binprot/bitvec_binprot.ml @@ -0,0 +1,11 @@ +open Bin_prot.Std +type t = Bitvec.t +module Functions = Bin_prot.Utils.Make_binable(struct + module Binable = struct + type t = string [@@deriving bin_io] + end + type t = Bitvec.t + let to_binable = Bitvec.to_binary + let of_binable = Bitvec.of_binary + end) +include Functions diff --git a/lib/bitvec_binprot/bitvec_binprot.mli b/lib/bitvec_binprot/bitvec_binprot.mli new file mode 100644 index 000000000..827c8a458 --- /dev/null +++ b/lib/bitvec_binprot/bitvec_binprot.mli @@ -0,0 +1,20 @@ +open Bin_prot + +(** Provides serialization functions for the Binprot Protocol.*) + + +include Binable.S with type t = Bitvec.t + + +(** Same module, but functions only without the type. + + Useful, for extending an existing interface with the binable + interface, without hitting the same type defined twice error, + e.g. + + {[include Comprable.Make_binable_using_comparator(struct + include Bitvector_order + include Bitvector_binprot.Functions + end]} +*) +module Functions : Binable.S with type t := Bitvec.t diff --git a/lib/bitvec_order/bitvec_order.ml b/lib/bitvec_order/bitvec_order.ml new file mode 100644 index 000000000..ff7193774 --- /dev/null +++ b/lib/bitvec_order/bitvec_order.ml @@ -0,0 +1,50 @@ +module Sexp = Base.Sexp + +type ('a,'b) comparator_t = ('a,'b) Base.Comparator.t +type ('a,'b) comparator = (module Base.Comparator.S + with type t = 'a + and type comparator_witness = 'b) +include Bitvec_sexp.Functions + +module Ascending = struct + type t = Bitvec.t + include Base.Comparator.Make(struct + type t = Bitvec.t + include Bitvec_sexp.Functions + let compare x y = Bitvec.compare x y [@@inline] + let sexp_of_t = sexp_of_t + end) +end + +module Descending = struct + type t = Bitvec.t + include Base.Comparator.Make(struct + type t = Bitvec.t + include Bitvec_sexp.Functions + let compare x y = Bitvec.compare y x [@@inline] + let sexp_of_t = sexp_of_t + end) +end + +module Natural = Ascending + +type ascending = Ascending.comparator_witness +type descending = Descending.comparator_witness +type natural = ascending + +let ascending : (_,_) comparator = (module Ascending) +let descending : (_,_) comparator = (module Descending) +let natural = ascending + +module Comparators = struct + type bitvec_order = natural + let bitvec_compare = Natural.comparator.compare + let bitvec_equal x y = bitvec_compare x y = 0 + let bitvec_order = natural + let bitvec_ascending = ascending + let bitvec_descending = descending +end + +let compare = Natural.comparator.compare + +include Natural diff --git a/lib/bitvec_order/bitvec_order.mli b/lib/bitvec_order/bitvec_order.mli new file mode 100644 index 000000000..67569702f --- /dev/null +++ b/lib/bitvec_order/bitvec_order.mli @@ -0,0 +1,165 @@ +(** Provides comparators for use with Janestreet Libraries. + + A comparator is a comparison function paired with a type that + is unique to this function and that acts as the witness of + equality between two comparison functions. + + Comparators are extensively used in the Janestreet's suite of + libraries, e.g., Base, Core_kernel, Core, etc. Comparators are + used to create sets, maps, hashtables, to instantiate interfaces, + and algorithms that require comparison. + + This module provides two comparators, [ascending] and + [descending], which represent two corresponding orderings, as + well as the [natural] comparator, which is the default ordering + that equals with the [ascending] order, which should be used in + cases where a particular ordering doesn't matter. + + This library interface is designed from the point of view of the + library user, to minimize verbosity and maximize readability and + the ease of use. + + For example, an empty set is created as + + {[let words = Set.empty Bitvec_order.natural]}, + + and it has type + + {[(Bitvec.t,Bitvec_order.natural) set]} + + (See the note on comparators below for the old style comparators) + + This module also provides (implements) the [Base.Comparator.S] + interface that is required by a few Janestreet functors, in + particular, you can construct the type of a set of bitvectors, + as + + {[type t = Set.M(Bitvec_oder)]}, or for a mapping from + bitvectors to OCaml [int] it would be + + {[type t = int Map.M(Bitvec_order)]} + + + And to instantiate the [Comparable.S] interface, + + {include Comparable.Make(Bitvec_order)}. + + See also, [Bitvec_binprot] that provide support for + the binable interfaces. + + Finally, for even more concise and readable syntax we provide + the [Comparators] module that designed to be opened, since it + defines properly prefixed identifiers, which will unlikely clash + with any existing name, so that the above examples could be + expressed + + {[ + open Bitvec_order.Comparators + let words = Set.empty bitvec_order + type words = (bitvec, bitvec_order) Set.t + let decreasing x y = bitvec_descending.compare x y + ]} + + {2 The old style comparators} + + Historically, the comparator in the Janestreet libraries had two + representations - as a record that contains the [compare] + function, and as a module that contains this record and an + instance of the witness type. The former representation is mostly + used for internal purposes, while the latter is expected in most + of the public functions, like [Set.empty], [Map.empty] in the form + of the first-class module, or in different [Make*] functors. In + this library, when we use the word comparator to refer to the + latter representation, i.e., to the module. When the underlying + comparator record is needed, it could be obtained through the + [comparator] field of the module. + + The older versions of Core and Base libraries were accepting the + comparator as a record, so instead of writing + + {[Set.empty Bitvec_order.natural]} + + the following notation should be used + + {[Set.empty Bitvec_order.Natural.comparator]} +*) +type t = Bitvec.t + + +(** [compare x y] orders [x] and [y] in the natural order. *) +val compare : t -> t -> int + + +(** type index for the increasing order *) +type ascending + +(** type index for the decreasing order *) +type descending + + +(** we use the increasing order as the natural ordering *) +type natural = ascending + +(** a type abbreviation for a comparator packed into a module. + + See the note about the historical meaning of the word comparator. +*) +type ('a,'b) comparator = (module Base.Comparator.S + with type t = 'a + and type comparator_witness = 'b) + + +(** [natural] the packed comparator that sorts in the natural order *) +val natural : (t, natural) comparator + +(** [ascending] the packed comparator that sorts in the increasing order *) +val ascending : (t, ascending) comparator + +(** [descending] the packed comparator that sorts in the decreasing order *) +val descending : (t, descending) comparator + +(** [natural] the comparator that sorts in the natural order *) +module Natural : Base.Comparator.S + with type t = t + and type comparator_witness = natural + +(** [natural] the comparator that sorts in the natural order *) +module Ascending : Base.Comparator.S + with type t = t + and type comparator_witness = ascending + +(** [natural] the comparator that sorts in the natural order *) +module Descending : Base.Comparator.S + with type t = t + and type comparator_witness = descending + +(** provides the natural order by default *) +include Base.Comparator.S + with type t := t + and type comparator_witness = natural + + +(** Open this module to make the following fields available *) +module Comparators : sig + + (** the default ordering for bitvectors *) + type bitvec_order = natural + + + (** [bitvec_compare x y] orders [x] and [y] in the natural order *) + val bitvec_compare : t -> t -> int + + + (** [bitvec_equal x y] is true if [x] is equal to [y]. *) + val bitvec_equal : t -> t -> bool + + + (** [natural] the packed comparator that sorts in the natural order *) + val bitvec_order : (t, natural) comparator + + (** [ascending] the packed comparator that sorts in the increasing order *) + val bitvec_ascending : (t, ascending) comparator + + (** [descending] the packed comparator that sorts in the decreasing order *) + val bitvec_descending : (t, descending) comparator +end diff --git a/lib/bitvec_sexp/bitvec_sexp.ml b/lib/bitvec_sexp/bitvec_sexp.ml new file mode 100644 index 000000000..7d41294b7 --- /dev/null +++ b/lib/bitvec_sexp/bitvec_sexp.ml @@ -0,0 +1,12 @@ +open Sexplib0 + +type t = Bitvec.t + +module Functions = struct + let sexp_of_t x = Sexp.Atom (Bitvec.to_string x) + let t_of_sexp = function + | Sexp.Atom x -> Bitvec.of_string x + | _ -> invalid_arg "Bitvec_sexp: expects an atom, got list" +end + +include Functions diff --git a/lib/bitvec_sexp/bitvec_sexp.mli b/lib/bitvec_sexp/bitvec_sexp.mli new file mode 100644 index 000000000..22d96cb23 --- /dev/null +++ b/lib/bitvec_sexp/bitvec_sexp.mli @@ -0,0 +1,11 @@ +open Sexplib0 + +type t = Bitvec.t + +val sexp_of_t : t -> Sexp.t +val t_of_sexp : Sexp.t -> t + +module Functions : sig + val sexp_of_t : t -> Sexp.t + val t_of_sexp : Sexp.t -> t +end diff --git a/lib/graphlib/graphlib.mli b/lib/graphlib/graphlib.mli index fd5c75e0b..d253efe32 100644 --- a/lib/graphlib/graphlib.mli +++ b/lib/graphlib/graphlib.mli @@ -648,6 +648,12 @@ module Std : sig *) val create : ('n,'d,_) Map.t -> 'd -> ('n,'d) t + (** [equal s1 s2] is [true] if [s1] and [s2] are equal solutions. + + Two solutions are equal if for all [x] in the data domain + ['d], we have that [equal s1[x] s2[x]]. + *) + val equal : equal:('d -> 'd -> bool) -> ('n,'d) t -> ('n,'d) t -> bool (** [iterations s] returns the total number of iterations that was made to obtain the current solution. *) @@ -660,6 +666,13 @@ module Std : sig *) val default : ('n,'d) t -> 'd + + (** [enum xs] enumerates all non-trivial values in the solution. + + A value is non-trivial if it differs from the default value. + *) + val enum : ('n,'d) t -> ('n * 'd) Sequence.t + (** [is_fixpoint s] is [true] if the solution is a fixed point solution, i.e., is a solution that stabilizes the system of equations. *) @@ -1485,10 +1498,10 @@ module Std : sig For the purpose of this function a graph can be represented with three values: - - [nodes_of_edge] returns the source and destination nodes + - [nodes_of_edge] returns the source and destination nodes of an edge; - - [nodes] is a sequence of nodes; - - [edges] is a sequence of edges; + - [nodes] is a sequence of nodes; + - [edges] is a sequence of edges; @param name the name of the graph. @param attrs graphviz attributes of the graph. diff --git a/lib/graphlib/graphlib_graph.ml b/lib/graphlib/graphlib_graph.ml index e0a4a67d4..a7ae1a9c1 100644 --- a/lib/graphlib/graphlib_graph.ml +++ b/lib/graphlib/graphlib_graph.ml @@ -41,7 +41,7 @@ let string_of_set ~sep pp_elt set = let empty_set (type a) (type cmp) map = let m : (module Comparator.S with type t = a and type comparator_witness = cmp) - = (module struct + = (module struct type t = a type comparator_witness = cmp let comparator = Map.comparator map @@ -180,9 +180,9 @@ module Partition = struct (* takes a mapping from node to its root *) let create (type a) (type c) - (comparator : (module Comparator.S with type t = a - and type comparator_witness = c)) - comps = + (comparator : (module Comparator.S with type t = a + and type comparator_witness = c)) + comps = let roots,groups = Hashtbl.fold comps ~init:(Map.empty comparator) ~f:(fun ~key:node ~data:root map -> @@ -1241,6 +1241,20 @@ module Fixpoint = struct | None -> default | Some x -> x + let enum (Solution {approx}) = + Map.to_sequence approx + + let is_subset ~equal (Solution {approx=m1}) ~of_:s2 = + Map.for_alli m1 ~f:(fun ~key ~data -> equal (get s2 key) data) + + let equal ~equal + (Solution {approx=m1; default=d1} as s1) + (Solution {approx=m2; default=d2} as s2) = + equal d1 d2 && + Int.equal (Map.length m1) (Map.length m2) && + is_subset ~equal s1 ~of_:s2 && + is_subset ~equal s2 ~of_:s1 + let is_fixpoint (Solution {steps; iters}) = match steps with | None -> iters > 0 | Some steps -> iters < steps diff --git a/lib/graphlib/graphlib_intf.ml b/lib/graphlib/graphlib_intf.ml index 0e54eadb7..98fa71016 100644 --- a/lib/graphlib/graphlib_intf.ml +++ b/lib/graphlib/graphlib_intf.ml @@ -117,7 +117,9 @@ type graph_attr = Graph.Graphviz.DotAttributes.graph module type Solution = sig type ('n,'d) t val create : ('n,'d,_) Map.t -> 'd -> ('n,'d) t + val equal : equal:('d -> 'd -> bool) -> ('n,'d) t -> ('n,'d) t -> bool val iterations : ('n,'d) t -> int + val enum : ('n,'d) t -> ('n * 'd) Sequence.t val default : ('n,'d) t -> 'd val is_fixpoint : ('n,'d) t -> bool val get : ('n,'d) t -> 'n -> 'd diff --git a/lib/knowledge/.merlin b/lib/knowledge/.merlin new file mode 100644 index 000000000..a5138a1da --- /dev/null +++ b/lib/knowledge/.merlin @@ -0,0 +1,2 @@ +REC +B ../../_build/lib/knowledge diff --git a/lib/knowledge/bap_knowledge.ml b/lib/knowledge/bap_knowledge.ml new file mode 100644 index 000000000..57d07acae --- /dev/null +++ b/lib/knowledge/bap_knowledge.ml @@ -0,0 +1,2539 @@ +open Core_kernel +open Monads.Std + +type ('a,'b) eq = ('a,'b) Type_equal.t = T : ('a,'a) eq + +module Order = struct + type partial = LT | EQ | GT | NC + module type S = sig + type t + val order : t -> t -> partial + end +end + +type conflict = exn = .. + +module Conflict = struct + type t = conflict = .. + let pp = Exn.pp + let add_printer pr = Caml.Printexc.register_printer pr + let sexp_of_t = Exn.sexp_of_t +end + +module type Id = sig + type t [@@deriving sexp, hash] + val zero : t + val pp : Format.formatter -> t -> unit + val of_string : string -> t + include Comparable.S_binable with type t := t + include Binable.S with type t := t +end + +(* static identifiers, + they should persist, so we will substitute them with uuid later +*) +module type Sid = sig + include Id + val incr : t ref -> unit +end + + +(* temporal identifiers + + Identifiers work like pointers in our runtime, and + are tagged words. We use 63 bit words, which are + represented natively as immediate values in 64-bit + OCaml or as boxed values in 32-bit OCaml. + + We add extra tags: + + Numbers: + +------------------+---+ + | payload | 1 | + +------------------+---+ + 62 1 0 + + Atoms: + +--------------+---+---+ + | payload | 1 | 0 | + +--------------+---+---+ + 62 1 0 + + + Cells: + +--------------+---+---+ + | payload | 0 | 0 | + +--------------+---+---+ + 62 1 0 + + + So numbers, are tagged with the least significand + bit set to 1. Not numbers (aka pointers), always + have the lowest bit set to 0, and are either + atoms (symbols or objects) with the second bit set, + and cells, with the second bit cleared. Finally, + we have the null value, which is represented with + all zeros, which is neither number, cell, or atom. + + The same arithmetically, + numbers = {1 + 2*n} -- all odd numbers + cells = {4 + 4*n} + atoms = {6 + 4*n} + null = 0 + + Those four sets are disjoint. + + + The chosen representation, allows us to represent + the following number of elements per class (since + classes partition values into disjoint sets, objects + of different classes may have the same values, basically, + each class has its own heap): + + numbers: 2^62 values (or [-2305843009213693953, 2305843009213693951] + atoms and cells: 2^61 values. + +*) +module type Tid = sig + include Id + val null : t + val first_atom : t + val first_cell : t + val next : t -> t + + val is_null : t -> bool + val is_atom : t -> bool + val is_cell : t -> bool + val is_number : t -> bool + + val fits : int -> bool + val of_int : int -> t + val fits_int : t -> bool + val to_int : t -> int + + val untagged : t -> Int63.t + val atom_of_string : string -> t + val cell_of_string : string -> t + val number_of_string : string -> t +end + + +module Oid : Tid = struct + include Int63 + let null = zero + let first_atom = of_int 6 + let first_cell = of_int 4 + let next x = Int63.(x + of_int 4) [@@inline] + let is_null x = x = zero + let is_number x = x land one <> zero [@@inline] + let is_atom x = + x land of_int 0b01 = zero && + x land of_int 0b10 <> zero + let is_cell x = x land of_int 0b11 = zero + let to_int63 x = x + let number_of_string s = (of_string s lsl 1) lor of_int 1 + let cell_of_string s = (of_string s lsl 2) + let atom_of_string s = (of_string s lsl 2) lor of_int 0b10 + let min_value = min_value asr 1 + let max_value = max_value asr 1 + let fits x = + let x = of_int x in + x >= min_value && x <= max_value + [@@inline] + let fits_int x = + x >= of_int Int.min_value && + x <= of_int Int.max_value + [@@inline] + let of_int x = (of_int x lsl 1) + one [@@inline] + let to_int x = to_int_trunc (x asr 1) [@@inline] + + (* ordinal of a value in the given category (atoms, cells, numbers) *) + let untagged x = + if is_number x then x asr 1 else x asr 2 + [@@inline] + + let pp ppf x = + Format.fprintf ppf "<%#0Lx>" (Int63.to_int64 x) + +end + +module Cid : Sid = Int63 +module Pid : Sid = Int63 + +let user_package = "user" +let keyword_package = "keyword" + +type slot_status = + | Sleep + | Awoke + | Ready + + +module Agent : sig + type t + type id + type reliability + type signs + + val register : + ?desc:string -> + ?package:string -> + ?reliability:reliability -> string -> t + + val registry : unit -> id list + + val authorative : reliability + val reliable : reliability + val trustworthy : reliability + val doubtful : reliability + val unreliable : reliability + + val name : id -> string + val desc : id -> string + val reliability : id -> reliability + + val set_reliability : id -> reliability -> unit + + val pp : Format.formatter -> t -> unit + val pp_id : Format.formatter -> id -> unit + val pp_reliability : Format.formatter -> reliability -> unit + + (* the private interface *) + + val weight : t -> int + include Base.Comparable.S with type t := t +end = struct + module Id = String + type t = Id.t + type agent = Id.t + type id = Id.t + type reliability = A | B | C | D | E [@@deriving sexp] + type info = { + name : string; + desc : string; + rcls : reliability; + } + type signs = Set.M(String).t + + let agents : (agent,info) Hashtbl.t = Hashtbl.create (module String) + + let authorative = A + let reliable = B + let trustworthy = C + let doubtful = D + let unreliable = E + + let weight = function + | A -> 16 + | B -> 8 + | C -> 4 + | D -> 2 + | E -> 1 + + let register + ?(desc="no description provided") + ?(package="user") + ?(reliability=trustworthy) name = + let name = sprintf "%s:%s" package name in + let agent = Caml.Digest.string name in + if Hashtbl.mem agents agent then + failwithf "An agent with name `%s' already exists, \ + please choose another name" name (); + Hashtbl.add_exn agents agent { + desc; name; rcls = reliability + }; + agent + + let registry () = Hashtbl.keys agents + + let info agent = Hashtbl.find_exn agents agent + let name agent = (info agent).name + let desc agent = (info agent).desc + let reliability agent = (info agent).rcls + let weight agent = weight (reliability agent) + + let set_reliability agent rcls = + Hashtbl.update agents agent ~f:(function + | None -> assert false + | Some agent -> {agent with rcls}) + + let pp ppf agent = Format.pp_print_string ppf (name agent) + + let pp_reliability ppf r = + Sexp.pp ppf (sexp_of_reliability r) + + let pp_id ppf agent = + let {name; desc; rcls} = info agent in + Format.fprintf ppf "Class %a %s - %s" + pp_reliability rcls name desc + + include (String : Base.Comparable.S with type t := t) +end + +module Opinions : sig + type 'a t + + val empty : equal:('a -> 'a -> bool) -> 'a -> 'a t + + val inspect : ('a -> Sexp.t) -> 'a t -> Sexp.t + + val add : Agent.t -> 'a -> 'a t -> 'a t + val of_list : equal:('a -> 'a -> bool) -> 'a -> (Agent.t,'a) List.Assoc.t -> 'a t + val choice : 'a t -> 'a + + val compare_votes : 'a t -> 'a t -> int + val join : 'a t -> 'a t -> 'a t +end = struct + type 'a opinion = { + opinion : 'a; + votes : Set.M(Agent).t; + } + + type 'a t = { + opinions : 'a opinion list; + equal : 'a -> 'a -> bool; + empty : 'a; + } + + let empty ~equal empty = {opinions=[]; equal; empty} + + let inspect sexp_of_opinion {opinions} = + Sexp.List (List.rev_map opinions ~f:(fun {opinion} -> + sexp_of_opinion opinion)) + + + let add_opinion op ({opinions; equal} as ops) = + let casted,opinions = + List.fold opinions ~init:(false,[]) + ~f:(fun (casted,opinions) ({opinion; votes} as elt) -> + if not casted && equal opinion op.opinion + then true, { + opinion; votes = Set.union votes op.votes; + } :: opinions + else casted,elt :: opinions) in + if casted + then {ops with opinions} + else { + ops with opinions = op :: opinions + } + + let add agent opinion ({empty; equal} as ops) = + if equal opinion empty then ops + else + add_opinion { + opinion; + votes = Set.singleton (module Agent) agent; + } ops + + let join x y = + List.fold y.opinions ~init:x ~f:(fun ops op -> add_opinion op ops) + + let votes_sum = + Set.fold ~init:0 ~f:(fun sum agent -> sum + Agent.weight agent) + + let count_votes {opinions} = + List.fold opinions ~init:0 ~f:(fun sum {votes} -> + sum + votes_sum votes) + + let compare_votes x y = + compare (count_votes x) (count_votes y) + + let of_list ~equal bot = + let init = empty ~equal bot in + List.fold ~init ~f:(fun opts (agent,data) -> + add agent data opts) + + let compare x y = + let w1 = votes_sum x.votes + and w2 = votes_sum y.votes in + match Int.compare w1 w2 with + | 0 -> Set.compare_direct x.votes y.votes + | n -> n + + let choice {opinions; empty} = + List.max_elt opinions ~compare |> function + | Some {opinion} -> opinion + | None -> empty + +end + +module Domain = struct + type 'a t = { + inspect : 'a -> Sexp.t; + empty : 'a; + order : 'a -> 'a -> Order.partial; + join : 'a -> 'a -> ('a,conflict) result; + name : string; + } + + let inspect d = d.inspect + let empty d = d.empty + let order d = d.order + let join d = d.join + let name d = d.name + + let is_empty {empty; order} x = order empty x = EQ + + exception Join of string * Sexp.t * Sexp.t [@@deriving sexp_of] + + let make_join name inspect order x y = + match order x y with + | Order.GT -> Ok x + | EQ | LT -> Ok y + | NC -> Error (Join (name, inspect x, inspect y)) + + let define ?(inspect=sexp_of_opaque) ?join ~empty ~order name = { + inspect; empty; order; name; + join = match join with + | Some f -> f + | None -> (make_join name inspect order) + } + + let partial_of_total order x y : Order.partial = match order x y with + | 0 -> EQ + | 1 -> GT + | _ -> LT + + let total ?inspect ?join ~empty ~order name = + define ?inspect ?join ~empty name ~order:(partial_of_total order) + + let flat ?inspect ?join ~empty ~equal name = + define ?inspect ?join ~empty name ~order:(fun x y -> + match equal empty x, equal empty y with + | true,true -> EQ + | true,false -> LT + | false,true -> GT + | false,false -> if equal x y then EQ else NC) + + let powerset (type t o) + (module S : Comparator.S with type t = t + and type comparator_witness = o) + ?(inspect=sexp_of_opaque) name = + let empty = Set.empty (module S) in + let order x y : Order.partial = + if Set.equal x y then EQ else + if Set.is_subset x y then LT else + if Set.is_subset y x then GT else NC in + let join x y = Ok (Set.union x y) in + let module Inspectable = struct + include S + let sexp_of_t = inspect + end in + let inspect = [%sexp_of: Base.Set.M(Inspectable).t] in + define ~inspect ~empty ~order ~join name + + let opinions ?(inspect=sexp_of_opaque) ~empty ~equal name = + let empty = Opinions.empty ~equal empty in + let order = partial_of_total (Opinions.compare_votes) in + let inspect = Opinions.inspect inspect in + define ~inspect ~empty ~order name + + let mapping (type k o d) + (module K : Comparator.S with type t = k + and type comparator_witness = o) + ?(inspect=sexp_of_opaque) + ~equal name = + let empty = Map.empty (module K) in + let join x y = + let module Join = struct exception Conflict of conflict end in + try Result.return @@ Map.merge x y ~f:(fun ~key:_ -> function + | `Left v | `Right v -> Some v + | `Both (x,y) -> + if equal x y then Some y + else + let x = inspect x and y = inspect y in + let failed = Join (name,x,y) in + raise (Join.Conflict failed)) + with Join.Conflict err -> Error err in + let inspect xs = + Sexp.List (Map.keys xs |> List.map ~f:K.comparator.sexp_of_t) in + let order x y = + Map.symmetric_diff x y ~data_equal:equal |> + Sequence.fold ~init:(0,0,0) ~f:(fun (l,m,r) -> function + | (_,`Left _) -> (l+1,m,r) + | (_,`Right _) -> (l,m,r+1) + | (_, `Unequal _) -> (l,m+1,r)) |> function + | 0,0,0 -> Order.EQ + | 0,0,_ -> LT + | _,0,0 -> GT + | _,_,_ -> NC in + define ~inspect ~join ~empty ~order name + + let optional ?(inspect=sexp_of_opaque) ?join ~equal name = + let join_data = match join with + | Some join -> join + | None -> fun x y -> + if equal x y then Ok y + else Error (Join (name, inspect x, inspect y)) in + let inspect = sexp_of_option inspect in + let join x y = match x,y with + | None,x | x,None -> Ok x + | Some x, Some y -> match join_data x y with + | Ok x -> Ok (Some x) + | Error err -> Error err in + flat ~inspect ~join ~empty:None ~equal:(Option.equal equal) name + + let string = define "string" ~empty:"" + ~inspect:sexp_of_string ~order:(fun x y -> + match String.is_empty x, String.is_empty y with + | true, true -> EQ + | true,false -> GT + | false,true -> LT + | false,false -> partial_of_total String.compare x y) + + let bool = optional ~inspect:sexp_of_bool ~equal:Bool.equal "bool" +end + +module Persistent = struct + type 'a t = + | String : string t + | Define : { + of_string : string -> 'a; + to_string : 'a -> string; + } -> 'a t + | Derive : { + of_persistent : 'b -> 'a; + to_persistent : 'a -> 'b; + persistent : 'b t; + } -> 'a t + + + let string = String + + let define ~to_string ~of_string = Define { + to_string; + of_string; + } + + let derive ~to_persistent ~of_persistent persistent = Derive { + to_persistent; + of_persistent; + persistent; + } + + let of_binable + : type a. (module Binable.S with type t = a) -> a t = + fun r -> Define { + to_string = Binable.to_string r; + of_string = Binable.of_string r + } + + let rec to_string + : type a. a t -> a -> string = + fun p x -> match p with + | String -> x + | Define {to_string} -> to_string x + | Derive {to_persistent; persistent} -> + to_string persistent (to_persistent x) + + let rec of_string + : type a. a t -> string -> a = + fun p s -> match p with + | String -> s + | Define {of_string} -> of_string s + | Derive {of_persistent; persistent} -> + of_persistent (of_string persistent s) + + + + module Chunk = struct + (* bin_io will pack len+data, and restore it correspondingly *) + type t = {data : string} [@@deriving bin_io] + end + + module KV = struct + type t = {key : string; data : string} + [@@deriving bin_io] + end + + module Chunks = struct + type t = Chunk.t list [@@deriving bin_io] + end + module KVS = struct + type t = KV.t list [@@deriving bin_io] + end + + let chunks = of_binable (module Chunks) + let kvs = of_binable (module KVS) + + let list p = derive chunks + ~to_persistent:(List.rev_map ~f:(fun x -> + {Chunk.data = to_string p x})) + ~of_persistent:(List.rev_map ~f:(fun {Chunk.data} -> + of_string p data)) + + let array p = derive chunks + ~to_persistent:(Array.fold ~init:[] ~f:(fun xs x -> + {Chunk.data = to_string p x} :: xs)) + ~of_persistent:(Array.of_list_rev_map ~f:(fun {Chunk.data} -> + of_string p data)) + + let sequence p = derive chunks + ~to_persistent:(fun xs -> + Sequence.to_list_rev @@ + Sequence.map xs ~f:(fun x -> + {Chunk.data = to_string p x})) + ~of_persistent:(fun xs -> + Sequence.of_list @@ + List.rev_map xs ~f:(fun {Chunk.data} -> + of_string p data)) + + let set c p = derive (list p) + ~to_persistent:Set.to_list + ~of_persistent:(Set.of_list c) + + let map c pk pd = derive kvs + ~to_persistent:(Map.fold ~init:[] ~f:(fun ~key ~data xs -> { + KV.key = to_string pk key; + KV.data = to_string pd data + } :: xs)) + ~of_persistent:(List.fold ~init:(Map.empty c) + ~f: (fun xs {KV.key;data} -> + let key = of_string pk key + and data = of_string pd data in + Map.add_exn xs ~key ~data)) +end + + +type 'a obj = Oid.t + +type fullname = { + package : string; + name : string; +} [@@deriving bin_io, sexp_of] + + +module Registry = struct + let packages = Hash_set.create (module String) () + let classes = Hashtbl.create (module String) + let slots = Hashtbl.create (module String) + + let add_package name = Hash_set.add packages name + + let is_present ~package namespace name = + match Hashtbl.find namespace package with + | None -> false + | Some names -> Map.mem names name + + let register kind namespace ?desc ?(package=user_package) name = + add_package package; + if is_present ~package namespace name + then failwithf + "Failed to declare new %s, there is already a %s \ + named `%s' in package `%s'" kind kind name package (); + Hashtbl.update namespace package ~f:(function + | None -> Map.singleton (module String) name desc + | Some names -> Map.add_exn names ~key:name ~data:desc); + {package; name} + + let add_class = register "class" classes + let add_slot = register "property" slots +end + +let string_of_fname {package; name} = + if package = keyword_package || package = user_package + then name + else package ^ ":" ^ name + +let escaped = + Staged.unstage @@ + String.Escaping.escape ~escapeworthy:[':'] ~escape_char:'\\' + +let find_separator s = + if String.is_empty s then None + else String.Escaping.index s ~escape_char:'\\' ':' + +(* invariant, keywords are always prefixed with [:] *) +let normalize_name ~package name = + let package = escaped package in + if package = keyword_package && + not (String.is_prefix ~prefix:":" name) + then ":"^name else name + +let split_name package s = match find_separator s with + | None -> {package; name=s} + | Some 0 -> {package=keyword_package; name=s} + | Some len -> { + package = String.sub s ~pos:0 ~len; + name = String.subo s ~pos:(len+1); + } + +module Class = struct + type +'s info = { + id : Cid.t; + name : fullname; + sort : 's; + } + let id {id} = id + + type (+'a,+'s) t = 's info + + let classes = ref Cid.zero + + let names = Hashtbl.create (module Cid) + + let newclass ?desc ?package name sort = + Cid.incr classes; + let id = !classes + and name = Registry.add_class ?desc ?package name in + Hashtbl.add_exn names id name; + {id; name; sort} + + let declare + : ?desc:string -> ?package:string -> string -> 's -> ('k,'s) t = + fun ?desc ?package name data -> + newclass ?desc ?package name data + + let refine {id; name} sort = {id; name; sort} + + let same x y = Cid.equal x.id y.id + + let equal : type a b. (a,_) t -> (b,_) t -> (a obj,b obj) Type_equal.t option = + fun x y -> Option.some_if (same x y) Type_equal.T + + let assert_equal x y = match equal x y with + | Some t -> t + | None -> + failwithf "assert_equal: wrong assertion, classes of %s and %s \ + are different" + (string_of_fname x.name) + (string_of_fname y.name) + () + + + + let sort = fun {sort} -> sort + let name {name={name}} = name + let package {name={package}} = package + let fullname {name} = string_of_fname name + +end + +module Dict = struct + module Key = struct + module Uid = Int + let last_id = ref 0 + + type 'a witness = .. + + module type Witness = sig + type t + type _ witness += Id : t witness + end + + type 'a typeid = (module Witness with type t = 'a) + + type 'a t = { + ord : Uid.t; + key : 'a typeid; + name : string; + show : 'a -> Sexp.t; + } + + let newtype (type a) () : a typeid = + let module Type = struct + type t = a + type _ witness += Id : t witness + end in + (module Type) + + + let create ~name show = + let key = newtype () in + incr last_id; + {key; ord = !last_id; name; show} + + let uid {ord} = ord [@@inline] + let compare k1 k2 = + let k1 = uid k1 and k2 = uid k2 in + (Uid.compare [@inlined]) k1 k2 + [@@inline] + + let name x = x.name + let to_sexp x = x.show + let equal x y = Int.equal x.ord y.ord [@@inline] + + let same (type a b) x y : (a,b) Type_equal.t = + if equal x y then + let module X = (val x.key : Witness with type t = a) in + let module Y = (val y.key : Witness with type t = b) in + match X.Id with + | Y.Id -> Type_equal.T + | _ -> failwith "broken type equality" + else failwith "types are not equal" + end + type 'a key = 'a Key.t + type record = + | T0 + | T1 : 'a key * 'a -> record + | T2 : 'a key * 'a * + 'b key * 'b -> record + | T3 : 'a key * 'a * + 'b key * 'b * + 'c key * 'c -> record + | T4 : 'a key * 'a * + 'b key * 'b * + 'c key * 'c * + 'd key * 'd -> record + | LL : record * 'a key * 'a * record -> record (* h(x) = h(y) - 1 *) + | EQ : record * 'a key * 'a * record -> record (* h(x) = h(y) *) + | LR : record * 'a key * 'a * record -> record (* h(x) = h(y) + 1 *) + + type t = record + + let empty = T0 + let is_empty = function + | T0 -> true + | _ -> false + + (* + - LL (x,y) : h(x) = h(y) - 1 + - EQ (x,y) : h(x) = h(y) + - LR (x,y) : h(x) = h(y) + 1 + *) + + let (<$) k1 k2 = + let k1 = Key.uid k1 and k2 = Key.uid k2 in + (Key.Uid.(<)[@inlined]) k1 k2 + [@@inline] + + let make0 = T0 [@@inlined] + let make1 k a = T1 (k,a) [@@inline] + let make2 ka a kb b = T2 (ka,a,kb,b) [@@inline] + let make3 ka a kb b kc c = T3 (ka,a,kb,b,kc,c) [@@inline] + let make4 ka a kb b kc c kd d = T4 (ka,a, kb,b, kc,c, kd,d) [@@inline] + let make5 ka a kb b kc c kd d ke e = + EQ (make2 ka a kb b,kc,c,make2 kd d ke e) + [@@inline] + let make6 ka a kb b kc c kd d ke e kf f = + EQ (T1 (ka,a),kb,b,T4(kc,c,kd,d,ke,e,kf,f)) + [@@inline] + let make7 ka a kb b kc c kd d ke e kf f kg g = + EQ (T2 (ka,a,kb,b),kc,c, T4 (kd,d,ke,e,kf,f,kg,g)) + [@@inline] + let make8 ka a kb b kc c kd d ke e kf f kg g kh h = + EQ (T3 (ka,a,kb,b,kc,c),kd,d, T4(ke,e,kf,f,kg,g,kh,h)) + [@@inline] + let make9 ka a kb b kc c kd d ke e kf f kg g kh h ki i = + EQ (T4 (ka,a,kb,b,kc,c,kd,d),ke,e,T4(kf,f,kg,g,kh,h,ki,i)) + [@@inline] + let make10 ka a kb b kc c kd d ke e kf f kg g kh h ki i kj j = + LL (make4 ka a kb b kc c kd d, ke, e, make5 kf f kg g kh h ki i kj j) + [@@inline] + + type 'r visitor = { + visit : 'a. 'a key -> 'a -> 'r -> 'r; + } [@@unboxed] + + let rec foreach x ~init f = match x with + | T0 -> init + | T1 (ka,a) -> f.visit ka a init + | T2 (ka,a,kb,b) -> + f.visit ka a init |> + f.visit kb b + | T3 (ka,a,kb,b,kc,c) -> + f.visit ka a init |> + f.visit kb b |> + f.visit kc c + | T4 (ka,a,kb,b,kc,c,kd,d) -> + f.visit ka a init |> + f.visit kb b |> + f.visit kc c |> + f.visit kd d + | LL (x,k,a,y) -> + let init = f.visit k a init in + foreach y ~init:(foreach x ~init f) f + | EQ (x,k,a,y) -> + let init = f.visit k a init in + foreach y ~init:(foreach x ~init f) f + | LR (x,k,a,y) -> + let init = f.visit k a init in + foreach y ~init:(foreach x ~init f) f + + type ('b,'r) app = { + app : 'a. 'a key -> 'a -> 'b -> 'r + } [@@unboxed] + + let cmp x y = Key.compare x y [@@inline] + let eq x y = Key.compare x y = 0 [@@inline] + + let rec pop_min t {app} = match t with + | T0 -> failwith "pop_min: empty" + | T1 (ka,a) -> app ka a T0 + | T2 (ka,a,kb,b) -> app ka a (T1 (kb,b)) + | T3 (ka,a,kb,b,kc,c) -> app ka a (T2 (kb,b,kc,c)) + | T4 (ka,a,kb,b,kc,c,kd,d) -> app ka a (T3 (kb,b,kc,c,kd,d)) + | LL (x,ka,a,y) -> pop_min x {app = fun kb b x -> app kb b (LL (x,ka,a,y))} + | EQ (x,ka,a,y) -> pop_min x {app = fun kb b x -> app kb b (EQ (x,ka,a,y))} + | LR (x,ka,a,y) -> pop_min x {app = fun kb b x -> app kb b (LR (x,ka,a,y))} + [@@inline] + + let rec pop_max t {app} = match t with + | T0 -> failwith "pop_max: empty" + | T1 (ka,a) -> app ka a T0 + | T2 (ka,a,kb,b) -> app kb b (T1 (ka,a)) + | T3 (ka,a,kb,b,kc,c) -> app kc c (T2 (ka,a,kb,b)) + | T4 (ka,a,kb,b,kc,c,kd,d) -> app kd d (T3 (ka,a,kb,b,kc,c)) + | LL (x,ka,a,y) -> pop_max y {app = fun kb b y -> app kb b (LL (x,ka,a,y))} + | EQ (x,ka,a,y) -> pop_max y {app = fun kb b y -> app kb b (EQ (x,ka,a,y))} + | LR (x,ka,a,y) -> pop_max y {app = fun kb b y -> app kb b (LR (x,ka,a,y))} + [@@inline] + + let shake_left = function + | LL (T0,ka,a,EQ (T4 (kb,b,kc,c,kd,d,ke,e),kf,f,T4(kg,g,kh,h,ki,i,kj,j))) -> + LL (T1 (ka,a),kb,b,EQ (T3 (kc,c,kd,d,ke,e),kf,f,T4(kg,g,kh,h,ki,i,kj,j))) + | LL (T1(kx,x),ka,a,EQ (T4 (kb,b,kc,c,kd,d,ke,e),kf,f,T4(kg,g,kh,h,ki,i,kj,j))) -> + LL (T2 (kx,x,ka,a),kb,b,EQ (T3 (kc,c,kd,d,ke,e),kf,f,T4(kg,g,kh,h,ki,i,kj,j))) + | LL (T2(kx,x,ky,y),ka,a,EQ (T4 (kb,b,kc,c,kd,d,ke,e),kf,f,T4(kg,g,kh,h,ki,i,kj,j))) -> + LL (T3 (kx,x,ky,y,ka,a),kb,b,EQ (T3 (kc,c,kd,d,ke,e),kf,f,T4(kg,g,kh,h,ki,i,kj,j))) + | LL (T3(kx,x,ky,y,kz,z),ka,a,EQ (T4 (kb,b,kc,c,kd,d,ke,e),kf,f,T4(kg,g,kh,h,ki,i,kj,j))) -> + LL (T4 (kx,x,ky,y,kz,z,ka,a),kb,b, EQ (T3 (kc,c,kd,d,ke,e),kf,f,T4(kg,g,kh,h,ki,i,kj,j))) + | _ -> assert false + [@@inline] + + exception Rol_wrong_rank of record + exception Ror_wrong_rank of record + + let rol = function + | LL (x,ka,a,LL (y,kb,b,z)) -> + (* + * h(x) = m-2 + * h(LL(y,b,z)=m + * h(y)=m-2 + * h(z)=m-1 + * ---------------- + * h(EQ(x,a,y)) = m-1 + * h(EQ(EQ(x,ka,a,y),b,z)) = m + *) + EQ (EQ(x,ka,a,y),kb,b,z) + | LL (x,ka,a,EQ (y,kb,b,z)) -> + (* + * h(x) = m-2 + * h(EQ(y,b,z))=m + * h(y)=m-1 + * h(z)=m-1 + * ---------------- + * h(LL(x,a,y)) = m + * h(LR(LL(x,a,y),b,z)) = m+1 + *) + LR (LL(x,ka,a,y),kb,b,z) + | LL (w,ka,a,LR (LL(x,kb,b,y),kc,c,z)) -> + (* + * h(w) = m-2 + * h(LR(LL(x,b,y),c,z))=m + * h(z)=m-2 + * h(LL(x,b,y))=m-1 + * h(y)=m-2 + * h(x)=m-3 + * ---------------- + * h(LR(w,a,x))=m-1, h(x) < h(w) + * h(EQ(y,kc,c,z))=m-1, h(y) = h(z) + * h(EQ (LR(w,ka,a,x),kb,b,EQ(y,kc,c,z))) = m + *) + EQ (LR(w,ka,a,x),kb,b,EQ(y,kc,c,z)) + | LL (w,ka,a,LR (EQ(x,kb,b,y),kc,c,z)) -> + (* + * h(w) = m-2 + * h(LR(EQ(x,b,y),c,z))=m + * h(z)=m-2 + * h(EQ(x,b,y))=m-1 + * h(y)=m-2 + * h(x)=m-2 + * ---------------- + * h(EQ(w,a,x))=m-1, h(x) = h(w) + * h(EQ(y,kc,c,z))=m-1, h(y) = h(z) + * h(EQ (EQ(w,ka,a,x),kb,b,EQ(y,kc,c,z))) = m + *) + EQ (EQ(w,ka,a,x),kb,b,EQ(y,kc,c,z)) + | LL (w,ka,a,LR (LR(x,kb,b,y),kc,c,z)) -> + (* + * h(w) = m-2 + * h(LR(LR(x,b,y),c,z))=m + * h(z)=m-2 + * h(LR(x,b,y))=m-1 + * h(y)=m-3 + * h(x)=m-2 + * ---------------- + * h(EQ(w,a,x))=m-1, h(x) = h(w) + * h(LL(y,kc,c,z))=m-1, h(y) < h(z) + * h(EQ (EQ(w,ka,a,x),kb,b,LL(y,kc,c,z))) = m + *) + EQ (EQ(w,ka,a,x),kb,b,LL(y,kc,c,z)) + | r -> raise (Rol_wrong_rank r) + [@@inline] + + + let ror = function + | LR (LR(x,ka,a,y),kb,b,z) -> + (* + * h(z) = m-2 + * h(LR(x,a,y))=m + * h(y)=m-2 + * h(x)=m-1 + * ------------------ + * h(EQ(y,b,z))=m-1, h(y) = h(z) + * h(EQ (x,a,EQ(y,kb,b,z))) = m + *) + EQ (x,ka,a,EQ(y,kb,b,z)) + | LR (EQ(x,ka,a,y),kb,b,z) -> + (* + * h(z) = m-2 + * h(EQ(x,a,y))=m + * h(y)=m-1 + * h(x)=m-1 + * ------------------ + * h(LR(y,b,z))=m, h(y) > h(z) + * h(LL (x,a,LR(y,b,z))) = m+1, h(x) < m + *) + LL (x,ka,a,LR(y,kb,b,z)) + | LR (LL (w,ka,a,LR(x,kb,b,y)),kc,c,z) -> + (* + * h(z) = m-2 + * h(LL (w,a,LR(x,b,y)))=m + * h(LR(x,b,y))=m-1 + * h(w)=m-2 + * h(x)=m-2 + * h(y)=m-3 + * ------------------------- + * h(EQ(w,a,x)) = m-1, h(x) = h(w) + * h(LL(y,c,z)) = m-1, h(y) < h(z) + *) + EQ (EQ(w,ka,a,x), kb,b, LL(y,kc,c,z)) + | LR (LL (w,ka,a,EQ(x,kb,b,y)),kc,c,z) -> + (* + * h(z) = m-2 + * h(LL (w,a,EQ(x,b,y)))=m + * h(EQ(x,b,y))=m-1 + * h(w)=m-2 + * h(x)=m-2 + * h(y)=m-2 + * ------------------------- + * h(EQ(w,a,x)) = m-1, h(x) = h(w) + * h(EQ(y,c,z)) = m-1, h(y) = h(z) + *) + EQ (EQ(w,ka,a,x), kb,b, EQ(y,kc,c,z)) + | LR (LL (w,ka,a,LL(x,kb,b,y)),kc,c,z) -> + (* + * h(z) = m-2 + * h(LL (w,a,LL(x,b,y)))=m + * h(LL(x,b,y))=m-1 + * h(w)=m-2 + * h(x)=m-3 + * h(y)=m-2 + * ------------------------- + * h(LR(w,a,x)) = m-1, h(x) < h(w) + * h(EQ(y,c,z)) = m-1, h(y) = h(z) + *) + EQ (LR(w,ka,a,x), kb,b, EQ(y,kc,c,z)) + | r -> raise (Ror_wrong_rank r) + [@@inline] + + let rank_increases was now = match was,now with + | (T0 | T1 _ | T2 _ | T3 _ | T4 _), LR _ + | (T0 | T1 _ | T2 _ | T3 _ | T4 _), EQ _ + | (T0 | T1 _ | T2 _ | T3 _ | T4 _), LL _ -> true + | EQ _, LL _ + | EQ _, LR _ -> true + | LR _, LL _ + | LL _, LR _ -> false + | _ -> false + [@@inline] + + (* [p += c] updates the right subtree of [p] with [c]. + pre: rank p > 1 /\ rank c > 1 *) + let (+=) p c' = match p with + | LL (b,k,x,c) -> + if rank_increases c c' + then rol (LL (b,k,x,c')) + else LL (b,k,x,c') + | LR (b,k,x,c) -> + if rank_increases c c' + then EQ (b,k,x,c') + else LR (b,k,x,c') + | EQ (b,k,x,c) -> + if rank_increases c c' + then LL (b,k,x,c') + else EQ (b,k,x,c') + | _ -> failwith "+=: rank < 2" + [@@inline] + + (* [b =+ p] updates the left subtree of [p] with [b]. + pre: rank p > 1 /\ rank b > 1 *) + let (=+) b' p = match p with + | LL (b,k,x,c) -> + if rank_increases b b' + then EQ (b',k,x,c) + else LL (b',k,x,c) + | LR (b,k,x,c) -> + if rank_increases b b' + then ror (LR (b',k,x,c)) + else LR (b',k,x,c) + | EQ (b,k,x,c) -> + if rank_increases b b' + then LR (b',k,x,c) + else EQ (b',k,x,c) + | _ -> failwith "=+: rank < 2" + [@@inline] + + (* pre: + - a is not in t; + - for all functions except [bal] t is balanced; + - for [bal] the input is t is disbalanced. + + post: + - a is in t', and len t' = len t + 1 + - h(t') >= h(t) + - t' is balanced + *) + let rec insert + : type a. a key -> a -> record -> record = fun ka a -> function + | T0 -> make1 ka a + | T1 (kb,b) -> if ka <$ kb + then make2 ka a kb b + else make2 kb b ka a + | T2 (kb,b,kc,c) -> if ka <$ kb + then make3 ka a kb b kc c else if ka <$ kc + then make3 kb b ka a kc c + else make3 kb b kc c ka a + | T3 (kb,b,kc,c,kd,d) -> + if ka <$ kc + then if ka <$ kb + then make4 ka a kb b kc c kd d + else make4 kb b ka a kc c kd d + else if ka <$ kd + then make4 kb b kc c ka a kd d + else make4 kb b kc c kd d ka a + | T4 (kb,b,kc,c,kd,d,ke,e) -> + if ka <$ kd then + if ka <$ kc then + if ka <$ kb + then make5 ka a kb b kc c kd d ke e + else make5 kb b ka a kc c kd d ke e + else make5 kb b kc c ka a kd d ke e + else if ka <$ ke + then make5 kb b kc c kd d ka a ke e + else make5 kb b kc c kd d ke e ka a + | EQ (T0,kb,b,T4(kc,c,kd,d,ke,e,kf,f)) -> + if ka <$ kd then + if ka <$ kc then + if ka <$ kb + then make6 ka a kb b kc c kd d ke e kf f + else make6 kb b ka a kc c kd d ke e kf f + else make6 kb b kc c ka a kd d ke e kf f + else + if ka <$ ke then + make6 kb b kc c kd d ka a ke e kf f + else if ka <$ kf + then make6 kb b kc c kd d ke e ka a kf f + else make6 kb b kc c kd d ke e kf f ka a + | EQ (T4(kb,b,kc,c,kd,d,ke,e),kf,f,T0) -> + if ka <$ kd then + if ka <$ kc then + if ka <$ kb + then make6 ka a kb b kc c kd d ke e kf f + else make6 kb b ka a kc c kd d ke e kf f + else make6 kb b kc c ka a kd d ke e kf f + else + if ka <$ ke then + make6 kb b kc c kd d ka a ke e kf f + else if ka <$ kf + then make6 kb b kc c kd d ke e ka a kf f + else make6 kb b kc c kd d ke e kf f ka a + | EQ (T1 (kb,b),kc,c,T4(kd,d,ke,e,kf,f,kg,g)) -> + if ka <$ kd then + if ka <$ kc then + if ka <$ kb + then make7 ka a kb b kc c kd d ke e kf f kg g + else make7 kb b ka a kc c kd d ke e kf f kg g + else make7 kb b kc c ka a kd d ke e kf f kg g + else + if ka <$ kf then + if ka <$ ke + then make7 kb b kc c kd d ka a ke e kf f kg g + else make7 kb b kc c kd d ke e ka a kf f kg g + else if ka <$ kg + then make7 kb b kc c kd d ke e kf f ka a kg g + else make7 kb b kc c kd d ke e kf f kg g ka a + | EQ (T4 (kb,b,kc,c,kd,d,ke,e),kf,f,T1(kg,g)) -> + if ka <$ kd then + if ka <$ kc then + if ka <$ kb + then make7 ka a kb b kc c kd d ke e kf f kg g + else make7 kb b ka a kc c kd d ke e kf f kg g + else make7 kb b kc c ka a kd d ke e kf f kg g + else + if ka <$ kf then + if ka <$ ke + then make7 kb b kc c kd d ka a ke e kf f kg g + else make7 kb b kc c kd d ke e ka a kf f kg g + else if ka <$ kg + then make7 kb b kc c kd d ke e kf f ka a kg g + else make7 kb b kc c kd d ke e kf f kg g ka a + | EQ (T2 (kb,b,kc,c),kd,d,T4(ke,e,kf,f,kg,g,kh,h)) -> + if ka <$ ke then + if ka <$ kc then + if ka <$ kb + then make8 ka a kb b kc c kd d ke e kf f kg g kh h + else make8 kb b ka a kc c kd d ke e kf f kg g kh h + else + if ka <$ kd + then make8 kb b kc c ka a kd d ke e kf f kg g kh h + else make8 kb b kc c kd d ka a ke e kf f kg g kh h + else + if ka <$ kg then + if ka <$ kf + then make8 kb b kc c kd d ke e ka a kf f kg g kh h + else make8 kb b kc c kd d ke e kf f ka a kg g kh h + else if ka <$ kh + then make8 kb b kc c kd d ke e kf f kg g ka a kh h + else make8 kb b kc c kd d ke e kf f kg g kh h ka a + | EQ (T4 (kb,b,kc,c,kd,d,ke,e),kf,f,T2(kg,g,kh,h)) -> + if ka <$ ke then + if ka <$ kc then + if ka <$ kb + then make8 ka a kb b kc c kd d ke e kf f kg g kh h + else make8 kb b ka a kc c kd d ke e kf f kg g kh h + else + if ka <$ kd + then make8 kb b kc c ka a kd d ke e kf f kg g kh h + else make8 kb b kc c kd d ka a ke e kf f kg g kh h + else + if ka <$ kg then + if ka <$ kf + then make8 kb b kc c kd d ke e ka a kf f kg g kh h + else make8 kb b kc c kd d ke e kf f ka a kg g kh h + else if ka <$ kh + then make8 kb b kc c kd d ke e kf f kg g ka a kh h + else make8 kb b kc c kd d ke e kf f kg g kh h ka a + | EQ (T3 (kb,b,kc,c,kd,d),ke,e,T4(kf,f,kg,g,kh,h,ki,i)) -> + if ka <$ ke then + if ka <$ kc then + if ka <$ kb + then make9 ka a kb b kc c kd d ke e kf f kg g kh h ki i + else make9 kb b ka a kc c kd d ke e kf f kg g kh h ki i + else + if ka <$ kd + then make9 kb b kc c ka a kd d ke e kf f kg g kh h ki i + else make9 kb b kc c kd d ka a ke e kf f kg g kh h ki i + else + if ka <$ kg then + if ka <$ kf + then make9 kb b kc c kd d ke e ka a kf f kg g kh h ki i + else make9 kb b kc c kd d ke e kf f ka a kg g kh h ki i + else if ka <$ kh + then make9 kb b kc c kd d ke e kf f kg g ka a kh h ki i + else if ka <$ ki then + make9 kb b kc c kd d ke e kf f kg g kh h ka a ki i + else + make9 kb b kc c kd d ke e kf f kg g kh h ki i ka a + | EQ (T4 (kb,b,kc,c,kd,d,ke,e),kf,f,T3(kg,g,kh,h,ki,i)) -> + if ka <$ ke then + if ka <$ kc then + if ka <$ kb + then make9 ka a kb b kc c kd d ke e kf f kg g kh h ki i + else make9 kb b ka a kc c kd d ke e kf f kg g kh h ki i + else + if ka <$ kd + then make9 kb b kc c ka a kd d ke e kf f kg g kh h ki i + else make9 kb b kc c kd d ka a ke e kf f kg g kh h ki i + else + if ka <$ kg then + if ka <$ kf + then make9 kb b kc c kd d ke e ka a kf f kg g kh h ki i + else make9 kb b kc c kd d ke e kf f ka a kg g kh h ki i + else if ka <$ kh + then make9 kb b kc c kd d ke e kf f kg g ka a kh h ki i + else if ka <$ ki then + make9 kb b kc c kd d ke e kf f kg g kh h ka a ki i + else + make9 kb b kc c kd d ke e kf f kg g kh h ki i ka a + | EQ (T4 (kb,b,kc,c,kd,d,ke,e),kf,f,T4(kg,g,kh,h,ki,i,kj,j)) -> + if ka <$ kf then + if ka <$ kc then + if ka <$ kb + then make10 ka a kb b kc c kd d ke e kf f kg g kh h ki i kj j + else make10 kb b ka a kc c kd d ke e kf f kg g kh h ki i kj j + else + if ka <$ ke + then if ka <$ kd + then make10 kb b kc c ka a kd d ke e kf f kg g kh h ki i kj j + else make10 kb b kc c kd d ka a ke e kf f kg g kh h ki i kj j + else make10 kb b kc c kd d ke e ka a kf f kg g kh h ki i kj j + else + if ka <$ ki then + if ka <$ kh + then + if ka <$ kg + then make10 kb b kc c kd d ke e kf f ka a kg g kh h ki i kj j + else make10 kb b kc c kd d ke e kf f kg g ka a kh h ki i kj j + else make10 kb b kc c kd d ke e kf f kg g kh h ka a ki i kj j + else if ka <$ kj + then make10 kb b kc c kd d ke e kf f kg g kh h ki i ka a kj j + else make10 kb b kc c kd d ke e kf f kg g kh h ki i kj j ka a + + | LL ((T0| T1 _ | T2 _ | T3 _ as b),k,x, + (EQ (T4 _,_,_,T4 _) as c)) as t -> + if ka <$ k + then LL (insert ka a b,k,x,c) + else insert ka a (shake_left t) + + | LL ((T4 _ as b),k,x,c) when ka <$ k -> + EQ (insert ka a b,k,x,c) + | LR (b,k,x,(T4 _ as c)) when k <$ ka -> + EQ (b,k,x,insert ka a c) + + | EQ ((EQ (T4 _,_,_,T4 _) as b),k,_,(EQ (T4 _,_,_,T4 _) as c)) as t -> + if ka <$ k + then insert ka a b =+ t + else t += insert ka a c + | EQ (x,kb,b,(EQ (T4 _,_,_,T4 _) as y)) -> + if ka <$ kb + then EQ (insert ka a x,kb,b,y) + else pop_min y @@ {app = fun kc c y -> + EQ (insert kb b x,kc,c,insert ka a y) + } + | EQ ((EQ (T4 _,_,_,T4 _) as x),kb,b,y) -> + if ka <$ kb + then pop_max x @@ {app = fun kc c x -> + EQ (insert ka a x,kc,c,insert kb b y) + } + else EQ (x,kb,b,insert ka a y) + + | LL (b,k,_,c) as t -> + if ka <$ k + then insert ka a b =+ t + else t += insert ka a c + | LR (b,k,_,c) as t -> + if ka <$ k + then insert ka a b =+ t + else t += insert ka a c + | EQ (b,k,_,c) as t -> + if ka <$ k + then insert ka a b =+ t + else t += insert ka a c + + (* [merge k x y] *) + type merge = { + merge : 'a. 'a key -> 'a -> 'a -> 'a + } [@@unboxed] + (* we could use a GADT, but it couldn't be unboxed, + an since we could create merge functions a lot, + it is better not to allocate them.*) + + let merge + : type a b. merge -> a key -> b key -> b -> a -> a = + fun {merge} ka kb b a -> + let T = Key.same ka kb in + merge kb b a + + let app = merge + + let rec upsert ~update:ret ~insert:add ka a t = match t with + | T0 -> add (make1 ka a) + | T1 (kb,b) -> if eq ka kb + then ret (fun f -> make1 ka (app f ka kb b a)) + else add (insert ka a t) + | T2 (kb,b,kc,c) -> if eq ka kb + then ret (fun f -> make2 ka (app f ka kb b a) kc c) else if eq ka kc + then ret (fun f -> make2 kb b ka (app f ka kc c a)) + else add (insert ka a t) + | T3 (kb,b,kc,c,kd,d) -> begin match cmp ka kc with + | 0 -> ret (fun f -> make3 kb b ka (app f ka kc c a) kd d) + | 1 -> if eq ka kd + then ret (fun f -> make3 kb b kc c ka (app f ka kd d a)) + else add (insert ka a t) + | _ -> if eq ka kb + then ret (fun f -> make3 ka (app f ka kb b a) kc c kd d) + else add@@insert ka a t + end + | T4 (kb,b,kc,c,kd,d,ke,e) -> begin match cmp ka kd with + | 0 -> ret@@fun f -> + make4 kb b kc c ka (app f ka kd d a) ke e + | 1 -> if eq ka ke + then ret@@fun f -> make4 kb b kc c kd d ka (app f ka ke e a) + else add@@insert ka a t + | _ -> match cmp ka kc with + | 0 -> ret@@fun f -> + make4 kb b ka (app f ka kc c a) kd d ke e + | 1 -> add@@insert ka a t + | _ -> if eq ka kb + then ret@@fun f -> + make4 ka (app f ka kb b a) kc c kd d ke e + else add@@insert ka a t + end + | LL (x,kb,b,y) -> begin match cmp ka kb with + | 0 -> ret@@fun f -> LL (x,ka,app f ka kb b a,y) + | 1 -> upsert ka a y + ~update:(fun k -> ret@@fun f -> LL (x,kb,b,k f)) + ~insert:(fun y -> add@@ t += y) + | _ -> + upsert ka a x + ~update:(fun k -> ret@@fun f -> LL (k f,kb,b, y)) + ~insert:(fun x -> add@@ x =+ t) + end + | EQ (x,kb,b,y) -> begin match cmp ka kb with + | 0 -> ret@@fun f -> EQ (x,ka,app f ka kb b a,y) + | 1 -> upsert ka a y + ~update:(fun k -> ret@@fun f -> EQ (x,kb,b,k f)) + ~insert:(fun y -> add@@ t += y) + | _ -> upsert ka a x + ~update:(fun k -> ret@@fun f -> EQ (k f,kb,b,y)) + ~insert:(fun x -> add@@ x =+ t) + end + | LR (x,kb,b,y) -> begin match cmp ka kb with + | 0 -> ret@@fun f -> LR (x,ka,app f ka kb b a,y) + | 1 -> upsert ka a y + ~update:(fun k -> ret@@fun f -> LR (x,kb,b,k f)) + ~insert:(fun y -> add@@ t += y) + | _ -> upsert ka a x + ~update:(fun k -> ret@@fun f -> LR (k f,kb,b,y)) + ~insert:(fun x -> add@@ x =+ t) + end + + let monomorphic_merge + : type t. t key -> (t -> t -> t) -> merge = + fun k f -> { + merge = fun (type a) + (kb : a key) (b : a) (a : a) : a -> + let T = Key.same k kb in + f b a + } + + let update f ka a x = + let f = monomorphic_merge ka f in + upsert ka a x + ~update:(fun k -> k f) + ~insert:(fun x -> x) + + let set ka a x = + let f = monomorphic_merge ka (fun _ x -> x) in + upsert ka a x + ~update:(fun k -> k f) + ~insert:(fun x -> x) + + exception Field_not_found + + let return (type a b) (k : a key) (ka : b key) (a : b) : a = + let T = Key.same k ka in + a + [@@inline] + + let rec get k = function + | T0 -> raise Field_not_found + | T1 (ka,a) -> if eq k ka then return k ka a + else raise Field_not_found + | T2 (ka,a,kb,b) -> begin match cmp k kb with + | 0 -> return k kb b + | 1 -> raise Field_not_found + | _ -> if eq k ka then return k ka a + else raise Field_not_found + end + | T3 (ka,a,kb,b,kc,c) -> begin match cmp k kb with + | 0 -> return k kb b + | 1 -> if eq k kc then return k kc c + else raise Field_not_found + | _ -> if eq k ka then return k ka a + else raise Field_not_found + end + | T4 (ka,a,kb,b,kc,c,kd,d) -> begin match cmp k kc with + | 0 -> return k kc c + | 1 -> if eq k kd then return k kd d + else raise Field_not_found + | _ -> match cmp k kb with + | 0 -> return k kb b + | 1 -> raise Field_not_found + | _ -> if eq k ka then return k ka a + else raise Field_not_found + end + | LL (x,ka,a,y) -> begin match cmp k ka with + | 0 -> return k ka a + | 1 -> get k y + | _ -> get k x + end + | EQ (x,ka,a,y) -> begin match cmp k ka with + | 0 -> return k ka a + | 1 -> get k y + | _ -> get k x + end + | LR (x,ka,a,y) -> begin match cmp k ka with + | 0 -> return k ka a + | 1 -> get k y + | _ -> get k x + end + + + let find k x = try Some (get k x) with + | Field_not_found -> None + + let merge (type a) m x y = + foreach y ~init:x { + visit = fun (type b c) (ka : b key) (a : b) x -> + upsert ka a x + ~insert:(fun x -> x) + ~update:(fun k -> k m) + } + + + let sexp_of_t dict = Sexp.List (foreach ~init:[] dict { + visit = fun k x xs -> + Sexp.List [ + Sexp.Atom (Key.name k); + (Key.to_sexp k x) + ] :: xs + }) + + + let pp_field ppf (k,v) = + Format.fprintf ppf "%s : %a" + (Key.name k) + Sexp.pp_hum (Key.to_sexp k v) + + let rec pp_fields ppf = function + | T0 -> () + | T1 (ka,a) -> + Format.fprintf ppf "%a" pp_field (ka,a) + | T2 (ka,a,kb,b) -> + Format.fprintf ppf "%a;@ %a" + pp_field (ka,a) + pp_field (kb,b) + | T3 (ka,a,kb,b,kc,c) -> + Format.fprintf ppf "%a;@ %a;@ %a" + pp_field (ka,a) + pp_field (kb,b) + pp_field (kc,c) + | T4 (ka,a,kb,b,kc,c,kd,d) -> + Format.fprintf ppf "%a;@ %a;@ %a;@ %a" + pp_field (ka,a) + pp_field (kb,b) + pp_field (kc,c) + pp_field (kd,d) + | LR (x,ka,a,y) -> + Format.fprintf ppf "%a;@ %a;@ %a" + pp_fields x pp_field (ka,a) pp_fields y + | LL (x,ka,a,y) -> + Format.fprintf ppf "%a;@ %a;@ %a" + pp_fields x pp_field (ka,a) pp_fields y + | EQ (x,ka,a,y) -> + Format.fprintf ppf "%a;@ %a;@ %a" + pp_fields x pp_field (ka,a) pp_fields y + + let pp ppf t = + Format.fprintf ppf "{@[<2>@,%a@]}" pp_fields t + + let pp_elt ppf (k,v) = + Format.fprintf ppf "%a" Sexp.pp_hum (Key.to_sexp k v) + + + let rec pp_tree ppf = function + | T0 -> Format.fprintf ppf "()" + | T1 (ka,a) -> + Format.fprintf ppf "(%a)" pp_elt (ka,a) + | T2 (ka,a,kb,b) -> + Format.fprintf ppf "(%a,%a)" + pp_elt (ka,a) + pp_elt (kb,b) + | T3 (ka,a,kb,b,kc,c) -> + Format.fprintf ppf "(%a,%a,%a)" + pp_elt (ka,a) + pp_elt (kb,b) + pp_elt (kc,c) + | T4 (ka,a,kb,b,kc,c,kd,d) -> + Format.fprintf ppf "(%a,%a,%a,%a)" + pp_elt (ka,a) + pp_elt (kb,b) + pp_elt (kc,c) + pp_elt (kd,d) + | LR (x,k,a,y) -> + Format.fprintf ppf "LR(%a,%a,%a)" + pp_tree x pp_elt (k,a) pp_tree y + | LL (x,k,a,y) -> + Format.fprintf ppf "LL(%a,%a,%a)" + pp_tree x pp_elt (k,a) pp_tree y + | EQ (x,k,a,y) -> + Format.fprintf ppf "EQ(%a,%a,%a)" + pp_tree x pp_elt (k,a) pp_tree y + + let pp_key ppf {Key.name} = + Format.fprintf ppf "%s" name +end + +module Record = struct + module Key = Dict.Key + module Uid = Dict.Key.Uid + + type record = Dict.t + type t = record + type 'a key = 'a Dict.key + + module Repr = struct + type entry = { + name : string; + data : string; + } [@@deriving bin_io] + + type t = entry list [@@deriving bin_io] + end + + type vtable = { + order : 'a. 'a key -> 'a -> 'a -> Order.partial; + join : 'a. 'a key -> 'a -> 'a -> ('a,conflict) result; + inspect : 'a. 'a key -> 'a -> Sexp.t; + } + + type slot_io = { + reader : string -> record -> record; + writer : record -> string option; + } + + let io : slot_io Hashtbl.M(String).t = + Hashtbl.create (module String) + + let vtables : vtable Hashtbl.M(Uid).t = + Hashtbl.create (module Uid) + + let empty = Dict.empty + + let uid = Key.uid + let domain k = Hashtbl.find_exn vtables (uid k) + + let (<:=) x y = Dict.foreach ~init:true x { + visit = fun k x yes -> + yes && match Dict.find k y with + | None -> false + | Some y -> match (domain k).order k x y with + | LT | EQ -> true + | GT | NC -> false + } + + let order : t -> t -> Order.partial = fun x y -> + match x <:= y, y <:= x with + | true,false -> LT + | true,true -> EQ + | false,true -> GT + | false,false -> NC + + + let commit (type p) {Domain.join} (key : p Key.t) v x = + match Dict.find key v with + | None -> Ok (Dict.insert key x v) + | Some y -> match join y x with + | Ok x -> Ok (Dict.set key x v) + | Error err -> Error err + + let put k v x = Dict.set k x v + let get + : type a. a Key.t -> a Domain.t -> record -> a = + fun k {Domain.empty} data -> + match Dict.find k data with + | None -> empty + | Some x -> x + + exception Merge_conflict of conflict + + let merge_or_keep old our = + Dict.foreach our ~init:old { + visit = fun kb b out -> match Dict.find kb old with + | None -> Dict.insert kb b out + | Some a -> match (domain kb).join kb a b with + | Ok b -> Dict.set kb b out + | Error _ -> out + } + + let try_merge ~on_conflict old our = + Dict.foreach our ~init:(Ok old) { + visit = fun kb b out -> + match out with + | Error _ as err -> err + | Ok out -> match Dict.find kb old with + | None -> Ok (Dict.insert kb b out) + | Some a -> match (domain kb).join kb a b with + | Ok b -> Ok (Dict.set kb b out) + | Error err -> match on_conflict with + | `drop_both -> assert false + | `drop_left -> Ok (Dict.set kb b out) + | `drop_right -> Ok (Dict.set kb a out) + | `fail -> Error err + } + + + let join x y = try_merge ~on_conflict:`fail x y + + let eq = Dict.Key.same + + let register_persistent (type p) + (key : p Key.t) + (p : p Persistent.t) = + let slot = Key.name key in + Hashtbl.add_exn io ~key:slot ~data:{ + reader = begin fun x dict -> + let x = Persistent.of_string p x in + Dict.insert key x dict + end; + writer = begin fun dict -> + match Dict.find key dict with + | None -> None + | Some s -> Some (Persistent.to_string p s) + end + } + + include Binable.Of_binable(Repr)(struct + type t = record + let to_binable s = + Dict.foreach s ~init:[] { + visit = fun k _ xs -> + let name = Key.name k in + match Hashtbl.find io name with + | None -> xs + | Some {writer} -> + match writer s with + | None -> xs + | Some data -> Repr.{name; data} :: xs + } + + let of_binable entries = + List.fold entries ~init:empty ~f:(fun s {Repr.name; data} -> + match Hashtbl.find io name with + | None -> s + | Some {reader} -> reader data s) + end) + + let eq = Dict.Key.same + + let register_domain + : type p. p Key.t -> p Domain.t -> unit = + fun key dom -> + let vtable = { + order = begin fun (type a) (k : a key) (x : a) (y : a) -> + let T = eq k key in + dom.order x y + end; + inspect = begin fun (type a) (k : a key) (x : a) -> + let T = eq k key in + dom.inspect x; + end; + join = begin fun (type a) (k : a key) (x : a) (y : a) : + (a,conflict) result -> + let T = eq k key in + dom.join x y + end; + } in + Hashtbl.add_exn vtables ~key:(uid key) ~data:vtable + + let sexp_of_t x = Dict.sexp_of_t x + let t_of_sexp = opaque_of_sexp + let inspect = sexp_of_t + + let pp ppf x = Sexp.pp_hum ppf (inspect x) + let pp_slots slots ppf x = + let slots = Set.of_list (module String) slots in + match (inspect x : Sexp.t) with + | Atom _ -> assert false + | List xs -> + List.iter xs ~f:(function + | Sexp.List (Atom slot :: _ ) as data when Set.mem slots slot -> + Sexp.pp_hum ppf data + | _ -> ()) +end + +module Knowledge = struct + + type +'a value = { + cls : 'a; + data : Record.t; + time : Int63.t; + } + type (+'a,+'s) cls = ('a,'s) Class.t + type 'a obj = Oid.t + type 'p domain = 'p Domain.t + type 'a persistent = 'a Persistent.t + type 'a ord = Oid.comparator_witness + type conflict = Conflict.t = .. + type pid = Pid.t + type oid = Oid.t [@@deriving bin_io, compare, sexp] + + type cell = { + car : oid; + cdr : oid; + } [@@deriving bin_io, compare, sexp] + + module Cell = struct + type t = cell + include Comparable.Make_binable(struct + type t = cell [@@deriving bin_io,compare, sexp] + end) + end + + + module Env = struct + type workers = { + waiting : Pid.Set.t; + current : Pid.Set.t; + } [@@deriving bin_io] + + type work = Done | Work of workers [@@deriving bin_io] + + type objects = { + vals : Record.t Oid.Map.t; + comp : work Dict.Key.Uid.Map.t Oid.Map.t; + syms : fullname Oid.Map.t; + heap : cell Oid.Map.t; + data : Oid.t Cell.Map.t; + objs : Oid.t String.Map.t String.Map.t; + pubs : Oid.Set.t String.Map.t; + } [@@deriving bin_io] + + let empty_class = { + vals = Map.empty (module Oid); + comp = Map.empty (module Oid); + objs = Map.empty (module String); + syms = Map.empty (module Oid); + pubs = Map.empty (module String); + heap = Map.empty (module Oid); + data = Map.empty (module Cell); + } + + type t = { + classes : objects Cid.Map.t; + package : string; + } [@@deriving bin_io] + end + + type state = Env.t + + let of_bigstring = + Binable.of_bigstring (module Env) + + let to_bigstring = + Binable.to_bigstring (module Env) + + let empty : Env.t = { + package = user_package; + classes = Map.empty (module Cid); + } + + module State = struct + include Monad.State.T1(Env)(Monad.Ident) + include Monad.State.Make(Env)(Monad.Ident) + end + + module Knowledge = struct + type 'a t = ('a,conflict) Result.t State.t + include Monad.Result.Make(Conflict)(State) + end + + type 'a knowledge = 'a Knowledge.t + + open Knowledge.Syntax + + module Slot = struct + type 'p promise = { + run : Oid.t -> unit Knowledge.t; + pid : pid; + } + + type (+'a,'p) t = { + cls : ('a,unit) cls; + dom : 'p Domain.t; + key : 'p Dict.Key.t; + name : string; + desc : string option; + promises : (pid, 'p promise) Hashtbl.t; + } + + type pack = Pack : ('a,'p) t -> pack + let repository = Hashtbl.create (module Cid) + + let register slot = + Hashtbl.update repository slot.cls.id ~f:(function + | None -> [Pack slot] + | Some xs -> Pack slot :: xs) + + let enum {Class.id} = Hashtbl.find_multi repository id + + let declare ?desc ?persistent ?package cls name (dom : 'a Domain.t) = + let slot = Registry.add_slot ?desc ?package name in + let name = string_of_fname slot in + let key = Dict.Key.create ~name dom.inspect in + Option.iter persistent (Record.register_persistent key); + Record.register_domain key dom; + let promises = Hashtbl.create (module Pid) in + let cls = Class.refine cls () in + let slot = {cls; dom; key; name; desc; promises} in + register slot; + slot + + let cls x = x.cls + let domain x = x.dom + let name x = x.name + let desc x = match x.desc with + | None -> "no description" + | Some s -> s + end + + type (+'a,'p) slot = ('a,'p) Slot.t + + module Value = struct + type +'a t = 'a value + + (* we could use an extension variant or create a new OCaml object + instead of incrementing a second, but they are less reliable + and heavier *) + let next_second = + let current = ref Int63.zero in + fun () -> Int63.incr current; !current + + let empty cls = + {cls; data=Record.empty; time = next_second ()} + + let order {data=x} {data=y} = Record.order x y + + let refine {data; cls; time} s= + {data; time; cls = Class.refine cls s} + + let cls {cls} = cls + let create cls data = {cls; data; time = next_second ()} + let put {Slot.key} v x = { + v with data = Record.put key v.data x; + time = next_second () + } + let get {Slot.key; dom} {data} = Record.get key dom data + let strip + : type a b. (a value, b value) Type_equal.t -> (a,b) Type_equal.t = + fun T -> T + + type strategy = [`drop_left | `drop_right | `drop_both] + + let merge ?(on_conflict=`drop_old) x y = + let on_conflict : strategy = match on_conflict with + | `drop_old -> if Int63.(x.time < y.time) + then `drop_left else `drop_right + | `drop_new -> if Int63.(x.time < y.time) + then `drop_right else `drop_left + | #strategy as other -> other in + match Record.try_merge ~on_conflict x.data y.data with + | Ok data -> { + x with time = next_second (); + data; + } + | Error _ -> + (* try_merge fails only if `fail is passed *) + assert false + + let join x y = match Record.join x.data y.data with + | Ok data -> Ok {x with data; time = next_second ()} + | Error c -> Error c + + module type S = sig + type t [@@deriving sexp] + val empty : t + val domain : t domain + include Base.Comparable.S with type t := t + include Binable.S with type t := t + end + + module Comparator = Base.Comparator.Make1(struct + type 'a t = 'a value + let sexp_of_t = sexp_of_opaque + let compare x y = match Record.order x.data y.data with + | LT -> -1 + | EQ -> 0 + | GT -> 1 + | NC -> Int63.compare x.time y.time + end) + + include Comparator + + type 'a ord = comparator_witness + + let derive + : type a b. (a,b) cls -> + (module S with type t = (a,b) cls t + and type comparator_witness = (a,b) cls ord) = + fun cls -> + let module R = struct + type t = (a,b) cls value + let sexp_of_t x = Record.sexp_of_t x.data + let t_of_sexp = opaque_of_sexp + let empty = empty cls + + include Binable.Of_binable(Record)(struct + type t = (a,b) cls value + let to_binable : 'a value -> Record.t = + fun {data} -> data + let of_binable : Record.t -> 'a value = + fun data -> {cls; data; time = next_second ()} + end) + type comparator_witness = Comparator.comparator_witness + include Base.Comparable.Make_using_comparator(struct + type t = (a,b) cls value + let sexp_of_t = sexp_of_t + include Comparator + end) + let domain = Domain.define ~empty ~order ~join + ~inspect:sexp_of_t + (Class.name cls) + end in + (module R) + + let pp ppf x = Record.pp ppf x.data + let pp_slots slots ppf x = Record.pp_slots slots ppf x.data + end + + module Class = struct + include Class + let property = Slot.declare + module Abstract = struct + let property = Slot.declare + end + end + + let get () = Knowledge.lift (State.get ()) + let put s = Knowledge.lift (State.put s) + let gets f = Knowledge.lift (State.gets f) + let update f = Knowledge.lift (State.update f) + + exception Non_monotonic_update of string * Conflict.t [@@deriving sexp] + + let provide : type a p. (a,p) slot -> a obj -> p -> unit Knowledge.t = + fun slot obj x -> + if Domain.is_empty slot.dom x + then Knowledge.return () + else + get () >>= function {classes} as s -> + let {Env.vals} as objs = + match Map.find classes slot.cls.id with + | None -> Env.empty_class + | Some objs -> objs in + try put { + s with classes = Map.set classes ~key:slot.cls.id ~data:{ + objs with vals = Map.update vals obj ~f:(function + | None -> Record.(put slot.key empty x) + | Some v -> match Record.commit slot.dom slot.key v x with + | Ok r -> r + | Error err -> raise (Record.Merge_conflict err))}} + with Record.Merge_conflict err -> + Knowledge.fail (Non_monotonic_update (Slot.name slot, err)) + + let pids = ref Pid.zero + + let register_promise (s : _ slot) run = + Pid.incr pids; + let pid = !pids in + Hashtbl.add_exn s.promises pid {run; pid} + + let promise s get = + register_promise s @@ fun obj -> + get obj >>= fun x -> + if Domain.is_empty s.dom x + then Knowledge.return () + else provide s obj x + + let objects {Class.id} = + get () >>| fun {classes} -> + match Map.find classes id with + | None -> Env.empty_class + | Some objs -> objs + + let uid {Slot.key} = Dict.Key.uid key + + let status + : ('a,_) slot -> 'a obj -> slot_status knowledge = + fun slot obj -> + objects slot.cls >>| fun {comp} -> + match Map.find comp obj with + | None -> Sleep + | Some slots -> match Map.find slots (uid slot) with + | None -> Sleep + | Some Work _ -> Awoke + | Some Done -> Ready + + let update_slot + : ('a,_) slot -> 'a obj -> _ -> unit knowledge = + fun slot obj f -> + objects slot.cls >>= fun ({comp} as objs) -> + let comp = Map.update comp obj ~f:(fun slots -> + let slots = match slots with + | None -> Map.empty (module Dict.Key.Uid) + | Some slots -> slots in + Map.update slots (uid slot) ~f) in + get () >>= fun s -> + let classes = Map.set s.classes slot.cls.id {objs with comp} in + put {s with classes} + + let enter_slot : ('a,_) slot -> 'a obj -> unit knowledge = fun s x -> + update_slot s x @@ function + | Some _ -> assert false + | None -> Work { + waiting = Set.empty (module Pid); + current = Set.empty (module Pid) + } + + let leave_slot : ('a,'p) slot -> 'a obj -> unit Knowledge.t = fun s x -> + update_slot s x @@ function + | Some (Work _) -> Done + | _ -> assert false + + let update_work s x f = + update_slot s x @@ function + | Some (Work w) -> f w + | _ -> assert false + + let enter_promise s x p = + update_work s x @@ fun {waiting; current} -> + Work {waiting; current = Set.add current p} + + let leave_promise s x p = + update_work s x @@ fun {waiting; current} -> + Work {waiting; current = Set.remove current p} + + let enqueue_promises s x = + update_work s x @@ fun {waiting; current} -> + Work {waiting = Set.union current waiting; current} + + let collect_waiting + : ('a,'p) slot -> 'a obj -> _ Knowledge.t = fun s x -> + objects s.cls >>| fun {comp} -> + Map.find_exn (Map.find_exn comp x) (uid s) |> function + | Env.Done -> assert false + | Env.Work {waiting} -> + Set.fold waiting ~init:[] ~f:(fun ps p -> + Hashtbl.find_exn s.Slot.promises p :: ps) + + let dequeue_waiting s x = update_work s x @@ fun _ -> + Work { + waiting = Set.empty (module Pid); + current = Set.empty (module Pid) + } + + let initial_promises {Slot.promises} = Hashtbl.data promises + + let current : type a p. (a,p) slot -> a obj -> p Knowledge.t = + fun slot id -> + objects slot.cls >>| fun {Env.vals} -> + match Map.find vals id with + | None -> slot.dom.empty + | Some v -> Record.get slot.key slot.dom v + + let rec collect_inner + : ('a,'p) slot -> 'a obj -> _ -> _ = + fun slot obj promises -> + current slot obj >>= fun was -> + Knowledge.List.iter promises ~f:(fun {Slot.run; pid} -> + enter_promise slot obj pid >>= fun () -> + run obj >>= fun () -> + leave_promise slot obj pid) >>= fun () -> + collect_waiting slot obj >>= fun waiting -> + dequeue_waiting slot obj >>= fun () -> + match waiting with + | [] -> Knowledge.return () + | promises -> + current slot obj >>= fun now -> + match slot.dom.order now was with + | EQ | LT -> Knowledge.return () + | GT | NC -> collect_inner slot obj promises + + + let collect : type a p. (a,p) slot -> a obj -> p Knowledge.t = + fun slot id -> + status slot id >>= function + | Ready -> + current slot id + | Awoke -> + enqueue_promises slot id >>= fun () -> + current slot id + | Sleep -> + enter_slot slot id >>= fun () -> + collect_inner slot id (initial_promises slot) >>= fun () -> + leave_slot slot id >>= fun () -> + current slot id + + let resolve slot obj = + collect slot obj >>| Opinions.choice + + let suggest agent slot obj x = + current slot obj >>= fun opinions -> + provide slot obj (Opinions.add agent x opinions) + + let propose agent s get = + register_promise s @@ fun obj -> + get obj >>= suggest agent s obj + + module Object = struct + type +'a t = 'a obj + type 'a ord = Oid.comparator_witness + + let with_new_object objs f = match Map.max_elt objs.Env.vals with + | None -> f Oid.first_atom { + objs + with vals = Map.singleton (module Oid) Oid.first_atom Record.empty + } + | Some (key,_) -> + let key = Oid.next key in + f key { + objs + with vals = Map.add_exn objs.vals ~key ~data:Record.empty + } + + let create : ('a,_) cls -> 'a obj Knowledge.t = fun cls -> + objects cls >>= fun objs -> + with_new_object objs @@ fun obj objs -> + update @@begin function {classes} as s -> { + s with classes = Map.set classes ~key:cls.id ~data:objs + } + end >>| fun () -> + obj + + (* an interesting question, what we shall do if + 1) an symbol is deleted + 2) a data object is deleted? + + So far we ignore both deletes. + *) + let delete {Class.id} obj = + update @@ function {classes} as s -> { + s with + classes = Map.change classes id ~f:(function + | None -> None + | Some objs -> Some { + objs with + vals = Map.remove objs.vals obj; + comp = Map.remove objs.comp obj; + }) + } + + let scoped cls scope = + create cls >>= fun obj -> + scope obj >>= fun r -> + delete cls obj >>| fun () -> + r + + let do_intern = + let is_public ~package name {Env.pubs} = + match Map.find pubs package with + | None -> false + | Some pubs -> Set.mem pubs name in + let unchanged id = Knowledge.return id in + let publicize ~package obj: Env.objects -> Env.objects = + fun objects -> { + objects with pubs = Map.update objects.pubs package ~f:(function + | None -> Set.singleton (module Oid) obj + | Some pubs -> Set.add pubs obj) + } in + let createsym ~public ~package name classes clsid objects s = + with_new_object objects @@ fun obj objects -> + let syms = Map.set objects.syms obj {package; name} in + let objs = Map.update objects.objs package ~f:(function + | None -> Map.singleton (module String) name obj + | Some names -> Map.set names name obj) in + let objects = {objects with objs; syms} in + let objects = if public + then publicize ~package obj objects else objects in + put {s with classes = Map.set classes clsid objects} >>| fun () -> + obj in + + fun ?(public=false) ?desc:_ ?package name {Class.id} -> + get () >>= fun ({classes} as s) -> + let package = Option.value package ~default:s.package in + let name = normalize_name ~package name in + let objects = match Map.find classes id with + | None -> Env.empty_class + | Some objs -> objs in + match Map.find objects.objs package with + | None -> createsym ~public ~package name classes id objects s + | Some names -> match Map.find names name with + | None -> createsym ~public ~package name classes id objects s + | Some obj when not public -> unchanged obj + | Some obj -> + if is_public ~package obj objects then unchanged obj + else + let objects = publicize ~package obj objects in + put {s with classes = Map.set classes id objects} >>| fun () -> + obj + + (* any [:] in names here are never treated as separators, + contrary to [read], where they are, and [do_intern] where + a leading [:] in a name will be left for keywords *) + let intern ?public ?desc ?package name cls = + let name = escaped name in + do_intern ?public ?desc ?package name cls + + let uninterned_repr cls obj = + Format.asprintf "#<%s %a>" cls Oid.pp obj + + let to_string + {Class.name=cls as fname; id=cid} {Env.package; classes} obj = + let cls = if package = cls.package then cls.name + else string_of_fname fname in + match Map.find classes cid with + | None -> uninterned_repr cls obj + | Some {Env.syms} -> match Map.find syms obj with + | Some fname -> if fname.package = package + then fname.name + else string_of_fname fname + | None -> uninterned_repr cls obj + + let repr cls obj = + get () >>| fun env -> + to_string cls env obj + + let read cls input = + try + Scanf.sscanf input "#<%s %s@>" @@ fun _ obj -> + Knowledge.return (Oid.atom_of_string obj) + with _ -> + get () >>= fun {Env.package} -> + let {package; name} = split_name package input in + do_intern ~package name cls + + let cast : type a b. (a obj, b obj) Type_equal.t -> a obj -> b obj = + fun Type_equal.T x -> x + + let id x = Oid.untagged x + + module type S = sig + type t [@@deriving sexp] + include Base.Comparable.S with type t := t + include Binable.S with type t := t + end + + let derive : type a. (a,_) cls -> + (module S + with type t = a obj + and type comparator_witness = a ord) = fun _ -> + let module Comparator = struct + type t = a obj + let sexp_of_t = Oid.sexp_of_t + let t_of_sexp = Oid.t_of_sexp + type comparator_witness = a ord + let comparator = Oid.comparator + end in + let module R = struct + include Comparator + include Binable.Of_binable(Oid)(struct + type t = a obj + let to_binable = ident + let of_binable = ident + end) + include Base.Comparable.Make_using_comparator(Comparator) + end in + (module R) + end + + module Domain = struct + include Domain + + let inspect_obj name x = + Sexp.Atom (Format.asprintf "#<%s %a>" name Oid.pp x) + + let obj {Class.name} = + let name = string_of_fname name in + total ~inspect:(inspect_obj name) ~empty:Oid.zero + ~order:Oid.compare name + end + module Order = Order + module Persistent = Persistent + + module Symbol = struct + let intern = Object.intern + let keyword = keyword_package + + let in_package package f = + get () >>= function {Env.package=old_package} as s -> + put {s with package} >>= fun () -> + f () >>= fun r -> + update (fun s -> {s with package = old_package}) >>| fun () -> + r + + + exception Import of fullname * fullname [@@deriving sexp_of] + + let intern_symbol ~package name obj cls = + Knowledge.return Env.{ + cls + with objs = Map.update cls.objs package ~f:(function + | None -> Map.singleton (module String) name obj + | Some names -> Map.set names name obj)} + + + + (* imports names inside a class. + + All names that [needs_import] will be imported + into the [package]. If the [package] already had + the same name but with different value, then a + [strict] import will raise an error, otherwise it + will be overwritten with the new value. + *) + let import_class ~strict ~package ~needs_import + : Env.objects -> Env.objects knowledge + = fun cls -> + Map.to_sequence cls.syms |> + Knowledge.Seq.fold ~init:cls ~f:(fun cls (obj,sym) -> + if not (needs_import cls sym obj) + then Knowledge.return cls + else + let obj' = match Map.find cls.objs package with + | None -> Oid.zero + | Some names -> match Map.find names sym.name with + | None -> Oid.zero + | Some obj' -> obj' in + if not strict || Oid.(obj' = zero || obj' = obj) + then intern_symbol ~package sym.name obj cls + else + let sym' = Map.find_exn cls.syms obj' in + Knowledge.fail (Import (sym,sym'))) + + let package_exists package = Map.exists ~f:(fun {Env.objs} -> + Map.mem objs package) + + let name_exists {package; name} = Map.exists ~f:(fun {Env.objs} -> + match Map.find objs package with + | None -> false + | Some names -> Map.mem names name) + + + exception Not_a_package of string [@@deriving sexp_of] + exception Not_a_symbol of fullname [@@deriving sexp_of] + + let check_name classes = function + | `Pkg name -> if package_exists name classes + then Knowledge.return () + else Knowledge.fail (Not_a_package name) + | `Sym sym -> if name_exists sym classes + then Knowledge.return () + else Knowledge.fail (Not_a_symbol sym) + + let current = function + | Some p -> Knowledge.return p + | None -> gets (fun s -> s.package) + + let import ?(strict=false) ?package imports : unit knowledge = + current package >>| escaped >>= fun package -> + get () >>= fun s -> + Knowledge.List.fold ~init:s.classes imports ~f:(fun classes name -> + let name = match find_separator name with + | None -> `Pkg name + | Some _ -> `Sym (split_name package name) in + let needs_import {Env.pubs} sym obj = match name with + | `Sym s -> sym = s + | `Pkg p -> match Map.find pubs p with + | None -> false + | Some pubs -> Set.mem pubs obj in + check_name classes name >>= fun () -> + Map.to_sequence classes |> + Knowledge.Seq.fold ~init:classes + ~f:(fun classes (clsid,objects) -> + import_class ~strict ~package ~needs_import objects + >>| fun objects -> + Map.set classes clsid objects)) + >>= fun classes -> put {s with classes} + end + + + module Data = struct + type +'a t = 'a obj + type 'a ord = Oid.comparator_witness + + let atom _ x = Knowledge.return x + + let add_cell {Class.id} objects oid cell = + let {Env.data; heap} = objects in + let data = Map.add_exn data ~key:cell ~data:oid in + let heap = Map.add_exn heap ~key:oid ~data:cell in + update (fun s -> { + s with classes = Map.set s.classes id { + objects with data; heap + }}) >>| fun () -> + oid + + let cons cls car cdr = + let cell = {car; cdr} in + objects cls >>= function {data; heap} as s -> + match Map.find data cell with + | Some id -> Knowledge.return id + | None -> match Map.max_elt heap with + | None -> + add_cell cls s Oid.first_cell cell + | Some (id,_) -> + add_cell cls s (Oid.next id) cell + + let case cls x ~null ~atom ~cons = + if Oid.is_null x then null else + if Oid.is_atom x || Oid.is_number x then atom x + else objects cls >>= fun {Env.heap} -> + let cell = Map.find_exn heap x in + cons cell.car cell.cdr + + let id = Object.id + + module type S = Object.S + let derive = Object.derive + end + + module Syntax = struct + include Knowledge.Syntax + let (-->) x p = collect p x + let (<--) p f = promise p f + let (//) c s = Object.read c s + end + + module type S = sig + include Monad.S with type 'a t = 'a knowledge + and module Syntax := Syntax + include Monad.Fail.S with type 'a t := 'a knowledge + and type 'a error = conflict + end + include (Knowledge : S) + + + let compute_value + : type a p . (a,p) cls -> p obj -> unit knowledge + = fun cls obj -> + Slot.enum cls |> List.iter ~f:(fun (Slot.Pack s) -> + collect s obj >>= fun v -> + provide s obj v) + + + let get_value cls obj = + compute_value cls obj >>= fun () -> + objects cls >>| fun {Env.vals} -> + match Map.find vals obj with + | None -> Value.empty cls + | Some x -> Value.create cls x + + let run cls obj s = + match State.run (obj >>= get_value cls) s with + | Ok x,s -> Ok (x,s) + | Error err,_ -> Error err + + + let pp_fullname ~package ppf {package=p; name} = + if package = p + then Format.fprintf ppf "%s" name + else Format.fprintf ppf "%s:%s" p name + + let pp_state ppf {Env.classes; package} = + Format.fprintf ppf "(in-package %s)@\n" package; + Map.iteri classes ~f:(fun ~key:cid ~data:{vals;syms} -> + let name = Hashtbl.find_exn Class.names cid in + Format.fprintf ppf "(in-class %a)@\n" + (pp_fullname ~package) name; + Map.iteri vals ~f:(fun ~key:oid ~data -> + if not (Dict.is_empty data) then + let () = match Map.find syms oid with + | None -> + Format.fprintf ppf "@[<2>(%a@ " Oid.pp oid + | Some name -> + Format.fprintf ppf "@[<2>(%a@ " + (pp_fullname ~package) name in + Format.fprintf ppf "@,%a)@]@\n" + (Sexp.pp_hum_indent 2) (Dict.sexp_of_t data))) + + module Conflict = Conflict + module Agent = Agent + type 'a opinions = 'a Opinions.t + type agent = Agent.t + let sexp_of_conflict = Conflict.sexp_of_t +end + +type 'a knowledge = 'a Knowledge.t diff --git a/lib/knowledge/bap_knowledge.mli b/lib/knowledge/bap_knowledge.mli new file mode 100644 index 000000000..64e79dd99 --- /dev/null +++ b/lib/knowledge/bap_knowledge.mli @@ -0,0 +1,621 @@ +open Core_kernel +open Monads.Std + +type 'a knowledge +module Knowledge : sig + type 'a t = 'a knowledge + type (+'k,+'s) cls + type +'a obj + type +'a value + type (+'a,'p) slot + type 'p domain + type 'a persistent + type state + type conflict = .. + + type agent + type 'a opinions + + (** state with no knowledge *) + val empty : state + + val of_bigstring : Bigstring.t -> state + + val to_bigstring : state -> Bigstring.t + + val pp_state : Format.formatter -> state -> unit + + (** [collect p x] collects the value of the property [p]. + + If the object [x] doesn't have a value for the property [p] and + there are promises registered in the knowledge system, to compute + the property [p] then they will be invoked, otherwise the empty + value of the property domain is returned as the result. *) + val collect : ('a,'p) slot -> 'a obj -> 'p t + + + (** [resolve p x] resolves the multi-opinion property [p] + + Finds a common resolution for the property [p] using + the current resolution strategy. + + This function is the same as [collect] except it collects + a value from the opinions domain and computes the current + consensus. + *) + val resolve : ('a,'p opinions) slot -> 'a obj -> 'p t + + (** [provide p x v] provides the value [v] for the property [p]. + + If the object [x] already had a value [v'] then the provided + value [v] then the result value of [p] is [join v v'] provided + such exists, where [join] is [Domain.join (Slot.domain p)]. + + If [join v v'] doesn't exist (i.e., it is [Error conflict]) + then [provide p x v] diverges into a conflict. + *) + val provide : ('a,'p) slot -> 'a obj -> 'p -> unit t + + + (** [suggest a p x v] suggests [v] as the value for the property [p]. + + The same as [provide] except the provided value is predicated by + the agent identity. + *) + val suggest : agent -> ('a,'p opinions) slot -> 'a obj -> 'p -> unit t + + (** [promise p f] promises to compute the property [p]. + + If no knowledge exists about the property [p] of + an object [x], then [f x] is invoked to provide an + initial value. + + If there are more than one promises, then they all must + provide a consistent answer. The function [f] may refer + to the property [p] directly or indirectly. In that case + the least fixed point solution of all functions [g] involved + in the property computation is computed. + *) + val promise : ('a,'p) slot -> ('a obj -> 'p t) -> unit + + (** [propose p f] proposes the opinion computation. + + The same as [promise] except that it promises a value for + an opinion based property. + *) + val propose : agent -> ('a, 'p opinions) slot -> ('a obj -> 'p t) -> unit + + + + val run : ('k,'s) cls -> 'k obj t -> state -> (('k,'s) cls value * state, conflict) result + + module Syntax : sig + include Monad.Syntax.S with type 'a t := 'a t + + + (** [x-->p] is [collect p x] *) + val (-->) : 'a obj -> ('a,'p) slot -> 'p t + + + (** [p <-- f] is [promise p f] *) + val (<--) : ('a,'p) slot -> ('a obj -> 'p t) -> unit + + + (** [c // s] is [Object.read c s] *) + val (//) : ('a,_) cls -> string -> 'a obj t + end + + + include Monad.S with type 'a t := 'a t + and module Syntax := Syntax + + include Monad.Fail.S with type 'a t := 'a t + and type 'a error = conflict + + (** Orders knowledge by information content. + + The [Order.partial] is a generalization of the total order, + which is used to compare the amount of information in two + specifications of knowledge. + + *) + module Order : sig + + (** partial ordering for two way comparison. + + The semantics of constructors: + - [LT] - strictly less information + - [GT] - strictly more information + - [EQ] - equal informational content + - [NC] - non-comparable entities + *) + type partial = LT | EQ | GT | NC + + + module type S = sig + + (** a partially ordered type *) + type t + + (** defines a partial order relationship between two entities. + + Given a partial ordering relation [<=] + - [order x y = LT iff x <= y && not (y <= x)] + - [order x y = GT iff y <= x && not (x <= y)] + - [order x y = EQ iff x <= y && y <= x] + - [order x y = NC iff not (x <= y) && not (y <= x)] + *) + val order : t -> t -> partial + end + end + + + (** Class is a collection of sorts. + + A class [k] is denoted by an indexed type [(k,s) cls], where + [s] is a sort. + *) + module Class : sig + type (+'k,'s) t = ('k,'s) cls + + + (** [declare ?desc ?package name sort] declares a new + class with the given [name] and [sort] index. *) + val declare : ?desc:string -> ?package:string -> string -> + 's -> ('k,'s) cls + + + (** [refine cls s] refines the [sort] of class ['k] to [s]. *) + val refine : ('k,_) cls -> 's -> ('k,'s) cls + + (** [same x y] is true if [x] and [y] denote the same class [k] *) + val same : ('a,_) cls -> ('b,_) cls -> bool + + (** [equal x y] constructs a type witness of classes equality. + + The witness could be used to cast objects of the same class, + e.g., + + {[ + match equal bitv abs with + | Some t -> Object.cast t x y + | _ -> ... + ]} + + Note that the equality is reflexive, so the obtained witness + could be used in both direction, for upcasting and downcasting. + *) + val equal : ('a,_) cls -> ('b,_) cls -> ('a obj, 'b obj) Type_equal.t option + + + (** [assert_equal x y] asserts the equality of two classes. + + Usefull, in the context where the class is known for sure, + (e.g., constrained by the module signature), but has to be + recreated. The [let T = assert_equal x y] expression, + establishes a type equality between objects in the typing + context, so there is no need to invoke [Object.cast]. + + {[ + let add : value obj -> value obj -> value obj = fun x y -> + let T = assert_equal bitv value in + x + y (* where (+) has type [bitv obj -> bitv obj -> bit obj] *) + ]} + *) + val assert_equal : ('a,_) cls -> ('b,_) cls -> ('a obj, 'b obj) Type_equal.t + + + (** [property ?desc ?persistent ?package cls name dom] declares + a new property of all instances of class [k]. + + Returns a slot, that is used to access this property. + *) + val property : + ?desc:string -> + ?persistent:'p persistent -> + ?package:string -> + ('k,_) cls -> string -> 'p domain -> ('k,'p) slot + + val name : ('a,_) cls -> string + val package : ('a,_) cls -> string + val fullname : ('a,_) cls -> string + + + (** [sort cls] returns the sort index of the class [k]. *) + val sort : ('k,'s) cls -> 's + end + + module Object : sig + type 'a t = 'a obj + type 'a ord + + (** [create] is a fresh new object with an idefinite extent. *) + val create : ('a,_) cls -> 'a obj knowledge + + (** [scoped scope] pass a fresh new object to [scope]. + + The extent of the created object is limited with the extent + of the function [scope].*) + val scoped : ('a,_) cls -> ('a obj -> 'b knowledge) -> 'b knowledge + + (** [repr x] returns a textual representation of the object [x] *) + val repr : ('a,_) cls -> 'a t -> string knowledge + + (** [read s] returns an object [x] such that [repr x = s]. *) + val read : ('a,_) cls -> string -> 'a t knowledge + + + (** [cast class_equality x] changes the type of an object. + + Provided with an equality of two object types, returns + the same object [x] with a new type. + + The type equality of two object types could be obtained + through [Class.equal] or [Class.assert_equal]. Note, this + function doesn't do any magic, this is just the + [Type_equal.conv], lifted into the [Object] module for + covenience. + *) + val cast : ('a obj, 'b obj) Type_equal.t -> 'a obj -> 'b obj + + + val id : 'a obj -> Int63.t + + module type S = sig + type t [@@deriving sexp] + include Base.Comparable.S with type t := t + include Binable.S with type t := t + end + + val derive : ('a,'d) cls -> (module S + with type t = 'a obj + and type comparator_witness = 'a ord) + end + + module Value : sig + type 'a t = 'a value + type 'a ord + include Type_equal.Injective with type 'a t := 'a t + + val empty : ('a,'b) cls -> ('a,'b) cls value + val order : 'a value -> 'a value -> Order.partial + val join : 'a value -> 'a value -> ('a value,conflict) result + val merge : ?on_conflict:[ + | `drop_old + | `drop_new + | `drop_right + | `drop_left + ] -> 'a value -> 'a value -> 'a value + + + (** [cls x] is the class of [x] *) + val cls : ('k,'s) cls value -> ('k,'s) cls + val get : ('k,'p) slot -> ('k,_) cls value -> 'p + val put : ('k,'p) slot -> ('k,'s) cls value -> 'p -> ('k,'s) cls value + + (** [refine v s] refines the sort of [v] to [s]. *) + val refine : ('k,_) cls value -> 's -> ('k,'s) cls value + + module type S = sig + type t [@@deriving sexp] + val empty : t + val domain : t domain + include Base.Comparable.S with type t := t + include Binable.S with type t := t + end + + val derive : ('a,'s) cls -> + (module S + with type t = ('a,'s) cls t + and type comparator_witness = ('a,'s) cls ord) + + val pp : Format.formatter -> 'a value -> unit + + val pp_slots : string list -> Format.formatter -> 'a value -> unit + end + + module Slot : sig + type ('a,'p) t = ('a,'p) slot + + val domain : ('a,'p) slot -> 'p domain + val cls : ('a,_) slot -> ('a, unit) cls + val name : ('a,'p) slot -> string + val desc : ('a,'p) slot -> string + end + + + (** A symbol is an object with unique name. + + Sometimes it is necessary to refer to an object by name, so that + a chosen name will always identify the same object. Finding or + creating an object by name is called "interning" it. A symbol + that has a name is called an "interned symbol". However we + stretch the boundaries of the symbol idea, by treating all other + objects as "uninterned symbols". So that any object could be + treated as a symbol. + + To prevent name clashing, that introduces unwanted equalities, + we employ the system of packages, where each symbol belongs + to a package, called its home package. The large system design + is leveraged due to the mechanism of symbol importing, where the + same symbol could be referenced from different packages (see + [import] and [in_package] functions, for more information). + + {3 Symbol syntax} + + The [read] function enables translation of the symbol textual + representation to an object. The symbol syntax is designed to + be verstatile so it can allow arbitrary sets of characters, to + enable support for modeling different knowledge domains. Only + two characters has the special meaning for the symbol reader, + the [:] character acts as a separator between the package and + the name constituent, and the [\\] symbol escapes any special + treatment of a symbol that follows it (including the [\\] + itself). When a symbol is read, an absence of the package is + treated the same as if the [package] parameter of the [create] + function wasn't set, e.g., + [read c "x"] is the same as [create c "x"], while an empty package + denotes the [keyword] package, e.g., + [read c ":x"] is the same as [create ~package:keyword c "x"]. + + + {3 Name equality} + + The equality of two names is defined by equality of their + byte representation. Hence, symbols which differ in register + will be treated differently, e.g., [Foo <> foo]. + *) + module Symbol : sig + + (** [intern ?public ?desc ?package name cls] interns a symbol in + a package. + + If a symbol with the given name is already interned in a + package, then returns its value, otherwise creates a new + object. + + If symbol is [public] then it might be advertised and be + accessible during the introspection. It is recommeneded to + provide a description string if a symbol is public. Note, a + non-public symbol still could be obtained by anyone who knows + the name. + + If the function is called in the scope of one or more + [in_package pkg], then the [package] parameter defaults to + [pkg], otherwise it defaults to ["user"]. See also the + [keyword] package for the special package that holds constants + and keywords. + + The [desc], [package], and [name] parameters could be + arbitrary strings (including empty). Any occurence of the + package separator symbol ([:]) will be escaped and won't be + treated as a package/name separator. + *) + val intern : ?public:bool -> ?desc:string -> ?package:string -> string -> + ('a,_) cls -> 'a obj knowledge + + (** [keyword = "keyword"] is the special name for the package + that contains keywords. Basically, keywords are special kinds + of symbols whose meaning is defined by their names and nothing + else. *) + val keyword : string + + (** [in_package pkg f] makes [pkg] the default package in [f]. + + Every reference to an unqualified symbol in the scope of the + [f] function will be treated as it is qualified with the + package [pkg]. This function will affect both the reader and + the pretty printer, thus [in_package "pkg" @@ Obj.repr buf] will + yield something like [#], instead of [#]. + *) + val in_package : string -> (unit -> 'a knowledge) -> 'a knowledge + + + (** [import ?strict ?package:p names] imports all [names] into [p]. + + The [names] elements could be either package names or + qualified names. If an element is a package, then all public + names from this package are imported into the package [p]. If + an element is a qualified symbol then it is imported into [p], + even if it is not public in the package from which it is being + imported. + + If any of the elements of the [names] list doesn't represent a + known package or known symbol, then a conflict is raised, + either [Not_a_package] or [Not_a_symbol]. + + If [strict] is [true] then no name can change its value during + the import. Otherwise, if the name is alredy in present in + package [p] with a different value, then it will be + overwritten with the new value, i.e., shadowed. + + All names are processed in order, so names imported from + packages that are in the beginning of the list could be + shadowed by the names that are in the end of the list (unless + [strict] is [true], of course). Thus, + {[ + import [x] >>= fun () -> + import [y] + ]} + + is the same as [import [x;y]]. + + Note, all imported names are added as not public. + + If the [package] parameter is not specified, then names are + imported into the current package, as set by the [in_package] + function. + *) + val import : ?strict:bool -> ?package:string -> string list -> unit knowledge + end + + module Agent : sig + type t = agent + type id + type reliability + + val register : + ?desc:string -> + ?package:string -> + ?reliability:reliability -> string -> agent + + val registry : unit -> id list + + val name : id -> string + val desc : id -> string + val reliability : id -> reliability + val set_reliability : id -> reliability -> unit + + val authorative : reliability + val reliable : reliability + val trustworthy : reliability + val doubtful : reliability + val unreliable : reliability + + val pp : Format.formatter -> t -> unit + val pp_id : Format.formatter -> id -> unit + val pp_reliability : Format.formatter -> reliability -> unit + end + + module Domain : sig + type 'a t = 'a domain + + + (** [define ~inspect ~empty ~order name] defines a domain for the + type ['a]. + + The [empty] value denotes the representation of an absence of + information, or an undefined value, or the default value, or + the least possible value in the chain, etc. It's only required + that for all possible values [x], [empty <= x], where [<=] + is the partial order defined by the [order] parameter + + The [order] function defines the partial order for the given + domain, such that + - [partial x y = LT] iff [x<=y && not(y <= x)] + - [partial x y = EQ] iff [x <= y && y <= x] + - [partial x y = GT] iff [not (x <= y) && (y <= x)] + - [partial x y = NC] iff [not (x <= y) && not (y <= x)]. + + The optional [inspect] function enables introspection, and may + return any representation of the domain value. + *) + val define : + ?inspect:('a -> Base.Sexp.t) -> + ?join:('a -> 'a -> ('a,conflict) result) -> + empty:'a -> + order:('a -> 'a -> Order.partial) -> string -> 'a domain + + val total : + ?inspect:('a -> Base.Sexp.t) -> + ?join:('a -> 'a -> ('a,conflict) result) -> + empty:'a -> + order:('a -> 'a -> int) -> + string -> 'a domain + + val flat : + ?inspect:('a -> Base.Sexp.t) -> + ?join:('a -> 'a -> ('a,conflict) result) -> + empty:'a -> + equal:('a -> 'a -> bool) -> + string -> 'a domain + + val optional : + ?inspect:('a -> Base.Sexp.t) -> + ?join:('a -> 'a -> ('a,conflict) result) -> + equal:('a -> 'a -> bool) -> + string -> 'a option domain + + val mapping : + ('a,'e) Map.comparator -> + ?inspect:('d -> Base.Sexp.t) -> + equal:('d -> 'd -> bool) -> + string -> + ('a,'d,'e) Map.t domain + + val powerset : ('a,'e) Set.comparator -> + ?inspect:('a -> Sexp.t) -> + string -> + ('a,'e) Set.t domain + + val opinions : + ?inspect:('a -> Sexp.t) -> + empty:'a -> + equal:('a -> 'a -> bool) -> + string -> + 'a opinions domain + + val string : string domain + val bool : bool option domain + + val obj : ('a,_) cls -> 'a obj domain + + + val empty : 'a t -> 'a + val is_empty : 'a t -> 'a -> bool + val order : 'a t -> 'a -> 'a -> Order.partial + val join : 'a t -> 'a -> 'a -> ('a,conflict) result + val inspect : 'a t -> 'a -> Base.Sexp.t + val name : 'a t -> string + end + + module Persistent : sig + type 'a t = 'a persistent + + val define : + to_string:('a -> string) -> + of_string:(string -> 'a) -> + 'a persistent + + val derive : + to_persistent:('a -> 'b) -> + of_persistent:('b -> 'a) -> + 'b persistent -> 'a persistent + + val of_binable : (module Binable.S with type t = 'a) -> 'a persistent + + val string : string persistent + + val list : 'a persistent -> 'a list persistent + val sequence : 'a persistent -> 'a Sequence.t persistent + val array : 'a persistent -> 'a array persistent + + val set : ('a,'c) Set.comparator -> 'a t -> ('a,'c) Set.t persistent + val map : ('k,'c) Map.comparator -> 'k t -> 'd t -> ('k,'d,'c) Map.t persistent + end + + module Data : sig + type +'a t + type 'a ord + + val atom : ('a,_) cls -> 'a obj -> 'a t knowledge + val cons : ('a,_) cls -> 'a t -> 'a t -> 'a t knowledge + + val case : ('a,_) cls -> 'a t -> + null:'r knowledge -> + atom:('a obj -> 'r knowledge) -> + cons:('a t -> 'a t -> 'r knowledge) -> 'r knowledge + + + val id : 'a obj -> Int63.t + + + module type S = sig + type t [@@deriving sexp] + include Base.Comparable.S with type t := t + include Binable.S with type t := t + end + + val derive : ('a,_) cls -> (module S + with type t = 'a t + and type comparator_witness = 'a ord) + end + + module Conflict : sig + type t = conflict = .. + val pp : Format.formatter -> conflict -> unit + val sexp_of_t : t -> Sexp.t + end + + val sexp_of_conflict : conflict -> Sexp.t +end diff --git a/lib/monads/monads_monad.ml b/lib/monads/monads_monad.ml index 7d064e825..e754dc6e9 100644 --- a/lib/monads/monads_monad.ml +++ b/lib/monads/monads_monad.ml @@ -40,47 +40,47 @@ module Monad = struct module Lift = struct - let nullary = return - let unary f a = a >>| f - let binary f a b = a >>= fun a -> b >>| fun b -> f a b - let ternary f a b c = a >>= fun a -> b >>= fun b -> c >>| fun c -> f a b c - let quaternary f a b c d = - a >>= fun a -> b >>= fun b -> c >>= fun c -> d >>| fun d -> - f a b c d - let quinary f a b c d e = - a >>= fun a -> b >>= fun b -> c >>= fun c -> d >>= fun d -> e >>| fun e -> - f a b c d e - - module Syntax = struct - let (!!) = nullary - let (!$) = unary - let (!$$) = binary - let (!$$$) = ternary - let (!$$$$) = quaternary - let (!$$$$$) = quinary - end + let nullary x = return x [@@inline] + let unary f a = a >>| f [@@inline] + let binary f a b = a >>= fun a -> b >>| fun b -> f a b [@@inline] + let ternary f a b c = a >>= fun a -> b >>= fun b -> c >>| fun c -> f a b c + let quaternary f a b c d = + a >>= fun a -> b >>= fun b -> c >>= fun c -> d >>| fun d -> + f a b c d + let quinary f a b c d e = + a >>= fun a -> b >>= fun b -> c >>= fun c -> d >>= fun d -> e >>| fun e -> + f a b c d e + + module Syntax = struct + let (!!) x = nullary x [@@inline] + let (!$) = unary + let (!$$) = binary + let (!$$$) = ternary + let (!$$$$) = quaternary + let (!$$$$$) = quinary + end end open Lift.Syntax module Fn = struct - let id = return - let nothing = return - let ignore m = m >>| ignore - let non f x = f x >>| not - let apply_n_times ~n f x = - let rec loop n x = - if n <= 0 then return x - else f x >>= loop (n-1) in - loop n x - - let compose f g x = g x >>= f + let id x = return x [@@inline] + let nothing x = return x [@@inline] + let ignore m = m >>| ignore [@@inline] + let non f x = f x >>| not [@@inline] + let apply_n_times ~n f x = + let rec loop n x = + if n <= 0 then return x + else f x >>= loop (n-1) in + loop n x + + let compose f g x = g x >>= f [@@inline] end module Syntax = struct include Monad_infix include Lift.Syntax - let (>=>) g f = Fn.compose f g + let (>=>) g f = Fn.compose f g [@@inline] end open Syntax @@ -224,10 +224,8 @@ module Monad = struct module Base = Eager_base(B) include Make(Base) end - end - module List = Collection.Delay(struct type 'a t = 'a list let fold xs ~init ~f = @@ -244,9 +242,9 @@ module Monad = struct type 'a t = 'a Sequence.t let fold xs ~init ~f finish = Sequence.delayed_fold xs ~init ~f:(fun a x ~k -> - f a x k) ~finish - let zero () = Sequence.empty - let return = Sequence.return + f a x k) ~finish + let zero () = Sequence.empty [@@inline] + let return x = Sequence.return x [@@inline] let plus = Sequence.append end) @@ -288,7 +286,7 @@ module Monad = struct Since, the code doesn't contain any implementation (just renaming) it can be considered OK. - *) + *) module Core(M : Core) = struct type 'a t = 'a M.t include Make(struct @@ -307,7 +305,7 @@ module Monad = struct monad representation to our maximal. We will not erase types from the resulting structure, as this functor is expected to be used as a type caster, e.g. [Monad.State.Make(Monad.Minimal(M)] - *) + *) module Minimal( M : Minimal) = struct type 'a t = 'a M.t include Make(struct @@ -411,7 +409,7 @@ module Ident type 'a t = 'a Sequence.t let fold xs ~init ~f finish = Sequence.delayed_fold xs ~init ~f:(fun a x ~k -> - f a x k) ~finish + f a x k) ~finish let zero () = Sequence.empty let return = Sequence.return let plus = Sequence.append @@ -424,15 +422,15 @@ module Ident end module Syntax = struct - let (>>=) x f = x |> f - let (>>|) x f = x |> f - let (>=>) f g x = g (f x) - let (!!) = ident - let (!$) = ident - let (!$$) = ident - let (!$$$) = ident - let (!$$$$) = ident - let (!$$$$$) = ident + let (>>=) x f = x |> f [@@inline] + let (>>|) x f = x |> f [@@inline] + let (>=>) f g x = g (f x) [@@inline] + let (!!) = ident + let (!$) = ident + let (!$$) = ident + let (!$$$) = ident + let (!$$$$) = ident + let (!$$$$$) = ident end module Let_syntax = struct @@ -488,7 +486,14 @@ module OptionT = struct let bind m f = M.bind m (function | Some r -> f r | None -> M.return None) - let map = `Define_using_bind + [@@inline] + + let map m ~f = M.bind m (function + | Some r -> M.return (Some (f r)) + | None -> M.return None) + [@@inline] + + let map = `Custom map end type 'a error = unit let fail () = M.return None @@ -516,11 +521,11 @@ module OptionT = struct : S with type 'a m := 'a T1(M).m and type 'a t := 'a T1(M).t and type 'a e := 'a T1(M).e - = Make2(struct - type ('a,'e) t = 'a M.t - type 'a error = unit - include (M : Monad.S with type 'a t := 'a M.t) - end) + = Make2(struct + type ('a,'e) t = 'a M.t + type 'a error = unit + include (M : Monad.S with type 'a t := 'a M.t) + end) include T1(Ident) include Make(Ident) @@ -550,13 +555,19 @@ module ResultT = struct and type ('a,'e) e := ('a,'e) Tp(T)(M).e and type 'a error = 'a Tp(T)(M).error = struct - open M.Syntax + + include struct + let (>>=) m f = (M.bind [@inlined]) m f [@@inline] + let (>>|) m f = (M.map [@inlined]) m f [@@inline] + end + module Base = struct include Tp(T)(M) let return x = M.return (Ok x) let bind m f : ('a,'e) t = m >>= function | Ok r -> f r | Error err -> M.return (Error err) + [@@inline] let fail err = M.return (Error err) let run = ident @@ -565,8 +576,11 @@ module ResultT = struct | other -> M.return other let lift m = m >>| fun x -> Ok x - let map = `Define_using_bind - + let map' m ~f = m >>= function + | Ok r -> return (f r) + | Error err -> M.return (Error err) + [@@inline] + let map = `Custom map' end include Base include Monad.Make2(Base) @@ -590,10 +604,10 @@ module ResultT = struct and type 'a m := 'a T1(T)(M).m and type 'a e := 'a T1(T)(M).e and type err := T.t - = struct - type err = T.t - include Makep(struct type 'a t = T.t end)(M) - end + = struct + type err = T.t + include Makep(struct type 'a t = T.t end)(M) + end module Make2(M : Monad.S) : S2 with type ('a,'e) t := ('a,'e) T2(M).t @@ -617,20 +631,20 @@ module ResultT = struct and type 'a m := 'a T(M).m and type 'a e := 'a T(M).e and type err := Error.t - = struct - include Make(struct type t = Error.t end)(M) - - let failf fmt = - let open Caml.Format in - let buf = Buffer.create 512 in - let ppf = formatter_of_buffer buf in - let kon ppf () = - pp_print_flush ppf (); - let err = Or_error.error_string (Buffer.contents buf) in - M.return err in - kfprintf kon ppf fmt + = struct + include Make(struct type t = Error.t end)(M) + + let failf fmt = + let open Caml.Format in + let buf = Buffer.create 512 in + let ppf = formatter_of_buffer buf in + let kon ppf () = + pp_print_flush ppf (); + let err = Or_error.error_string (Buffer.contents buf) in + M.return err in + kfprintf kon ppf fmt - end + end type 'a t = 'a Or_error.t type 'a m = 'a type 'a e = 'a Or_error.t @@ -659,11 +673,11 @@ module ResultT = struct with type ('a,'e) t = ('a,'e) result and type 'a m = 'a and type ('a,'e) e = ('a,'e) result - = struct - include T2(Ident) - include Make2(Ident) - end - include Self + = struct + include T2(Ident) + include Make2(Ident) + end + include Self end module ListT = struct @@ -758,7 +772,7 @@ module Seq = struct let bind xsm f = xsm >>= fun xs -> Sequence.fold xs ~init:(!!Sequence.empty) ~f:(fun ysm x -> ysm >>= fun ys -> f x >>| fun xs -> - Sequence.append xs ys) + Sequence.append xs ys) let map xsm ~f = xsm >>| Sequence.map ~f let map = `Custom map end @@ -780,9 +794,9 @@ module Seq = struct end module Make(M : Monad.S) - : S with type 'a m := 'a T1(M).m - and type 'a t := 'a T1(M).t - and type 'a e := 'a T1(M).e + : S with type 'a m := 'a T1(M).m + and type 'a t := 'a T1(M).t + and type 'a e := 'a T1(M).e = Make2(struct type ('a,'e) t = 'a M.t include (M : Monad.S with type 'a t := 'a M.t) @@ -840,7 +854,7 @@ module Writer = struct let returnw x = M.return @@ writer x let write x = M.return @@ writer ((), x) - let read m = m >>= fun (Writer (x,e)) -> returnw (e,e) + let read m = m >>= fun (Writer (_,e)) -> returnw (e,e) let listen m = m >>= fun (Writer (x,e)) -> returnw ((x,e),e) let run m = m >>| fun (Writer (x,e)) -> (x,e) let exec m = m >>| fun (Writer ((),e)) -> e @@ -898,7 +912,7 @@ module Reader = struct let bind m f = reader @@ fun s -> m => s >>= fun x -> f x => s let map m ~f = reader @@ fun s -> m => s >>| f let read () = reader @@ fun s -> M.return s - let lift m = reader @@ fun s -> m + let lift m = reader @@ fun _ -> m let run m s = m => s let map = `Custom map end @@ -966,7 +980,7 @@ module State = struct s : 'b; } - type ('a,'e) state = State of ('e -> 'a) + type ('a,'e) state = State of ('e -> 'a) [@@unboxed] module Tp(T : T1)(M : Monad.S) = struct @@ -982,28 +996,31 @@ module State = struct and type ('a,'e) e := ('a,'e) Tp(T)(M).e and type 'a env := 'a Tp(T)(M).env = struct - open M.Monad_infix - let make run = State run - let (=>) (State run) x = run x + include struct + let (>>=) m f = (M.bind [@inlined]) m f [@@inline] + let (>>|) m f = (M.map [@inlined]) m f [@@inline] + end + + let make run = State run [@@inline] + let (=>) (State run) x = run x [@@inline] type 'a result = 'a M.t module Basic = struct include Tp(T)(M) - let return x = make @@ fun s -> M.return {x;s} - let bind m f = make @@ fun s -> m=>s >>= fun {x;s} -> f x => s - let map m ~f = make @@ fun s -> m=>s >>| fun {x;s} -> {x=f x;s} + let return x = (make [@inlined]) @@ fun s -> (M.return [@inlined]) {x;s} [@@inline] + let bind m f = make @@ fun s -> m=>s >>= fun {x;s} -> f x => s [@@inline] + let map m ~f = make @@ fun s -> m=>s >>| fun {x;s} -> {x=f x;s} [@@inline] let map = `Custom map end - let put s = make @@ fun _ -> M.return {x=();s} - let get () = make @@ fun s -> M.return {x=s;s} - let gets f = make @@ fun s -> M.return {x=f s;s} - let update f = make @@ fun s -> M.return {x=();s = f s} + let put s = make @@ fun _ -> M.return {x=();s} [@@inline] + let get () = make @@ fun s -> M.return {x=s;s} [@@inline] + let gets f = make @@ fun s -> M.return {x=f s;s} [@@inline] + let update f = make @@ fun s -> M.return {x=();s = f s} [@@inline] let modify m f = - make @@ fun s -> m=>s >>= fun {x;s} -> M.return {x; s = f s} + make @@ fun s -> m=>s >>= fun {x;s} -> M.return {x; s = f s} [@@inline] let run m s = M.(m => s >>| fun {x;s} -> (x,s)) let eval m s = M.(run m s >>| fst) let exec m s = M.(run m s >>| snd) - let lift m = make @@ fun s -> - M.bind m (fun x -> M.return {x;s}) + let lift m = make @@ fun s -> M.bind m (fun x -> M.return {x;s}) [@@inline] include Basic include Monad.Make2(Basic) end @@ -1214,7 +1231,7 @@ module State = struct let run m = fun ctxt -> M.bind (SM.run m (init ctxt)) ~f:(fun (x,cs) -> - M.return (x,cs.init)) + M.return (x,cs.init)) include Monad.Make2(struct type nonrec ('a,'e) t = ('a,'e) t @@ -1357,13 +1374,13 @@ module LazyT = struct with type 'a t = 'a Lazy.t and type 'a m = 'a and type 'a e = 'a - = struct - type 'a t = 'a Lazy.t - type 'a m = 'a - type 'a e = 'a - include Make(Ident) - end - include Self + = struct + type 'a t = 'a Lazy.t + type 'a m = 'a + type 'a e = 'a + include Make(Ident) + end + include Self end module Cont = struct @@ -1446,90 +1463,13 @@ module Cont = struct S2 with type ('a,'e) t = ('a,'e) cont and type 'a m = 'a and type ('a,'e) e = (('a -> 'e) -> 'e) - = struct - type ('a,'e) t = ('a,'e) T(Ident).t - type 'a m = 'a - type ('a,'e) e = ('a -> 'e) -> 'e - include Make2(Ident) - end - include Self -end - -module T = struct - module Option = struct - module Make(M : Monad.S) = struct - type 'a t = 'a option M.t - include Monad.Make(struct - type nonrec 'a t = 'a t - let return x : 'a t = Option.return x |> M.return - let bind m f : 'b t = M.bind m (function - | Some r -> f r - | None -> M.return None) - let map = `Define_using_bind - end) - let lift (m : 'a option) : 'a t = M.return m - end - module Make2(M : Monad.S2) = struct - type ('a,'b) t = ('a option,'b) M.t - include Monad.Make2(struct - type nonrec ('a,'b) t = ('a,'b) t - let return x = Option.return x |> M.return - let bind m f = M.bind m (function - | Some r -> f r - | None -> M.return None) - let map = `Define_using_bind - end) - let lift (m : 'a option) : ('a,'b) t = M.return m - end - end - - - module Or_error = struct - module Make(M : Monad.S) = struct - type 'a t = 'a Or_error.t M.t - include Monad.Make(struct - type nonrec 'a t = 'a t - let return x = Or_error.return x |> M.return - let bind m f = M.bind m (function - | Ok r -> f r - | Error err -> M.return (Error err)) - let map = `Define_using_bind - end) - let lift m = M.return m - end - module Make2(M : Monad.S2) = struct - type ('a,'b) t = ('a Or_error.t,'b) M.t - include Monad.Make2(struct - type nonrec ('a,'b) t = ('a,'b) t - let return x = Or_error.return x |> M.return - let bind m f = M.bind m (function - | Ok r -> f r - | Error err -> M.return (Error err)) - let map = `Define_using_bind - end) - let lift m = M.return m - end - end - - - module Result = struct - module Make(M : Monad.S) = struct - type ('a,'e) t = ('a,'e) Result.t M.t - include Monad.Make2(struct - type nonrec ('a,'e) t = ('a,'e) t - let return x = Result.return x |> M.return - let bind m f = M.bind m (function - | Ok r -> f r - | Error err -> M.return (Error err)) - let map = `Define_using_bind - end) - let lift m : ('a,'e) t = M.return m - end - end - - module State = struct - module Make = State + = struct + type ('a,'e) t = ('a,'e) T(Ident).t + type 'a m = 'a + type ('a,'e) e = ('a -> 'e) -> 'e + include Make2(Ident) end + include Self end module Lazy = LazyT diff --git a/lib/x86_cpu/x86_cpu.ml b/lib/x86_cpu/x86_cpu.ml index 1008905d5..274e4b34f 100644 --- a/lib/x86_cpu/x86_cpu.ml +++ b/lib/x86_cpu/x86_cpu.ml @@ -11,6 +11,7 @@ module Make_CPU(Env : ModeVars) = struct rax; rcx; rdx; rsi; rdi; rbx; rbp; rsp; ] @ Array.to_list r + @ Array.to_list ymms let flags = Var.Set.of_list [ cf; pf; af; zf; sf; oF; df diff --git a/lib_test/bap_disasm/test_disasm.ml b/lib_test/bap_disasm/test_disasm.ml index 09ecf8c96..ed0b9cd0c 100644 --- a/lib_test/bap_disasm/test_disasm.ml +++ b/lib_test/bap_disasm/test_disasm.ml @@ -60,7 +60,8 @@ let strings_of_insn insn = List.map ~f:(Op.to_string) in (name :: ops) -let insn_of_mem arch data ctxt = +let insn_of_mem arch data _ctxt = + Toplevel.reset (); let mem = memory_of_string data in Dis.with_disasm ~backend:"llvm" arch ~f:(fun dis -> Dis.insn_of_mem dis mem >>= function @@ -93,6 +94,7 @@ let test_insn_of_mem (arch,samples) ctxt = List.iter samples ~f:test let test_run_all (arch,samples) ctxt = + Toplevel.reset (); let mem = samples |> List.map ~f:fst3 |> String.concat |> memory_of_string in Dis.with_disasm ~backend:"llvm" arch ~f:(fun dis -> @@ -107,7 +109,7 @@ let test_run_all (arch,samples) ctxt = ~f:(fun (data,exp,kinds) -> function | (_,None) -> assert_string "bad instruction" | (mem, Some r) -> - assert_strings_equal ctxt exp (strings_of_insn r); + assert_strings_equal ctxt exp (strings_of_insn r); assert_equal ~ctxt ~printer:Int.to_string (String.length data) (Memory.length mem); List.iter kinds ~f:(fun expected -> @@ -224,7 +226,7 @@ let strlen = List.concat [ type dest_kind = [`Jump | `Cond | `Fall ] [@@deriving sexp] type graph = (int * int list * (int * dest_kind) list) list - [@@deriving sexp_of] +[@@deriving sexp_of] let graph : graph = [ 1, [], [3, `Jump]; @@ -301,6 +303,7 @@ let structure cfg ctxt = (* test one instruction cfg *) let test_micro_cfg insn ctxt = + Toplevel.reset (); let open Or_error in let mem = Bigstring.of_string insn |> Memory.create LittleEndian (Addr.of_int64 0L) |> @@ -339,7 +342,9 @@ let test_micro_cfg insn ctxt = |6: ret +<-----+ +-------------------+ - With the third ret unreachable. + +-------------------+ + |7: ret + + +-------------------+ *) let has_dest cfg src dst kind = @@ -347,21 +352,29 @@ let has_dest cfg src dst kind = Cfg.Edge.mem e cfg -let call1_3ret ctxt = +let sort_by_addr = + List.sort ~compare:(fun x y -> + Addr.compare (Block.addr x) (Block.addr y)) + + +let call1_3ret _ctxt = + Toplevel.reset (); let mem = String.concat [call1; ret; ret; ret] |> memory_of_string in let dis = Rec.run `x86_64 mem |> Or_error.ok_exn in assert_bool "No errors" (Rec.errors dis = []); - assert_bool "Three block" (Rec.cfg dis |> Cfg.number_of_nodes = 3); + assert_bool "Four block" (Rec.cfg dis |> Cfg.number_of_nodes = 4); let cfg = Rec.cfg dis in - match Cfg.nodes cfg |> Seq.to_list with - | [b1;b2;b3] -> + match Cfg.nodes cfg |> Seq.to_list |> sort_by_addr with + | [b1;b2;b3;b4] -> let call = memory_of_string ~width:64 call1 in let ret1 = memory_of_string ret ~start:5 ~width:64 in let ret2 = memory_of_string ret ~start:6 ~width:64 in + let ret3 = memory_of_string ret ~start:7 ~width:64 in assert_memory call (Block.memory b1); assert_memory ret1 (Block.memory b2); assert_memory ret2 (Block.memory b3); + assert_memory ret3 (Block.memory b4); assert_bool "b1 -> jump b3" @@ has_dest cfg b1 b3 `Jump; assert_bool "b1 -> fall b2" @@ has_dest cfg b1 b2 `Fall; assert_bool "b2 has no succs" @@ @@ -377,7 +390,6 @@ let suite () = "Disasm.Basic" >::: [ "addresses" >:: test_cfg addresses; "structure" >:: test_cfg structure; "ret" >:: test_micro_cfg ret; - "sub" >:: test_micro_cfg sub; + "call" >:: test_micro_cfg call; "call1_3ret" >:: call1_3ret; ] - diff --git a/lib_test/bap_project/test_project.ml b/lib_test/bap_project/test_project.ml index 2f3de6026..65a5b1d7c 100644 --- a/lib_test/bap_project/test_project.ml +++ b/lib_test/bap_project/test_project.ml @@ -1,3 +1,5 @@ +open Bap_core_theory + open Core_kernel open Bap_future.Std open OUnit2 @@ -6,6 +8,7 @@ open Word_size open Bap.Std type case = { + name : string; arch : arch; addr : int; code : string; @@ -14,19 +17,21 @@ type case = { } let arm = { + name = "arm-project"; arch = `armv7; addr = 16; - code = "\x01\x20\xA0\xE1"; - bil = "R2 := R1"; - asm = "mov r2, r1"; + code = "\x1e\xff\x2f\xe1"; + bil = "jmp LR"; + asm = "bx lr"; } let x86 = { + name = "x86-project"; arch = `x86; addr = 10; - code = "\x89\x34\x24"; - asm = "movl %esi, (%esp)"; - bil = "mem := mem with [ESP,el]:u32 <- ESI"; + code = "\xeb\xfe"; + asm = "jmp -0x2"; + bil = "jmp 0xA"; } let normalize = String.filter ~f:(function @@ -43,16 +48,18 @@ let tag = Value.Tag.register (module String) let addr_width case = Arch.addr_size case.arch |> Size.in_bits let test_substitute case = + Toplevel.reset (); let sub_name = Format.asprintf "test_%a" Arch.pp case.arch in let addr = sprintf "%#x" in let min_addr = addr case.addr in let max_addr = addr (case.addr + String.length case.code - 1) in let base = Addr.of_int case.addr ~width:(addr_width case) in - let name addr = Option.some_if Addr.(base = addr) sub_name in - let symbolizer = Stream.map Project.Info.arch (fun _ -> - Ok (Symbolizer.create name)) in - let new_rooter _ = Ok ([base] |> Seq.of_list |> Rooter.create) in - let rooter = Stream.map Project.Info.arch ~f:new_rooter in + let symbolizer = Symbolizer.create @@ fun addr -> + Option.some_if Addr.(base = addr) sub_name in + let agent = + let name = sprintf "test-project-symbolizer-for-%s" case.name in + KB.Agent.register name in + Symbolizer.provide agent symbolizer; let input = let file = "/dev/null" in let mem = @@ -63,7 +70,7 @@ let test_substitute case = let data = code in Project.Input.create case.arch file ~code ~data in - let p = Project.create ~rooter ~symbolizer input |> ok_exn in + let p = Project.create input |> ok_exn in let mem,_ = Memmap.lookup (Project.memory p) base |> Seq.hd_exn in let test expect s = let p = Project.substitute p mem tag s in diff --git a/lib_test/powerpc/powerpc_tests_helpers.ml b/lib_test/powerpc/powerpc_tests_helpers.ml index ef6057e1f..831f978c8 100644 --- a/lib_test/powerpc/powerpc_tests_helpers.ml +++ b/lib_test/powerpc/powerpc_tests_helpers.ml @@ -203,7 +203,11 @@ let get_insn ?addr arch bytes = | Ok (mem, Some insn, _) -> let insn_name = Insn.(name @@ of_basic insn) in mem, insn, insn_name - | _ -> failwith "disasm failed" + | Ok (mem, None,_) -> + failwithf "Failed to find an instruction in: %a" + Memory.pps mem () + | Error err -> + failwithf "Diassembler failed with: %s" (Error.to_string_hum err) () let lookup_var c var = match c#lookup var with | None -> None @@ -247,7 +251,7 @@ let check_gpr ?addr init bytes var expected arch _ctxt = | None -> assert_bool "var not found OR it's result not Imm" false | Some w -> if not (Word.equal w expected) || - (Word.bitwidth w <> Word.bitwidth expected) then + (Word.bitwidth w <> Word.bitwidth expected) then printf "\n%s: check failed for %s: expected %s <> %s\n" insn_name (Var.name var) @@ -279,7 +283,7 @@ let check_mem init bytes mem ~addr ~size expected ?(endian=BigEndian) arch _ctxt let bil = Or_error.ok_exn @@ to_bil arch memory insn in check_bil (init @ bil); let c = Stmt.eval (init @ bil) (new Bili.context) in - match load_word c mem addr endian size with + match load_word c mem addr endian size with | None -> assert_bool "word not found OR it's result not Imm" false | Some w -> if not (Word.equal w expected) then diff --git a/oasis/bap-std b/oasis/bap-std index c765d50ea..88fda6ae3 100644 --- a/oasis/bap-std +++ b/oasis/bap-std @@ -16,7 +16,12 @@ Library bap bap.types, bap-future, cmdliner, - regular + bap-knowledge, + bap-core-theory, + graphlib, + regular, + core_kernel, + ppx_jane Modules: Bap InternalModules: Bap_event, Bap_log, Bap_project, Bap_self @@ -27,7 +32,15 @@ Library types FindlibParent: bap FindlibName: types CompiledObject: best - BuildDepends: monads, zarith, uuidm, bap.config, regular, graphlib, ogre + BuildDepends: monads, zarith, + bitvec, + uuidm, + bap.config, + regular, + graphlib, + ogre, + bap-knowledge, + bap-core-theory InternalModules: Bap_addr, @@ -63,8 +76,8 @@ Library types Bap_ogre, Bap_result, Bap_size, - Bap_state, Bap_stmt, + Bap_toplevel, Bap_trie, Bap_trie_intf, Bap_type, @@ -109,6 +122,8 @@ Library disasm Bap_disasm_basic, Bap_disasm_block, Bap_disasm_brancher, + Bap_disasm_calls, + Bap_disasm_driver, Bap_disasm_insn, Bap_disasm_linear_sweep, Bap_disasm_prim, @@ -190,7 +205,7 @@ Library bundle Path: lib/bap_bundle FindlibParent: bap FindlibName: bundle - BuildDepends: uri, camlzip, unix, bap.config + BuildDepends: uri, camlzip, unix, bap.config, core_kernel, ppx_jane Modules: Bap_bundle @@ -200,7 +215,7 @@ Executable "bapbuild" MainIs: bapbuild.ml Install: true CompiledObject: best - BuildDepends: core_kernel, ocamlbuild, bap-build, compiler-libs + BuildDepends: core_kernel, ocamlbuild, bap-build, compiler-libs, ppx_jane Executable "bapbundle" Build$: flag(everything) || flag(bap_std) diff --git a/oasis/bil b/oasis/bil index 74b0b5cdc..d92677d69 100644 --- a/oasis/bil +++ b/oasis/bil @@ -7,6 +7,6 @@ Library bil_plugin Path: plugins/bil Build$: flag(everything) || flag(bil) FindlibName: bap-plugin-bil - BuildDepends: bap - XMETAExtraLines: tags="bil,analysis" - InternalModules: Bil_main + BuildDepends: bap, bap-core-theory, bap-knowledge + XMETAExtraLines: tags="bil,analysis,semantics" + InternalModules: Bil_main, Bil_lifter, Bil_semantics, Bil_ir, Bil_float diff --git a/oasis/bitvec b/oasis/bitvec new file mode 100644 index 000000000..ee067baf7 --- /dev/null +++ b/oasis/bitvec @@ -0,0 +1,11 @@ +Flag bitvec + Description: Build the bitvec library + Default: false + +Library bitvec + Build$: flag(everything) || flag(bitvec) + Path: lib/bitvec + FindlibName: bitvec + CompiledObject: best + BuildDepends: zarith + Modules: Bitvec diff --git a/oasis/bitvec-binprot b/oasis/bitvec-binprot new file mode 100644 index 000000000..a94b43094 --- /dev/null +++ b/oasis/bitvec-binprot @@ -0,0 +1,11 @@ +Flag bitvec_binprot + Description: Enables support for Core's Binprot protocol + Default: false + +Library bitvec_binprot + Build$: flag(everything) || flag(bitvec_binprot) + Path: lib/bitvec_binprot + FindlibName: bitvec-binprot + CompiledObject: best + BuildDepends: bitvec, bin_prot, ppx_jane + Modules: Bitvec_binprot diff --git a/oasis/bitvec-order b/oasis/bitvec-order new file mode 100644 index 000000000..fefd9d85a --- /dev/null +++ b/oasis/bitvec-order @@ -0,0 +1,11 @@ +Flag bitvec_order + Description: Provides comparators for use with JS Core + Default: false + +Library bitvec_order + Build$: flag(everything) || flag(bitvec_order) + Path: lib/bitvec_order + FindlibName: bitvec-order + CompiledObject: best + BuildDepends: base, bitvec, bitvec-sexp + Modules: Bitvec_order diff --git a/oasis/bitvec-sexp b/oasis/bitvec-sexp new file mode 100644 index 000000000..aa1f3fca4 --- /dev/null +++ b/oasis/bitvec-sexp @@ -0,0 +1,11 @@ +Flag bitvec_sexp + Description: Provides sexp converters for bitvectors + Default: false + +Library bitvec_sexp + Build$: flag(everything) || flag(bitvec_sexp) + Path: lib/bitvec_sexp + FindlibName: bitvec-sexp + CompiledObject: best + BuildDepends: bitvec, sexplib0 + Modules: Bitvec_sexp diff --git a/oasis/common b/oasis/common index 542aa8eb8..9c87529e3 100644 --- a/oasis/common +++ b/oasis/common @@ -1,6 +1,6 @@ OASISFormat: 0.4 Name: bap -Version: 1.6.0 +Version: 2.0.0-alpha OCamlVersion: >= 4.04.1 Synopsis: BAP Core Library Authors: BAP Team @@ -11,13 +11,10 @@ Plugins: META (0.4) AlphaFeatures: ocamlbuild_more_args, compiled_setup_ml BuildTools: ocamlbuild XOCamlbuildExtraArgs: - -j 2 + -j 8 -use-ocamlfind -classic-display - -plugin-tags "'package(findlib),package(core_kernel)'" - -BuildDepends: ppx_jane, core_kernel (>= v0.11 && < v0.12) - + -plugin-tags "'package(findlib)'" PreConfCommand: $rm setup.data PostDistcleanCommand: $rm _tags myocamlbuild.ml setup.ml setup.data diff --git a/oasis/common.tags.in b/oasis/common.tags.in index b7c783d44..24a76a4e2 100644 --- a/oasis/common.tags.in +++ b/oasis/common.tags.in @@ -3,6 +3,4 @@ true: short_paths true: bin_annot true: debug -not <**/bap_elf/*>: predicate(ppx_driver) -not <**/bap_elf/*>: pp(ppx-jane -dump-ast -inline-test-drop) true: warn(+a-4-6-7-9-27-29-32..42-44-45-48-50-60) \ No newline at end of file diff --git a/oasis/core-theory b/oasis/core-theory new file mode 100644 index 000000000..98c883cd6 --- /dev/null +++ b/oasis/core-theory @@ -0,0 +1,24 @@ +Flag core_theory + Description: Build the bap-core-theory library + Default: false + +Library bap_core_theory + Build$: flag(everything) || flag(core_theory) + Path: lib/bap_core_theory + FindlibName: bap-core-theory + CompiledObject: best + BuildDepends: bap-knowledge, core_kernel, + bitvec, bitvec-order, bitvec-sexp, bitvec-binprot + Modules: Bap_core_theory + InternalModules: + Bap_core_theory_basic, + Bap_core_theory_definition, + Bap_core_theory_effect, + Bap_core_theory_empty, + Bap_core_theory_grammar_definition, + Bap_core_theory_IEEE754, + Bap_core_theory_program, + Bap_core_theory_manager, + Bap_core_theory_parser, + Bap_core_theory_value, + Bap_core_theory_var diff --git a/oasis/elementary b/oasis/elementary new file mode 100644 index 000000000..6642c2f3a --- /dev/null +++ b/oasis/elementary @@ -0,0 +1,11 @@ +Flag elementary + Description: Build the bap-elementary library + Default: false + +Library bap_elementary + Build$: flag(everything) || flag(elementary) + Path: lib/bap_elementary + FindlibName: bap-elementary + CompiledObject: best + BuildDepends: bap, bap-knowledge, bap-core-theory, core_kernel + Modules: Bap_elementary diff --git a/oasis/frontend b/oasis/frontend index a476c3fae..3a11b24f1 100644 --- a/oasis/frontend +++ b/oasis/frontend @@ -7,4 +7,5 @@ Executable "bap" MainIs: bap_main.ml Build$: flag(everything) || flag(frontend) CompiledObject: best - BuildDepends: bap, bap.plugins, cmdliner, findlib.dynload, parsexp + BuildDepends: core_kernel, ppx_jane, bap-future, regular, + bap, bap.plugins, cmdliner, findlib.dynload, parsexp diff --git a/oasis/ida b/oasis/ida index 3e7435f36..4834b78ba 100644 --- a/oasis/ida +++ b/oasis/ida @@ -10,7 +10,7 @@ Library bap_ida CompiledObject: best Build$: flag(everything) || flag(ida) Modules: Bap_ida - BuildDepends: fileutils, re.posix + BuildDepends: fileutils, re.posix, core_kernel, ppx_jane XMETADescription: make calls into IDA Library bap_ida_plugin @@ -18,7 +18,7 @@ Library bap_ida_plugin Path: plugins/ida FindlibName: bap-plugin-ida CompiledObject: best - BuildDepends: bap, bap-ida + BuildDepends: bap, bap-ida, core_kernel, ppx_jane Modules: Ida_main InternalModules: Bap_ida_config, Bap_ida_service, Bap_ida_info XMETADescription: use ida to provide rooter, symbolizer and reconstructor diff --git a/oasis/knowledge b/oasis/knowledge new file mode 100644 index 000000000..33fa9b4f8 --- /dev/null +++ b/oasis/knowledge @@ -0,0 +1,11 @@ +Flag knowledge + Description: Build the knowledge library + Default: false + +Library knowledge + Build$: flag(everything) || flag(knowledge) + Path: lib/knowledge + FindlibName: bap-knowledge + CompiledObject: best + BuildDepends: core_kernel, monads, ppx_jane + Modules: Bap_knowledge diff --git a/oasis/lisp b/oasis/lisp new file mode 100644 index 000000000..aaac9baae --- /dev/null +++ b/oasis/lisp @@ -0,0 +1,29 @@ +Flag lisp + Description: BAP Lisp + Default: false + +Library bap_lisp + Build$: flag(everything) || flag(lisp) + XMETADescription: microexecution framework + Path: lib/bap_lisp + FindlibName: bap-lisp + CompiledObject: best + BuildDepends: parsexp, bap-strings, bap-knowledge, bap-core-theory, + graphlib, ppx_jane, core_kernel + Modules: Bap_lisp + InternalModules: + Bap_lisp, + Bap_lisp__attribute, + Bap_lisp__attributes, + Bap_lisp__context, + Bap_lisp__def, + Bap_lisp__index, + Bap_lisp__loc, + Bap_lisp__parse, + Bap_lisp__program, + Bap_lisp__resolve, + Bap_lisp__source, + Bap_lisp__type, + Bap_lisp__types, + Bap_lisp__var, + Bap_lisp__word diff --git a/oasis/llvm b/oasis/llvm index 236576021..f0c87cecb 100644 --- a/oasis/llvm +++ b/oasis/llvm @@ -31,7 +31,7 @@ Library bap_llvm Bap_llvm_ogre_samples, Bap_llvm_ogre_types CCOpt: $cc_optimization - CCLib: $llvm_lib $cxxlibs $llvm_ldflags + CCLib: $llvm_lib $cxxlibs $llvm_ldflags -lcurses CSources: llvm_disasm.h, llvm_disasm.c, llvm_stubs.c, diff --git a/oasis/mips b/oasis/mips index 26564fa1f..78223f8ad 100644 --- a/oasis/mips +++ b/oasis/mips @@ -7,7 +7,7 @@ Library mips_plugin Path: plugins/mips FindlibName: bap-plugin-mips Build$: flag(everything) || flag (mips) - BuildDepends: bap, bap-abi, bap-c + BuildDepends: bap, bap-abi, bap-c, bap-core-theory InternalModules: Mips, Mips_main, diff --git a/oasis/monads b/oasis/monads index fcc596517..1a9149d8a 100644 --- a/oasis/monads +++ b/oasis/monads @@ -7,7 +7,7 @@ Library monads Path: lib/monads FindlibName: monads CompiledObject: best - BuildDepends: core_kernel + BuildDepends: core_kernel, ppx_jane Modules: Monads InternalModules: Monads_monad, Monads_monoid, diff --git a/oasis/primus-dictionary b/oasis/primus-dictionary index eada8f1bd..3931f0c26 100644 --- a/oasis/primus-dictionary +++ b/oasis/primus-dictionary @@ -5,9 +5,9 @@ Flag primus_dictionary Library primus_dictionary_library_plugin Build$: flag(everything) || flag(primus_dictionary) Path: plugins/primus_dictionary - BuildDepends: bap-primus + BuildDepends: bap-primus, bap, core_kernel FindlibName: bap-plugin-primus_dictionary CompiledObject: best InternalModules: Primus_dictionary_main - XMETADescription: provides a key-value storage + XMETADescription: provides a key-value storage XMETAExtraLines: tags="primus, primus-library" \ No newline at end of file diff --git a/oasis/primus-machine b/oasis/primus-machine new file mode 100644 index 000000000..8a1faa4dc --- /dev/null +++ b/oasis/primus-machine @@ -0,0 +1,13 @@ +Flag primus_machine + Description: Build Primus Machine monad + Default: false + +Library bap_primus_machine + Build$: flag(everything) || flag(primus_machine) + XMETADescription: provides Primus Machine monad + XMETAExtraLines: tags="primus" + Path: lib/bap_primus_machine + FindlibName: bap-primus-machine + CompiledObject: best + BuildDepends: bap-knowledge, core_kernel + Modules: Bap_primus_machine diff --git a/oasis/primus-region b/oasis/primus-region index f71f415f6..7089f0afc 100644 --- a/oasis/primus-region +++ b/oasis/primus-region @@ -5,7 +5,7 @@ Flag primus_region Library primus_region_library_plugin Build$: flag(everything) || flag(primus_region) Path: plugins/primus_region - BuildDepends: bap-primus + BuildDepends: bap-primus, bap, core_kernel FindlibName: bap-plugin-primus_region CompiledObject: best InternalModules: Primus_region_main diff --git a/oasis/primus-support b/oasis/primus-support index 0e29bb0c3..229d87462 100644 --- a/oasis/primus-support +++ b/oasis/primus-support @@ -2,7 +2,6 @@ Flag primus_support Description: build supporting components for Primus Default: false - Library primus_loader_plugin Path: plugins/primus_loader Build$: flag(everything) || flag(primus_support) @@ -69,7 +68,7 @@ Library primus_print_plugin Build$: flag(everything) || flag(primus_support) FindlibName: bap-plugin-primus_print CompiledObject: best - BuildDepends: bap-primus, bare + BuildDepends: bap-primus, bare, bap XMETADescription: prints Primus states and observations Modules: Primus_print_main XMETAExtraLines: tags="primus, printer" @@ -79,7 +78,7 @@ Library primus_mark_visited_plugin Build$: flag(everything) || flag(primus_support) FindlibName: bap-plugin-primus_mark_visited CompiledObject: best - BuildDepends: bap-primus + BuildDepends: bap-primus, bap XMETADescription: marks terms that were visited by Primus Modules: Primus_mark_visited_main XMETAExtraLines: tags="primus" @@ -90,7 +89,7 @@ Library primus_limit Build$: flag(everything) || flag(primus_support) FindlibName: bap-plugin-primus_limit CompiledObject: best - BuildDepends: bap-primus + BuildDepends: bap-primus, bap XMETADescription: ensures termination by limiting Primus machines Modules: Primus_limit_main XMETAExtraLines: tags="primus" diff --git a/oasis/primus-test b/oasis/primus-test index a640491f9..534bd8656 100644 --- a/oasis/primus-test +++ b/oasis/primus-test @@ -6,7 +6,7 @@ Flag primus_test Library primus_test_library_plugin Build$: flag(everything) || flag(primus_test) Path: plugins/primus_test - BuildDepends: bap-primus + BuildDepends: bap-primus, bap, core_kernel FindlibName: bap-plugin-primus_test CompiledObject: best InternalModules: Primus_test_main diff --git a/oasis/regular b/oasis/regular index 91d93d691..7921ac78f 100644 --- a/oasis/regular +++ b/oasis/regular @@ -7,7 +7,7 @@ Library regular Path: lib/regular FindlibName: regular CompiledObject: best - BuildDepends: core_kernel + BuildDepends: core_kernel, ppx_jane Modules: Regular InternalModules: Regular_bytes, diff --git a/oasis/taint b/oasis/taint index 76b373edf..ab1e83dc9 100644 --- a/oasis/taint +++ b/oasis/taint @@ -24,7 +24,7 @@ Library primus_propagate_taint_plugin Build$: flag(everything) || flag(taint) FindlibName: bap-plugin-primus_propagate_taint CompiledObject: best - BuildDepends: bap-primus, bap-taint + BuildDepends: bap-primus, bap-taint, core_kernel, bap XMETADescription: a compatibility layer between different taint analysis frameworks InternalModules: Primus_propagate_taint_main XMETAExtraLines: tags="dataflow, pass, taint, primus" @@ -34,7 +34,7 @@ Library primus_taint_plugin Build$: flag(everything) || flag(taint) FindlibName: bap-plugin-primus_taint CompiledObject: best - BuildDepends: bap-primus, bap-taint + BuildDepends: bap-primus, bap-taint, bap, core_kernel XMETADescription: a taint analysis control interface DataFiles: lisp/*.lisp ($primus_taint_lisp_path) InternalModules: Primus_taint_main, Primus_taint_policies diff --git a/opam/opam b/opam/opam index 5a0224bfc..026c8d740 100644 --- a/opam/opam +++ b/opam/opam @@ -137,7 +137,15 @@ install: [ ["ocamlfind" "remove" "bap-strings"] ["ocamlfind" "remove" "monads"] ["ocamlfind" "remove" "ogre"] - ["ocamlfind" "remove" "bare"] + ["ocamlfind" "remove" "bap-knowledge"] + ["ocamlfind" "remove" "bitvec"] + ["ocamlfind" "remove" "bitvec-order"] + ["ocamlfind" "remove" "bitvec-sexp"] + ["ocamlfind" "remove" "bitvec-binprot"] + ["ocamlfind" "remove" "bap-core-theory"] + ["ocamlfind" "remove" "bap-elementary"] + ["ocamlfind" "remove" "bap-lisp"] + ["ocamlfind" "remove" "bap-primus-machine"] [make "reinstall"] ["cp" "run_tests.native" "%{bin}%/bap_run_tests"] ] @@ -224,6 +232,15 @@ remove: [ ["ocamlfind" "remove" "monads"] ["ocamlfind" "remove" "ogre"] ["ocamlfind" "remove" "bare"] + ["ocamlfind" "remove" "bap-knowledge"] + ["ocamlfind" "remove" "bitvec"] + ["ocamlfind" "remove" "bitvec-order"] + ["ocamlfind" "remove" "bitvec-sexp"] + ["ocamlfind" "remove" "bitvec-binprot"] + ["ocamlfind" "remove" "bap-core-theory"] + ["ocamlfind" "remove" "bap-elementary"] + ["ocamlfind" "remove" "bap-lisp"] + ["ocamlfind" "remove" "bap-primus-machine"] ["rm" "-f" "%{prefix}%/bin/baptop"] ["rm" "-f" "%{prefix}%/bin/ppx-bap"] ["rm" "-rf" "%{prefix}%/share/bap"] diff --git a/plugins/bil/.merlin b/plugins/bil/.merlin index 59b176960..28a83b472 100644 --- a/plugins/bil/.merlin +++ b/plugins/bil/.merlin @@ -1,6 +1,10 @@ +REC B ../../_build/plugins/bil -S . - -PKG bap +B ../../_build/lib/knowledge +B ../../_build/lib/bap_core_theory +FLG -open Bap_core_theory +PKG oUnit -REC +S . +B _build +FLG -short-paths \ No newline at end of file diff --git a/plugins/bil/Makefile b/plugins/bil/Makefile new file mode 100644 index 000000000..b7283f4be --- /dev/null +++ b/plugins/bil/Makefile @@ -0,0 +1,7 @@ + +float: + bapbuild -pkgs findlib.dynload,monads,bap-primus,oUnit,bap-core-theory,bap-knowledge bil_float_tests.native + +clean: + rm -rf _build + rm -f bil_float_tests.native diff --git a/plugins/bil/bil_float.ml b/plugins/bil/bil_float.ml new file mode 100644 index 000000000..6a68bb8db --- /dev/null +++ b/plugins/bil/bil_float.ml @@ -0,0 +1,1003 @@ +open Core_kernel +open Bap.Std + +open Bap_knowledge +open Bap_core_theory +open Knowledge.Syntax + + +type 'a t = 'a knowledge + +type ('b,'e,'t,'s) fsort = (('b,'e,'t) Theory.IEEE754.t,'s) Theory.format + Theory.Float.t Theory.Value.sort + +module Value = Knowledge.Value + +module Make(B : Theory.Core) = struct + + open Knowledge.Syntax + + module B = struct + include B + + let one s = succ (zero s) + let ones s = not (zero s) + + let is_one x = + x >>= fun v -> eq x (one (Theory.Value.sort v)) + + let is_all_ones x = + x >>= fun v -> eq x (ones (Theory.Value.sort v)) + + let of_int sort v = + let m = Bitvec.modulus (Theory.Bitv.size sort) in + int sort Bitvec.(int v mod m) + + let is_negative = B.msb + let is_positive x = B.inv (B.msb x) + let is_non_negative x = B.or_ (is_positive x) (B.is_zero x) + let abs x = ite (is_negative x) (neg x) x + + let of_bool = function + | true -> b1 + | false -> b0 + + let testbit x i = lsb (rshift x i) + + let if_ cond ~then_ ~else_ = ite cond then_ else_ + + module Infix = struct + let ( + ) = add + let ( - ) = sub + let ( * ) = mul + let ( / ) = div + let ( = ) = eq + let ( <> ) = neq + let ( < ) = ult + let ( > ) = ugt + let ( <= ) = ule + let ( >= ) = uge + let ( <$ ) = slt + let ( >$ ) = sgt + let ( <=$ ) = sle + let ( >=$ ) = sge + let ( lsr ) = rshift + let ( lsl ) = lshift + let ( land ) = logand + let ( lor ) = logor + let ( lxor ) = logxor + let (&&) = and_ + let (||) = or_ + end + + include Infix + + let max x y = ite (x > y) x y + let min x y = ite (x < y) x y + let smax x y = ite (x >$ y) x y + let smin x y = ite (x <$ y) x y + + let leading_one sort = + one sort lsl (of_int sort Caml.(Theory.Bitv.size sort - 1)) + + end + + type 'a t = 'a knowledge + + module Bits = Theory.Bitv + module IEEE754 = Theory.IEEE754 + + let (>>->) x f = + x >>= fun x -> + f (Theory.Value.sort x) x + + let bind a body = + a >>= fun a -> + let sort = Theory.Value.sort a in + Theory.Var.scoped sort @@ fun v -> + B.let_ v !!a (body (B.var v)) + + let (>>>=) = bind + + let exps = Theory.IEEE754.Sort.exps + + let sigs fsort = + let open Theory.IEEE754 in + let spec = Sort.spec fsort in + Theory.Bitv.define spec.p + + let floats fsort = exps fsort, sigs fsort + + let bits = IEEE754.Sort.bits + + let fsign = B.msb + + let exponent fsort bitv = + let open IEEE754 in + let bits = Sort.bits fsort in + let exps = Sort.exps fsort in + let spec = Sort.spec fsort in + B.(low exps (bitv lsr of_int bits spec.t)) + + (* pre: input coef is already of t-bits length *) + let pack_raw fsort sign expn coef = + let open B in + let open IEEE754 in + let bits = Sort.bits fsort in + let bit = Bits.define 1 in + let bits_1 = Bits.define Caml.(Bits.size bits - 1) in + let sign = ite sign (B.one bit) (B.zero bit) in + B.append bits sign (B.append bits_1 expn coef) + + let pack fsort sign expn coef = + let open B in + let {IEEE754.p; t} = IEEE754.Sort.spec fsort in + let is_subnormal = (inv (msb coef)) && (is_one expn) in + ite is_subnormal (pred expn) expn >>>= fun expn -> + if Caml.(p = t) then pack_raw fsort sign expn coef + else + B.low (Bits.define t) coef >>>= fun coef -> + pack_raw fsort sign expn coef >>= fun r -> !!r + + let raw_significand fsort bitv = + let open IEEE754 in + let spec = Sort.spec fsort in + let sigs = Bits.define spec.t in + B.low sigs bitv + + let finite_significand fsort expn bitv = + let open IEEE754 in + let spec = Sort.spec fsort in + raw_significand fsort bitv >>>= fun coef -> + if spec.t = spec.p then coef + else + let bit = Bits.define 1 in + let leading_bit = B.(ite (is_zero expn) (zero bit) (one bit)) in + B.append (sigs fsort) leading_bit coef + + let unpack fsort x f = + x >>>= fun x -> + exponent fsort x >>>= fun expn -> + fsign x >>>= fun sign -> + finite_significand fsort expn x >>>= fun coef -> + B.(ite (is_zero expn) (succ expn) expn) >>>= fun expn -> + f sign expn coef + + let unpack_raw fsort x f = + x >>>= fun x -> + exponent fsort x >>>= fun expn -> + raw_significand fsort x >>>= fun coef -> + f (fsign x) expn coef + + let with_sign sign bitv = + let open B in + bitv >>-> fun s bitv -> + let s' = Bits.define Caml.(Bits.size s - 1) in + let bit = Bits.define 1 in + ite sign ((append s (one bit) (zero s')) lor !!bitv) + ((append s (zero bit) (ones s')) land !!bitv) + + let fzero fsort sign = + let open B in + let bits = IEEE754.Sort.bits fsort in + zero bits >>>= fun bitv -> + ite sign ( + ones bits >>>= fun ones -> + not (ones lsr one bits) >>>= fun one -> + one lor bitv) bitv + + let fone fsort sign = + let open IEEE754 in + let {bias; t} = Sort.spec fsort in + let expn = B.of_int (Sort.exps fsort) bias in + let sigs = Bits.define t in + pack_raw fsort sign expn (B.zero sigs) + + let inf fsort sign = + let open B in + let exps = IEEE754.Sort.exps fsort in + pack fsort sign (B.ones exps) (B.zero (sigs fsort)) + + let is_inf fsort x : Theory.bool = + unpack fsort x @@ fun _ expn coef -> + B.(and_ (is_zero coef) (is_all_ones expn)) + + let is_pinf fsort x = + is_inf fsort x >>>= fun inf -> + B.(and_ inf (inv (msb x))) + + let is_ninf fsort x = + is_inf fsort x >>>= fun inf -> + B.(and_ inf (msb x)) + + let is_qnan fsort x = + let open B in + unpack_raw fsort x @@ fun _sign expn coef -> + is_all_ones expn && non_zero coef && msb coef + + let is_snan fsort x = + let open B in + unpack_raw fsort x @@ fun _sign expn coef -> + is_all_ones expn && non_zero coef && inv (msb coef) + + let is_nan fsort x = + let open B in + unpack_raw fsort x @@ fun _sign expn coef -> + is_all_ones expn && non_zero coef + + let qnan fsort = + let open B in + let open IEEE754 in + let exps = Sort.exps fsort in + let spec = Sort.spec fsort in + let sigs = Bits.define spec.t in + not (ones sigs lsr one sigs) >>>= fun coef -> + pack_raw fsort B.b1 (ones exps) coef + + let snan fsort = + let open B in + let open IEEE754 in + let exps = Sort.exps fsort in + let spec = Sort.spec fsort in + let sigs = Bits.define spec.t in + pack_raw fsort B.b0 (ones exps) (B.one sigs) + + (* unset a leading bit in coef, no-checks for nan are performed *) + let transform_to_signal fsort x = + let open IEEE754 in + let spec = Sort.spec fsort in + let bits = Sort.bits fsort in + let shift = B.of_int bits (spec.t - 1) in + let mask = B.(not (one bits lsl shift)) in + B.(x land mask) + + (* set a leading bit in coef, no-checks for nan are performed *) + let transform_to_quite fsort x = + let open IEEE754 in + let spec = Sort.spec fsort in + let bits = Sort.bits fsort in + let shift = B.of_int bits (spec.t - 1) in + let mask = B.(one bits lsl shift) in + B.(x lor mask) + + let with_special fsort x f = + let open B in + unpack_raw fsort x @@ fun _sign expn coef -> + is_all_ones expn >>>= fun is_special -> + (is_special && is_zero coef) >>>= fun is_inf -> + (is_special && non_zero coef && inv (msb coef)) >>>= fun is_snan -> + (is_special && msb coef) >>>= fun is_qnan -> + f ~is_inf ~is_snan ~is_qnan + + let is_special fsort x = unpack fsort x @@ fun _ expn _ -> B.is_all_ones expn + let is_finite fsort x = B.inv (is_special fsort x) + let is_finite_nonzero fsort x = + let open B in + unpack_raw fsort x @@ fun _ expn coef -> + inv (is_all_ones expn) >>>= fun ok_expn -> + ok_expn && (non_zero expn || non_zero coef) + + let is_norml fsort x = + unpack_raw fsort x @@ fun _ e _ -> + B.(non_zero e && inv (is_all_ones e)) + let is_subnormal fsort x = + unpack_raw fsort x @@ fun _ e _ -> B.is_zero e + + let is_zero x = + let open B in + x >>-> fun s x -> + is_zero ((!!x lsl one s) lsr one s) + + (* TODO: just a stub, need more reliable functions *) + let fsucc fsort x = + let open B in + let exps,sigs = floats fsort in + unpack fsort x @@ fun sign expn coef -> + succ coef >>>= fun coef -> + ite (is_zero coef) (one exps) (zero exps) >>>= fun de -> + ite (is_zero coef) (leading_one sigs) coef >>>= fun coef -> + expn + de >>>= fun expn -> + pack fsort sign expn coef + + (* TODO: just a stub, need more reliable functions *) + let fpred fsort x = + let open B in + let exps,_sigs = floats fsort in + unpack fsort x @@ fun sign expn coef -> + pred coef >>>= fun coef -> + ite (is_all_ones coef) (one exps) (zero exps) >>>= fun de -> + expn - de >>>= fun expn -> + pack fsort sign expn coef + + let precision fsort = let {IEEE754.p} = IEEE754.Sort.spec fsort in p + let bias fsort = let {IEEE754.bias} = IEEE754.Sort.spec fsort in bias + + let match_ ?default cases = + let cases, default = match default with + | Some d -> List.rev cases, d + | None -> + let cases = List.rev cases in + let _,d = List.hd_exn cases in + List.tl_exn cases, d in + List.fold cases ~init:default + ~f:(fun fin (cond, ok) -> B.ite cond ok fin) + + let (-->) x y = x,y + let anything_else = B.b1 + + + let extract_last x n = + let open B in + x >>-> fun xsort x -> + n >>-> fun nsort n -> + let mask = ones xsort lsl !!n in + let size = Bits.size xsort |> B.of_int nsort in + ite (!!n = size) !!x (!!x land (not mask)) + + let is_round_up rm sign last guard round sticky = + let open B in + let case m t f = ite (requal rm m) t f and default = ident in + case rtn (inv sign) @@ + case rtz sign @@ + case rna guard @@ + case rne (guard && (round || sticky || last)) @@ + default b0 + + let guardbits loss lost_bits f = + let open B in + lost_bits >>-> fun sort lost_bits -> + (!!lost_bits > zero sort) >>>= fun has_grd -> + (!!lost_bits > one sort ) >>>= fun has_rnd -> + (!!lost_bits > of_int sort 2) >>>= fun has_stk -> + ite has_grd (pred !!lost_bits) (zero sort) >>>= fun grd_pos -> + ite has_rnd (pred grd_pos) (zero sort) >>>= fun rnd_pos -> + ite has_grd (testbit loss grd_pos) b0 >>>= fun grd -> + ite has_rnd (testbit loss rnd_pos) b0 >>>= fun rnd -> + ite has_stk (non_zero (extract_last loss rnd_pos)) b0 >>>= fun stk -> + f grd rnd stk + + let round rm sign coef loss lost_bits f = + let open B in + guardbits loss lost_bits @@ fun grd rnd stk -> + ite (is_round_up rm sign (lsb coef) grd rnd stk) (succ coef) coef >>>= fun coef' -> + and_ (non_zero coef) (is_zero coef') >>>= fun is_overflow -> + f coef' is_overflow + + (* maximum possible exponent that fits in [n - 1] bits. (one for sign) + and one for special numbers like inf or nan *) + let max_exponent' n = int_of_float (2.0 ** (float_of_int n )) - 2 + let min_exponent' _n = 1 + let max_exponent n = B.of_int n (Bits.size n |> max_exponent') + let min_exponent n = B.of_int n (Bits.size n |> min_exponent') + + (* returns pow of 2 nearest to n and 2 in that power *) + let nearest_pow2 num = + let rec find pow n = + let n' = 2 * n in + if n' >= num then pow + 1, n' + else find (pow + 1) n' in + find 0 1 + + let clz x = + let open B in + x >>-> fun sort x -> + let size = Bits.size sort in + let pow, num = nearest_pow2 size in + let sort' = Bits.define num in + let shifts = List.init pow ~f:(fun p -> Caml.(num / (2 lsl p))) in + let shifts,_ = + List.fold shifts ~init:([],0) + ~f:(fun (acc,prev) curr -> + let total = Caml.(curr + prev) in + (total, curr) :: acc, total) in + let shifts = List.rev shifts in + let rec loop lets x = function + | [] -> List.fold lets ~f:(+) ~init:(zero sort) + | (total, shf) :: shifts -> + ones sort' lsl (of_int sort total) >>>= fun mask -> + ite (is_zero (x land mask)) (x lsl of_int sort shf) x >>>= fun nextx -> + ite (is_zero (x land mask)) (of_int sort shf) (zero sort) >>>= fun nextn -> + loop (nextn :: lets) nextx shifts in + loop [] (B.unsigned sort' !!x) shifts >>>= fun n -> + of_int sort Caml.(num - size) >>>= fun dif -> + ite (is_zero !!x) (of_int sort size) (n - dif) + + let possible_lshift expn coef = + let open B in + expn >>-> fun exps expn -> + coef >>-> fun sigs coef -> + clz !!coef >>>= fun clz -> + min_exponent exps >>>= fun mine -> + ite (!!expn < mine) (zero exps) (!!expn - mine) >>>= fun diff -> + unsigned sigs diff >>>= fun diff -> + ite (clz < diff) clz diff + + let norm_finite expn coef f = + let open B in + expn >>-> fun exps expn -> + coef >>-> fun _sigs coef -> + possible_lshift !!expn !!coef >>>= fun shift -> + unsigned exps shift >>>= fun dexpn -> + !!coef lsl shift >>>= fun coef -> + ite (is_zero coef) (min_exponent exps) (!!expn - dexpn) >>>= fun expn -> + f expn coef + + let norm expn coef f = + let open B in + expn >>-> fun _exps expn -> + ite (is_all_ones !!expn) (f !!expn coef) + (norm_finite !!expn coef f) + + let msbn x = + let open B in + x >>-> fun sort x -> + clz !!x >>>= fun clz -> + of_int sort (Bits.size sort) - clz - one sort + + let xor s s' = B.(and_ (or_ s s') (inv (and_ s s'))) + + let guardbits' overflow last loss lost_bits f = + let open B in + guardbits loss lost_bits @@ fun guard' round' sticky' -> + ite overflow last guard' >>>= fun guard -> + ite overflow guard' round' >>>= fun round -> + ite overflow (round' || sticky') sticky' >>>= fun sticky -> + f guard round sticky + + let fadd_finite fsort rm x y = + let open B in + let exps, sigs = floats fsort in + unpack fsort x @@ fun xsign xexpn xcoef -> + unpack fsort y @@ fun _ yexpn ycoef -> + ite (xexpn > yexpn) (xexpn - yexpn) (yexpn - xexpn) >>>= fun lost_bits -> + match_ [ + (xexpn = yexpn) --> zero sigs; + (xexpn > yexpn) --> extract_last ycoef lost_bits; + (xexpn < yexpn) --> extract_last xcoef lost_bits; + ] >>>= fun loss -> + ite (xexpn > yexpn) xcoef (xcoef lsr lost_bits) >>>= fun xcoef -> + ite (yexpn > xexpn) ycoef (ycoef lsr lost_bits) >>>= fun ycoef -> + xcoef + ycoef >>>= fun sum -> + max xexpn yexpn >>>= fun expn -> + ite (sum >= xcoef) expn (succ expn) >>>= fun expn -> + guardbits' (sum < xcoef) (lsb sum) loss lost_bits @@ fun guard round sticky -> + ite (sum < xcoef) (sum lsr one sigs) sum >>>= fun coef -> + ite (sum < xcoef) (coef lor leading_one sigs) coef >>>= fun coef -> + is_round_up rm xsign (lsb coef) guard round sticky >>>= fun up -> + ite up (succ coef) coef >>>= fun coef' -> + (is_zero coef' && non_zero coef) >>>= fun rnd_overflow -> + ite rnd_overflow (leading_one sigs) coef' >>>= fun coef -> + ite rnd_overflow (succ expn) expn >>>= fun _expn' -> + ite (expn > max_exponent exps) (zero sigs) coef >>>= fun coef -> + norm expn coef @@ fun expn coef -> + pack fsort xsign expn coef + + let bitv_extend bitv ~addend f = + bitv >>-> fun sort bitv -> + let sort' = Bits.define (Bits.size sort + addend) in + B.unsigned sort' !!bitv >>>= fun bitv -> + f bitv sort' + + let common_ground xexpn xcoef yexpn ycoef f = + let open B in + xexpn >>-> fun exps xexpn -> + xcoef >>-> fun sigs xcoef -> + let xexpn = !!xexpn in + let xcoef = !!xcoef in + ite (xexpn > yexpn) (xexpn - yexpn) (yexpn - xexpn) >>>= fun diff -> + ite (is_zero diff) diff (diff - one exps) >>>= fun lost_bits -> + match_ [ + (xexpn > yexpn) --> ( + extract_last ycoef lost_bits >>>= fun loss -> + xcoef lsl one sigs >>>= fun xcoef -> + ycoef lsr lost_bits >>>= fun ycoef -> + f loss lost_bits xcoef ycoef); + (xexpn < yexpn) --> ( + extract_last xcoef lost_bits >>>= fun loss -> + xcoef lsr lost_bits >>>= fun xcoef -> + ycoef lsl one sigs >>>= fun ycoef -> + f loss lost_bits xcoef ycoef); + (xexpn = yexpn) --> f (zero sigs) (zero exps) xcoef ycoef; + ] + + let fsub_finite fsort rm x y = + let open B in + let exps, sigs = floats fsort in + unpack fsort x @@ fun xsign xexpn xcoef -> + unpack fsort y @@ fun _ysign yexpn ycoef -> + bitv_extend xcoef ~addend:1 @@ fun xcoef sigs' -> + bitv_extend ycoef ~addend:1 @@ fun ycoef _ -> + or_ (xexpn < yexpn) (and_ (xexpn = yexpn) (xcoef < ycoef)) >>>= fun swap -> + ite swap (inv xsign) xsign >>>= fun sign -> + ite (xexpn = yexpn && xcoef = ycoef) b0 sign >>>= fun sign -> + common_ground xexpn xcoef yexpn ycoef @@ fun loss lost_bits xcoef ycoef -> + ite (is_zero loss) (zero sigs') (one sigs') >>>= fun borrow -> + ite swap (ycoef - xcoef - borrow) (xcoef - ycoef - borrow) >>>= fun coef -> + msb coef >>>= fun msbc -> + max xexpn yexpn >>>= fun expn -> + ite (xexpn = yexpn) expn (expn - one exps) >>>= fun expn -> + ite (is_zero coef) (min_exponent exps) (ite msbc (succ expn) expn) >>>= fun expn -> + guardbits loss lost_bits @@ fun guard' round' sticky' -> + ite (round' || sticky') (inv guard') guard' >>>= fun guard' -> + ite msbc (lsb coef) guard' >>>= fun guard -> + ite msbc guard' round' >>>= fun round -> + ite msbc (round' || sticky') sticky' >>>= fun sticky -> + ite msbc (coef lsr one sigs') coef >>>= fun coef -> + is_round_up rm sign (lsb coef) guard round sticky >>>= fun up -> + unsigned sigs coef >>>= fun coef -> + ite up (succ coef) coef >>>= fun coef' -> + (is_zero coef' && non_zero coef) >>>= fun rnd_overflow -> + ite rnd_overflow (leading_one sigs) coef' >>>= fun coef -> + ite rnd_overflow (succ expn) expn >>>= fun expn -> + norm expn coef @@ fun expn coef -> pack fsort sign expn coef + + let add_or_sub_finite is_sub fsort rm x y = + let ( lxor ) = xor in + let s1 = is_sub in + let s2 = fsign x in + let s3 = fsign y in + let is_sub = s1 lxor (s2 lxor s3) in + B.ite is_sub (fsub_finite fsort rm x y) + (fadd_finite fsort rm x y) + + let fsum_special fsort is_sub x y = + let open B in + let not = inv in + let ( lxor ) = xor in + let s1 = is_sub in + let s2 = fsign x in + let s3 = fsign y in + let is_sub = s1 lxor (s2 lxor s3) in + with_special fsort x @@ fun ~is_inf:xinf ~is_snan:xsnan ~is_qnan:xqnan -> + with_special fsort y @@ fun ~is_inf:yinf ~is_snan:ysnan ~is_qnan:yqnan -> + (xsnan || xqnan) >>>= fun xnan -> + (ysnan || yqnan) >>>= fun ynan -> + (xinf && yinf) >>>= fun is_inf -> + match_ [ + (is_sub && is_inf) --> qnan fsort; + (xinf && yinf) --> x; + (xnan && (not ynan)) --> transform_to_quite fsort x; + (ynan && (not xnan)) --> transform_to_quite fsort y; + anything_else --> (transform_to_quite fsort x); + ] + + let add_or_sub ~is_sub fsort rm x y = + is_finite fsort x >>>= fun is_x_fin -> + is_finite fsort y >>>= fun is_y_fin -> + let open B in + ite (is_x_fin && is_y_fin) + (add_or_sub_finite is_sub fsort rm x y) + (fsum_special fsort is_sub x y) + + let fsub fsort rm x y = add_or_sub ~is_sub:B.b1 fsort rm x y + let fadd fsort rm x y = add_or_sub ~is_sub:B.b0 fsort rm x y + + let double bitv f = + bitv >>-> fun sort bitv -> + bitv_extend !!bitv ~addend:(Bits.size sort) f + + let normalize_coef coef f = + let open B in + coef >>-> fun _sort coef -> + clz !!coef >>>= fun clz -> + !!coef lsl clz >>>= fun coef -> + f coef clz + + (* Clarification. + The result of (Sx,Ex,Cx) * (Sy,Ey,Cy) is (Sx xor Sy, Ex + Ey - bias, Cx * Cy), + where S,E,C - sign, exponent and coefficent. + Also, say we have 53-bit precision: C = c52 . c51 . c50 ... c0. + We normalize operands by shifting them as left as possible to + be sure what exactly bits the result will occupy. + The bare result is C = c105 . c104 ... c0. If c105 is set then + we have an overflow and right shift needed. The result of multiplication + is in 53 bits starting from c104.*) + let fmul_finite fsort rm x y = + let open B in + let double e c f = + double e @@ fun e es -> + double c @@ fun c cs -> + f e es c cs in + let precision = precision fsort in + let exps,sigs = floats fsort in + let mine = min_exponent' (Bits.size exps) in + let maxe = max_exponent' (Bits.size exps) in + unpack fsort x @@ fun xsign xexpn xcoef -> + unpack fsort y @@ fun ysign yexpn ycoef -> + xor xsign ysign >>>= fun sign -> + normalize_coef xcoef @@ fun xcoef dx -> + normalize_coef ycoef @@ fun ycoef dy -> + dx + dy >>>= fun dnorm -> + double xexpn xcoef @@ fun xexpn exps' xcoef sigs' -> + double yexpn ycoef @@ fun yexpn _ ycoef _ -> + of_int sigs' precision >>>= fun prec -> + xexpn + yexpn >>>= fun expn -> + xcoef * ycoef >>>= fun coef -> + msb coef >>>= fun coef_overflowed -> + ite coef_overflowed (succ expn) expn >>>= fun expn -> + of_int exps' (bias fsort) >>>= fun bias -> + bias + unsigned exps' dnorm >>>= fun dexpn -> + ite (dexpn >=$ expn) (dexpn - expn + of_int exps' mine) (zero exps') >>>= fun underflow -> + underflow > of_int exps' precision >>>= fun is_underflow -> + expn - dexpn + underflow >>>= fun expn -> + expn > of_int exps' maxe >>>= fun is_overflow -> + ite coef_overflowed (one sigs') (zero sigs') >>>= fun from_overflow -> + coef lsr (from_overflow + unsigned sigs' underflow) >>>= fun coef -> + coef lsl one sigs' >>>= fun coef -> + low sigs coef >>>= fun loss -> + high sigs coef >>>= fun coef -> + low exps expn >>>= fun expn -> + round rm sign coef loss prec @@ fun coef rnd_overflow -> + ite rnd_overflow (leading_one sigs) coef >>>= fun coef -> + ite rnd_overflow (succ expn) expn >>>= fun expn -> + ((expn > of_int exps maxe) || is_overflow) >>>= fun is_overflow -> + norm expn coef @@ fun expn coef -> + ite is_underflow (zero exps) expn >>>= fun expn -> + ite is_overflow (ones exps) expn >>>= fun expn -> + ite (is_overflow || is_underflow) (zero sigs) coef >>>= fun coef -> + pack fsort sign expn coef + + let fmul_special fsort x y = + let open B in + with_special fsort x @@ fun ~is_inf:xinf ~is_snan:xsnan ~is_qnan:xqnan -> + with_special fsort y @@ fun ~is_inf:yinf ~is_snan:_ysnan ~is_qnan:_yqnan -> + fsign x >>>= fun xsign -> + fsign y >>>= fun ysign -> + (xinf && yinf) >>>= fun is_inf -> + match_ [ + (is_zero x && yinf) --> qnan fsort; + (is_zero y && xinf) --> qnan fsort; + is_inf --> with_sign (xor xsign ysign) x; + (xsnan || xqnan) --> transform_to_quite fsort x; + anything_else --> (transform_to_quite fsort y); + ] + + let fmul fsort rm x y = + is_finite fsort x >>>= fun is_x_fin -> + is_finite fsort y >>>= fun is_y_fin -> + B.(ite (is_x_fin && is_y_fin) + (fmul_finite fsort rm x y) + (fmul_special fsort x y)) + + let mask_bit sort i = + let uno = B.one sort in + let shf = B.of_int sort i in + B.(uno lsl shf) + + (* pre: nominator > denominator *) + let long_division prec nomin denom f = + let open B in + let lost_alot = b1, b1, b0 in + let lost_half = b1, b0, b0 in + let lost_zero = b0, b0, b0 in + let lost_afew = b0, b0, b1 in + nomin >>-> fun sort nomin -> + let rec loop i bits nomin = + if Caml.(i < 0) then + List.fold bits ~f:(lor) ~init:(zero sort) >>>= fun coef -> + match_ [ + (nomin > denom) --> f coef lost_alot; + (nomin = denom) --> f coef lost_half; + (nomin = zero sort) --> f coef lost_zero; + anything_else --> (f coef lost_afew); + ] + else + ite (nomin > denom) (mask_bit sort i) (zero sort) >>>= fun bit -> + ite (nomin > denom) (nomin - denom) nomin >>>= fun next_nomin -> + next_nomin lsl one sort >>>= fun next_nomin -> + bind next_nomin (fun next_nomin -> + loop Caml.(i - 1) (bit :: bits) next_nomin) in + loop Caml.(prec - 1) [] !!nomin + + let fdiv_finite fsort rm x y = + let open B in + let norm_nominator exps sigs nomin denom f = + ite (nomin < denom) (nomin lsl one sigs) nomin >>>= fun nomin' -> + ite (nomin < denom) (one exps) (zero exps) >>>= fun dexpn -> + f nomin' dexpn in + let exps,sigs = floats fsort in + let prec = precision fsort in + let prec' = of_int exps prec in + let maxe = max_exponent exps in + let mine = min_exponent exps in + unpack fsort x @@ fun xsign xexpn xcoef -> + unpack fsort y @@ fun ysign yexpn ycoef -> + xor xsign ysign >>>= fun sign -> + normalize_coef xcoef @@ fun nomin dx -> + normalize_coef ycoef @@ fun denom dy -> + dy - dx >>>= fun de -> + bitv_extend nomin ~addend:1 @@ fun nomin sigs' -> + bitv_extend denom ~addend:1 @@ fun denom _ -> + norm_nominator exps sigs' nomin denom @@ fun nomin dexpn' -> + long_division prec nomin denom @@ fun coef (guard', round', sticky') -> + unsigned sigs coef >>>= fun coef -> + of_int exps (bias fsort) >>>= fun bias -> + unsigned exps de - dexpn' >>>= fun dexpn -> + xexpn - yexpn >>>= fun expn -> + ((xexpn < yexpn) && (yexpn - xexpn > dexpn + bias - mine)) >>>= fun underflowed -> + ((xexpn < yexpn) && (yexpn - xexpn > prec' + dexpn + bias - mine)) >>>= fun is_underflow -> + ((xexpn > yexpn) && (expn > maxe - dexpn - bias)) >>>= fun is_overflow -> + expn + dexpn + bias >>>= fun expn -> + if_ underflowed + ~then_:( + abs expn + mine >>>= fun fix_underflow -> + extract_last coef fix_underflow >>>= fun loss -> + guardbits loss fix_underflow @@ fun guard round sticky -> + (sticky || round' || sticky') >>>= fun sticky -> + coef lsr fix_underflow >>>= fun coef -> + is_round_up rm sign (lsb coef) guard round sticky >>>= fun up -> + ite up (succ coef) coef >>>= fun coef -> + norm mine coef @@ fun expn coef -> pack fsort sign expn coef) + ~else_:( + is_round_up rm sign (lsb coef) guard' round' sticky' >>>= fun up -> + ite up (succ coef) coef >>>= fun coef -> + ite is_underflow (zero exps) expn >>>= fun expn -> + ite is_overflow (ones exps) expn >>>= fun expn -> + ite (is_overflow || is_underflow) (zero sigs) coef >>>= fun coef -> + norm expn coef @@ fun expn coef -> pack fsort sign expn coef) + + let fdiv_special fsort x y = + let open B in + with_special fsort x @@ fun ~is_inf:xinf ~is_snan:xsnan ~is_qnan:xqnan -> + with_special fsort y @@ fun ~is_inf:yinf ~is_snan:ysnan ~is_qnan:_yqnan -> + fsign x >>>= fun xsign -> + fsign y >>>= fun ysign -> + (xinf && yinf) >>>= fun is_inf -> + inv (xinf || xsnan || xsnan) >>>= fun is_finx -> + inv (yinf || ysnan || ysnan) >>>= fun is_finy -> + xor xsign ysign >>>= fun sign -> + match_ [ + (is_zero x && is_zero y) --> qnan fsort; + (is_zero x && is_finy) --> fzero fsort sign; + (is_zero y && is_finx) --> inf fsort sign; + (xinf && yinf) --> qnan fsort; + is_inf --> with_sign (xor xsign ysign) x; + (xsnan || xqnan) --> transform_to_quite fsort x; + anything_else --> (transform_to_quite fsort y); + ] + + let fdiv fsort rm x y = + let open B in + ite (is_finite_nonzero fsort x && is_finite_nonzero fsort y) + (fdiv_finite fsort rm x y) + (fdiv_special fsort x y) + + let ftwo fsort = + fone fsort B.b0 >>>= fun one -> + fadd fsort B.rne one one + + (* pre: fsort ">=" fsort' *) + let truncate fsort x rm fsort' = + let open B in + let sigs_sh = Bits.size (sigs fsort') in + let d_bias = Caml.(bias fsort - bias fsort') in + let dst_maxe = max_exponent' (Bits.size @@ exps fsort') in + unpack fsort x @@ fun sign expn coef -> + if_ (is_all_ones expn || is_zero expn) + ~then_:( + low (exps fsort') expn >>>= fun expn -> + high (sigs fsort') coef >>>= fun coef -> + pack fsort' sign expn coef) + ~else_:( + expn - of_int (exps fsort) d_bias >>>= fun expn -> + if_ (expn > of_int (exps fsort) dst_maxe) + ~then_:(inf fsort' sign) + ~else_:( + low (exps fsort') expn >>>= fun expn -> + coef lsl of_int (sigs fsort) sigs_sh >>>= fun truncated -> + high (sigs fsort') truncated >>>= fun truncated -> + high (sigs fsort') coef >>>= fun coef -> + lsb coef >>>= fun last -> + msb truncated >>>= fun guard -> + truncated lsl one (sigs fsort') >>>= fun truncated -> + msb truncated >>>= fun round -> + truncated lsl one (sigs fsort') >>>= fun truncated -> + non_zero truncated >>>= fun sticky -> + is_round_up rm sign last guard round sticky >>>= fun up -> + ite (is_all_ones coef && up) (succ expn) expn >>>= fun expn -> + ite up (succ coef) coef >>>= fun coef -> + ite (is_all_ones expn) (inf fsort' sign) + (pack fsort' sign expn coef))) + + (* pre: fsort "<=" fsort' *) + let extend fsort x fsort' = + let open B in + let d_sigs = Caml.(Bits.size (sigs fsort') - Bits.size (sigs fsort)) in + let d_bias = Caml.(bias fsort' - bias fsort) in + unpack fsort x @@ fun sign expn coef -> + match_ [ + is_all_ones expn --> ones (exps fsort'); + (expn = min_exponent (exps fsort)) --> min_exponent (exps fsort'); + anything_else --> + (unsigned (exps fsort') expn + of_int (exps fsort') d_bias); + ] >>>= fun expn -> + unsigned (sigs fsort') coef >>>= fun coef -> + (coef lsl (of_int (sigs fsort') d_sigs)) >>>= fun coef -> + pack fsort' sign expn coef + + let convert fsort x rm fsort' = + let size f = Bits.size (IEEE754.Sort.bits f) in + if size fsort = size fsort' then + B.unsigned (bits fsort') x + else if size fsort < size fsort' then + extend fsort x fsort' + else truncate fsort x rm fsort' + + let double_precision fsort = + let bits = Bits.size (bits fsort) in + let bits' = 2 * bits in + let p = Option.value_exn (IEEE754.binary bits') in + IEEE754.Sort.define p + + let gen_cast_float fsort _rmode sign bitv = + let open IEEE754 in + let open B in + let {p;bias} = Sort.spec fsort in + let exps = exps fsort in + let sigs = Bits.define p in + bitv >>-> fun inps bitv -> + of_int exps Caml.(bias + p - 1) >>>= fun expn -> + of_int sigs p >>>= fun prec -> + clz !!bitv >>>= fun clz -> + unsigned sigs clz >>>= fun clz -> + of_int sigs (Bits.size inps) - clz >>>= fun msbn -> + if_ (msbn > prec) + ~then_:(msbn - prec >>>= fun de -> + msbn - one sigs >>>= fun hi -> + hi - prec + one sigs >>>= fun lo -> + extract sigs hi lo !!bitv >>>= fun coef -> + extract sigs (pred lo) (zero sigs) !!bitv >>>= fun loss -> + round rne sign coef loss lo @@ fun coef rnd_overflow -> + ite rnd_overflow (not (ones sigs lsr one sigs)) coef >>>= fun coef -> + ite rnd_overflow (succ expn) expn >>>= fun expn -> + expn + unsigned exps de >>>= fun expn -> + norm expn coef @@ fun expn coef -> + pack fsort sign expn coef) + ~else_:(unsigned sigs !!bitv >>>= fun coef -> + norm expn coef @@ fun expn coef -> + pack fsort sign expn coef) + + let cast_float fsort rmode bitv = gen_cast_float fsort rmode B.b0 bitv + + let cast_float_signed fsort rmode bitv = + let open B in + let sign = msb bitv in + let bitv = ite sign (neg bitv) bitv in + gen_cast_float fsort rmode sign bitv + + let cast_int (fsort : _ fsort) outs bitv = + let open B in + let open IEEE754 in + let {p;bias} = Sort.spec fsort in + let exps = exps fsort in + let sigs = Bits.define p in + unpack fsort bitv @@ fun sign expn coef -> + expn - of_int exps bias + one exps >>>= fun bits -> + of_int sigs p - unsigned sigs bits >>>= fun bits -> + coef lsr unsigned sigs bits >>>= fun coef -> + unsigned outs coef >>>= fun coef -> + ite sign (neg coef) coef + + (* returns x in range [1.0; 2.0] *) + let range_reduction fsort x f = + let open B in + let bias = bias fsort in + let low = of_int (exps fsort) Caml.(bias - 0) in + let top = of_int (exps fsort) Caml.(bias + 1) in + unpack_raw fsort x @@ fun sign expn coef -> + match_ [ + (expn = top && non_zero coef) --> (expn - low); + (expn > top) --> (expn - low); + (expn = low || expn = top) --> zero (exps fsort); + (expn < low) --> (low - expn); + ] >>>= fun d_expn -> + (expn >= low) >>>= fun increase -> + ite (expn >= low) (expn - d_expn) (expn + d_expn) >>>= fun expn -> + pack_raw fsort sign expn coef >>>= fun r -> f r d_expn increase + + let sqrt2 fsort' = + let open IEEE754 in + let fsort = Sort.define binary128 in + B.int (bits fsort) + (Bitvec.of_string "0x3fff6a09e667f3bcc908b2fb1366ea95") >>>= fun x -> + convert fsort x B.rne fsort' + + let range_reconstruction fsort x d_expn clz increase = + let idiv = fdiv in + let open B in + sqrt2 fsort >>>= fun sqrt2 -> + unpack_raw fsort x @@ fun sign expn coef -> + if_ (is_zero clz) + ~then_:( + lsb d_expn >>>= fun is_odd -> + ite (is_odd && increase) (succ d_expn) + (ite is_odd (pred d_expn) d_expn) >>>= fun d_expn -> + ite increase + (expn + d_expn / of_int (exps fsort) 2) + (expn - d_expn / of_int (exps fsort) 2) >>>= fun expn -> + pack_raw fsort sign expn coef >>>= fun r -> + ite is_odd (idiv fsort rne r sqrt2) r) + ~else_:( (* for subnormals *) + unsigned (exps fsort) clz >>>= fun clz -> + lsb clz >>>= fun is_odd -> + ite is_odd (pred clz) clz >>>= fun clz -> + clz / of_int (exps fsort) 2 >>>= fun clz -> + expn - d_expn / of_int (exps fsort) 2 - clz >>>= fun expn -> + pack_raw fsort sign expn coef >>>= fun r -> + ite is_odd (idiv fsort rne r sqrt2) r) + + let horner fsort x cfs = + let open B in + let ( + ) = add_or_sub_finite b0 fsort rne in + let ( * ) = fmul_finite fsort rne in + let rec sum y = function + | [] -> y + | c :: cs -> + x * y >>>= fun y -> + y + c >>>= fun y -> + sum y cs in + match cfs with + | [] -> assert false + | [c] -> c + | c::cs -> sum c cs + + let normalize_subnormal fsort x f = + unpack fsort x @@ fun sign expn coef -> + clz coef >>>= fun clz -> + B.(coef lsl clz) >>>= fun coef -> + pack fsort sign expn coef >>>= fun r -> f r clz + + (* x + 1 for [0.0;1.0] and degree = 23 *) + let coefs s = + List.map ~f:(fun x -> B.int s @@ Bitvec.of_string x) [ + "0x3fea283d1af5d3b8b0e9889579874140"; + "0xbfede2119c815a1d59de530775f68488"; + "0x3ff0766a0fd57550fbcaf5949c387bfb"; + "0xbff272ebd07e2237538440a72429c8d9"; + "0x3ff408b110927387f462771163341fec"; + "0xbff5242b7488a8af2abfbadce8148c9e"; + "0x3ff606758477172056f0e80bc6f46e09"; + "0xbff69084710b6e59fb3f53199edfb3dd"; + "0x3ff70e24e277a45deb05c37590719576"; + "0xbff74ec376fd27c8ac279836961f6025"; + "0x3ff78aabb74359ac28498e73a2e9989f"; + "0xbff7c74981d8ac6e123f062aee10b24c"; + "0x3ff805e10b3ea64c491b21abdb7f9796"; + "0xbff82fbf8642c69d9c901bcfcd18b8e1"; + "0x3ff8657a1d1b03f0f63a64018888f16c"; + "0xbff8acff29ed5e34f19ffa86dc9e5b9a"; + "0x3ff907fff45a4ee1cba42e6e42229c85"; + "0xbff94fffff0ee79b4a473fa1503182d5"; + "0x3ff9bffffff219963def0395eb8c25e4"; + "0xbffa3fffffffbb055e0eddae1c1de75f"; + "0x3ffafffffffffe5a8a7bc5d57eea900e"; + "0xbffbfffffffffffd5373ee69a938caf9"; + "0x3ffdffffffffffffff22a39a2a24d3f4"; + "0x3fff0000000000000000000000000000"; + ] + + let fsqrt fsort rm x = + let fsort' = double_precision fsort in + B.ite (is_zero x) (fzero fsort B.b0) + (normalize_subnormal fsort x @@ fun x clz -> + range_reduction fsort x @@ fun y d_expn increase -> + fone fsort B.b0 >>>= fun one -> + fsub fsort rm y one >>>= fun y -> + extend fsort y fsort' >>>= fun y -> + horner fsort' y (coefs (bits fsort')) >>>= fun y -> + B.unsigned (exps fsort') d_expn >>>= fun d_expn -> + range_reconstruction fsort' y d_expn clz increase >>>= fun r -> + truncate fsort' r rm fsort) + + let test fsort x = + let fsort' = double_precision fsort in + normalize_subnormal fsort x @@ fun x clz -> + range_reduction fsort x @@ fun y d_expn increase -> + fone fsort B.b0 >>>= fun one -> + fsub fsort B.rne y one >>>= fun y -> + extend fsort y fsort' >>>= fun y -> + horner fsort' y (coefs (bits fsort')) >>>= fun y -> + B.unsigned (exps fsort') d_expn >>>= fun d_expn -> + range_reconstruction fsort' y d_expn clz increase >>>= fun r -> + truncate fsort' r B.rne fsort +end diff --git a/plugins/bil/bil_float.mli b/plugins/bil/bil_float.mli new file mode 100644 index 000000000..e48c456f1 --- /dev/null +++ b/plugins/bil/bil_float.mli @@ -0,0 +1,20 @@ +open Bap_knowledge +open Bap_core_theory + +open Theory + +type ('b,'e,'t,'s) fsort = (('b,'e,'t) IEEE754.t,'s) format Float.t Value.sort + +module Make(B : Theory.Core) : sig + + val fadd : ('b,'e,'t,'s) fsort -> rmode -> 's bitv -> 's bitv -> 's bitv + val fsub : ('b,'e,'t,'s) fsort -> rmode -> 's bitv -> 's bitv -> 's bitv + val fmul : ('b,'e,'t,'s) fsort -> rmode -> 's bitv -> 's bitv -> 's bitv + val fdiv : ('b,'e,'t,'s) fsort -> rmode -> 's bitv -> 's bitv -> 's bitv + val fsqrt : ('b,'e,'t,'s) fsort -> rmode -> 's bitv -> 's bitv + + val cast_int : ('a, 'b, 'c, 'd) fsort -> 'e Bitv.t Value.sort -> 'd bitv -> 'e bitv + val cast_float : ('a, 'b, 'c, 'd) fsort -> rmode -> 'e bitv -> 'd bitv + val cast_float_signed : ('a, 'b, 'c, 'd) fsort -> rmode -> 'e bitv -> 'd bitv + val convert : ('b, 'e, 't, 's) fsort -> 's bitv -> rmode -> ('b, 'a, 'c, 'd) fsort -> 'd bitv +end diff --git a/plugins/bil/bil_float_tests.ml b/plugins/bil/bil_float_tests.ml new file mode 100644 index 000000000..5e2b4e60e --- /dev/null +++ b/plugins/bil/bil_float_tests.ml @@ -0,0 +1,480 @@ +open Core_kernel +open OUnit2 +open Bap_plugins.Std +open Bap_primus.Std +open Bap.Std +open Monads.Std +open Bap_knowledge +open Bap_core_theory + +module G = Bil_float.Make(Theory.Manager) + +[@@@warning "-3"] + +let () = Plugins.run ~exclude:["bil"] () + +let () = Bil_semantics.init () + +let enum_bits w = + let bits = Word.(enum_bits w BigEndian) in + let b_len = Seq.length bits in + let w_len = Word.bitwidth w in + if b_len > w_len then + Seq.drop bits (b_len - w_len) + else bits + +let float_bits w = + let bits = enum_bits w in + let (@@) = sprintf "%s%d" in + Seq.fold bits ~init:"" ~f:(fun s x -> + if x then s @@ 1 + else s @@ 0) + +let float64_bits x = + let w = Word.of_int64 (Int64.bits_of_float x) in + let bits = enum_bits w in + let (@@) = sprintf "%s%d" in + Seq.foldi bits ~init:"" ~f:(fun i acc x -> + let a = + if i = 1 || i = 12 then "_" + else "" in + let s = sprintf "%s%s" acc a in + if x then s @@ 1 + else s @@ 0) + +let deconstruct x = + let wi = Word.to_int_exn in + let y = Int64.bits_of_float x in + let w = Word.of_int64 y in + let expn = Word.extract_exn ~hi:62 ~lo:52 w in + let bias = Word.of_int ~width:11 1023 in + let expn' = Word.(signed (expn - bias)) in + let frac = Word.extract_exn ~hi:51 w in + printf "ocaml %f: bits %s, 0x%LX\n" x (float64_bits x) y; + printf "ocaml %f: biased/unbiased expn %d/%d, coef 0x%x\n" + x (wi expn) (wi expn') (wi frac) + +type bits8 +type bits24 +type bits32 +type float32 = ((int,bits8,bits24) IEEE754.ieee754,bits32) format float sort + +type bits11 +type bits53 +type bits64 +type float64 = ((int,bits11,bits53) IEEE754.ieee754,bits64) format float sort + +type bits15 +type bits112 +type bits128 +type float128 = ((int,bits15,bits112) IEEE754.ieee754,bits128) format float sort + + +let exps_32 : bits8 bitv sort = Bits.define 8 +let sigs_32 : bits24 bitv sort = Bits.define 24 +let bitv_32 : bits32 bitv sort = Bits.define 32 +let fsort32 : float32 = IEEE754.(Sort.define binary32) + +let exps_64 : bits11 bitv sort = Bits.define 11 +let sigs_64 : bits53 bitv sort = Bits.define 53 +let bitv_64 : bits64 bitv sort = Bits.define 64 +let fsort64 : float64 = IEEE754.(Sort.define binary64) + +let exps_128 : bits15 bitv sort = Bits.define 15 +let sigs_128 : bits112 bitv sort = Bits.define 112 +let bitv_128 : bits128 bitv sort = Bits.define 128 +let fsort128 : float128 = IEEE754.(Sort.define binary128) + +type binop = [ + | `Add + | `Sub + | `Mul + | `Div +] [@@deriving sexp] + +type unop = [ + | `Sqrt + | `Rsqrt + ] [@@deriving sexp] + +type cast_int = [ + | `Of_uint + | `Of_sint +] [@@deriving sexp] + +type cast_float = [ + | `Of_float +] [@@deriving sexp] + +type test = [ + binop | cast_int | cast_float +] [@@deriving sexp] + +let test_name op = Sexp.to_string (sexp_of_test (op :> test)) + + +module Machine = struct + type 'a m = 'a + include Primus.Machine.Make(Monad.Ident) +end +module Main = Primus.Machine.Main(Machine) +module Eval = Primus.Interpreter.Make(Machine) + +let proj = + let nil = Memmap.empty in + Project.Input.create `x86_64 "/bin/true" ~code:nil ~data:nil |> + Project.create |> + ok_exn + +let word_of_float x = Word.of_int64 (Int64.bits_of_float x) +let float_of_word x = Int64.float_of_bits (Word.to_int64_exn x) + +let exp x = + let open Knowledge.Syntax in + let x = x >>| Value.semantics in + match Knowledge.run x Knowledge.empty with + | Error _ -> assert false + | Ok (s,_) -> Semantics.get Bil.Domain.exp s + +let eval ?(name="") ~expected test _ctxt = + let open Machine.Syntax in + let float64_bits w = + let x = Word.signed w |> Word.to_int64_exn in + float64_bits (Int64.float_of_bits x) in + match exp test with + | None -> assert false + | Some e -> + let check = + Eval.exp e >>| fun r -> + let r = Primus.Value.to_word r in + let equal = Word.(r = expected) in + if not equal then + let () = printf "\nFAIL %s\n" name in + let () = printf "expected: %s\n" (float64_bits expected) in + printf "got : %s\n" (float64_bits r); + printf "got %s\n" (Word.to_string r); + assert_bool name equal in + match Main.run proj check with + | Primus.Normal,_ -> () + | _ -> raise (Failure "Something went wrong") + +let knowledge_of_word sort w = Theory.Manager.int sort w +let knowledge_of_float x = knowledge_of_word bitv_64 (word_of_float x) +let knowledge_of_int64 x = knowledge_of_word bitv_64 (Word.of_int64 x) + +let gfloat_of_int x = + let bits = Word.of_int ~width:64 x in + knowledge_of_word bitv_64 bits + +let binop op x y ctxt = + let bits = Int64.bits_of_float in + let name = sprintf "%Lx %s %Lx\n" (bits x) (test_name op) (bits y) in + let real, op = match op with + | `Add -> x +. y, G.fadd + | `Sub -> x -. y, G.fsub + | `Mul -> x *. y, G.fmul + | `Div -> x /. y, G.fdiv in + let test = op fsort64 G.rne (knowledge_of_float x) (knowledge_of_float y) in + eval ~name ~expected:(word_of_float real) test ctxt + +let cast_int cast x ctxt = + let name = sprintf "%s %d\n" (test_name cast) x in + let expected = word_of_float (float x) in + let op = match cast with + | `Of_uint -> G.cast_float + | `Of_sint -> G.cast_float_signed in + let test = op fsort64 G.rne (gfloat_of_int x) in + eval ~name ~expected test ctxt + +let cast_float x ctxt = + let name = sprintf "%s %g\n" (test_name `Of_float) x in + let expected = Word.of_int ~width:64 (int_of_float x) in + let test = G.cast_int fsort64 bitv_64 (knowledge_of_float x) in + eval ~name ~expected test ctxt + +let sqrt_exp x ctxt = + let name = sprintf "sqrt %g\n" x in + let expected = Float.sqrt x |> word_of_float in + let x = Theory.Manager.var (Var.define bitv_64 "x") in + let test = G.fsqrt fsort64 G.rne x in + eval ~name ~expected test ctxt + +let sqrt_ x ctxt = + let name = sprintf "sqrt %g %Lx %s" x (Int64.bits_of_float x) (float64_bits x) in + let expected = Float.sqrt x |> word_of_float in + let test = G.fsqrt fsort64 G.rne (knowledge_of_float x) in + eval ~name ~expected test ctxt + +let ( + ) = binop `Add +let ( - ) = binop `Sub +let ( * ) = binop `Mul +let ( / ) = binop `Div + +let of_uint = cast_int `Of_uint +let of_sint = cast_int `Of_sint +let to_int = cast_float + +let make_float s e c = + let s = Word.of_int ~width:1 s in + let e = Word.of_int ~width:11 e in + let c = Word.of_int ~width:52 c in + let w = Word.(concat (concat s e) c) in + Word.signed w |> Word.to_int64_exn |> Int64.float_of_bits + +let neg x = ~-. x +let nan = Float.nan +let inf = Float.infinity +let ninf = Float.neg_infinity +let smallest_nonzero = make_float 0 0 1 +let some_small = make_float 0 0 2 +let biggest_subnormal = make_float 0 0 0xFFFF_FFFF_FFFF_F +let smallest_normal = Float.(biggest_subnormal + smallest_nonzero) +let biggest_normal = make_float 0 2046 0xFFFF_FFFF_FFFF_F + +let () = Random.self_init () + +let random = Random.int +let random_elt xs = List.nth_exn xs @@ random (List.length xs) + +let random_int ~from ~to_ = + let open Caml in + let max = to_ - from in + let x = random max in + x + from + +let random_float () = + let expn () = random_int ~from:0 ~to_:2046 in + let frac () = Random.int 0xFFFFFFFFFFFFF in + let sign () = Random.int 2 in + let make () = + let expn = expn () in + let frac = frac () in + make_float (sign ()) expn frac in + let small () = + let x = Random.int 42 in + let y = Int64.of_int x in + Random.float (Int64.float_of_bits y) in + random_elt [make (); make (); small (); make (); make (); small (); make ()] + +let random_floats ~times ops = + List.init times ~f:(fun i -> + let f = + match random_elt ops with + | `Sqrt -> + let x = Float.abs @@ random_float () in + fun (ctxt : test_ctxt) -> sqrt_ x ctxt + | `Add | `Sub | `Mul | `Div as op -> + let x = random_float () in + let y = random_float () in + fun ctxt -> binop op x y ctxt in + (sprintf "random%d" i) >:: f) + +let of_bits = Int64.float_of_bits + +let convert_128_to_64 x expected _ctxt = + let open Machine.Syntax in + let from = fsort128 and to_ = fsort64 in + let expected = Word.of_int64 @@ Int64.bits_of_float expected in + match exp (G.convert from x G.rne to_) with + | None -> assert false + | Some e -> + let check = + Eval.exp e >>| fun r -> + let r = Primus.Value.to_word r in + let equal = Word.(r = expected) in + assert_bool "convert_64_to_32" equal in + match Main.run proj check with + | Primus.Normal,_ -> () + | _ -> raise (Failure "Something went wrong") + +let convert_64_to_32 x _ctxt = + let open Machine.Syntax in + let from = fsort64 and to_ = fsort32 in + let expected = Word.of_int32 @@ Int32.bits_of_float x in + let x = knowledge_of_float x in + match exp (G.convert from x G.rne to_) with + | None -> assert false + | Some e -> + let check = + Eval.exp e >>| fun r -> + let r = Primus.Value.to_word r in + let equal = Word.(r = expected) in + assert_bool "convert_64_to_32" equal in + match Main.run proj check with + | Primus.Normal,_ -> () + | _ -> raise (Failure "Something went wrong") + +let one_128 = + knowledge_of_word bitv_128 (Word.of_string "0x3fff0000000000000000000000000000:128u") + +let suite () = + + let almost_inf32 = of_bits 0x47EFFFFFeFFFFFFFL in + let shouldbe_inf32 = of_bits 0x47EFFFFFfFFFFFFFL in + + "Gfloat" >::: [ + + (* of uint *) + "of uint 42" >:: of_uint 42; + "of uint 0" >:: of_uint 0; + "of uint 1" >:: of_uint 1; + "of uint 2" >:: of_uint 2; + "of uint 10" >:: of_uint 10; + "of uint 13213" >:: of_uint 13213; + "of uint 45676" >:: of_uint 45667; + "of uint 98236723" >:: of_uint 98236723; + "of uint 0xFFFF_FFFF_FFFF_FFF" >:: of_uint 0xFFFF_FFFF_FFFF_FFF; + + (* of sint *) + "of sint -42" >:: of_sint (-42); + "of sint 0" >:: of_sint 0; + "of sint -1" >:: of_sint 1; + "of sint -2" >:: of_sint (-2); + "of sint -10" >:: of_sint (-10); + "of sint -13213" >:: of_sint (-13213); + "of sint -45676" >:: of_sint (-45667); + "of sint -98236723" >:: of_sint (-98236723); + + (* to int *) + "to int 42.42" >:: to_int 42.42; + "to int 0.42" >:: to_int 0.42; + "to int 0.99999999999" >:: to_int 0.99999999999; + "to int 13123120.98882344542" >:: to_int 13123120.98882344542; + "to int -42.42" >:: to_int (-42.42); + "to int -13123120.98882344542" >:: to_int (-13123120.98882344542); + + (* convert float to float *) + "convert almost inf32" >:: convert_64_to_32 almost_inf32; + "should be inf32" >:: convert_64_to_32 shouldbe_inf32; + "one128 to one64" >:: convert_128_to_64 one_128 1.0; + + (* add *) + "0.0 + 0.5" >:: 0.0 + 0.5; + "4.2 + 2.3" >:: 4.2 + 2.3; + "4.2 + 2.98" >:: 4.2 + 2.98; + "2.2 + 4.28" >:: 2.2 + 4.28; + "2.2 + 2.46" >:: 2.2 + 2.46; + "2.2 + -4.28" >:: 2.2 + (neg 4.28); + "-2.2 + 4.28" >:: (neg 2.2) + 4.28; + "0.0000001 + 0.00000002" >:: 0.0000001 + 0.00000002; + "123213123.23434 + 56757.05656549151" >:: 123213123.23434 + 56757.05656549151; + "nan + nan" >:: nan + nan; + "inf + inf" >:: inf + inf; + "-inf + -inf" >:: ninf + ninf; + "nan + -inf" >:: nan + ninf; + "-inf + nan" >:: ninf + nan; + "nan + inf" >:: nan + inf; + "inf + nan" >:: inf + nan; + "-inf + inf" >:: ninf + inf; + "inf + -inf" >:: inf + ninf; + "0.0 + small" >:: 0.0 + smallest_nonzero; + "small + small" >:: smallest_nonzero + some_small; + "biggest_sub + small" >:: biggest_subnormal + smallest_nonzero; + "biggest_normal + small" >:: biggest_normal + smallest_nonzero; + "biggest_normal + biggest_subnorm" >:: biggest_normal + biggest_subnormal; + "near inf case" >:: make_float 0 2046 0xFFFF_FFFF_FFFF_FFF + make_float 0 2046 1; + + (* sub *) + "4.2 - 2.28" >:: 4.2 - 2.28; + "4.28 - 2.2" >:: 4.28 - 2.2; + "2.2 - 4.28" >:: 2.2 - 4.28; + "2.2 - 2.6" >:: 2.2 - 2.6; + "0.0 - 0.0" >:: 0.0 - 0.0; + "4.2 - 4.2" >:: 4.2 - 4.2; + "2.2 - -4.28" >:: 2.2 - (neg 4.28); + "-2.2 - 2.46" >:: (neg 2.2) - 2.46; + "-2.2 - -2.46" >:: (neg 2.2) - (neg 2.46); + "2.0 - 2.0" >:: 2.0 - 2.0; + "-2.0 + 2.0" >:: (neg 2.0) + 2.0; + "0.0000001 - 0.00000002" >:: 0.0000001 - 0.00000002; + "0.0 - 0.00000001" >:: 0.0 - 0.0000001; + "123213123.23434 - 56757.05656549151" >:: 123213123.23434 - 56757.05656549151; + "nan - nan" >:: nan - nan; + "inf - inf" >:: inf - inf; + "-inf - -inf" >:: ninf - ninf; + "nan - -inf" >:: nan - ninf; + "-inf - nan" >:: ninf - nan; + "nan - inf" >:: nan - inf; + "inf - nan" >:: inf - nan; + "-inf - inf" >:: ninf - inf; + "inf - -inf" >:: inf - ninf; + "0.0 - small" >:: 0.0 - smallest_nonzero; + "small - 0.0" >:: smallest_nonzero - 0.0; + "small - small" >:: smallest_nonzero - smallest_nonzero; + "small - small'" >:: smallest_nonzero - some_small; + "small' - small" >:: some_small - smallest_nonzero; + "smalles_norm - small" >:: smallest_normal - smallest_nonzero; + "biggest_sub - small" >:: biggest_subnormal - smallest_nonzero; + "biggest_normal - small" >:: biggest_normal - smallest_nonzero; + "biggest_normal - biggest_subnorm" >:: biggest_normal - biggest_subnormal; + "biggest_subnorm - biggest_normal" >:: biggest_subnormal - biggest_normal; + "near inf case" >:: make_float 1 2046 0xFFFF_FFFF_FFFF_FFF - make_float 0 2046 1; + + (* mul *) + "1.0 * 2.5" >:: 1.0 * 2.5; + "2.5 * 0.5" >:: 2.5 * 0.5; + "4.2 * 3.4" >:: 4.2 * 3.4; + "0.01 * 0.02" >:: 0.01 * 0.02; + "1.0 * 0.5" >:: 1.0 * 0.5; + "1.0 * -0.5" >:: 1.0 * (neg 0.5); + "- 1.0 * -0.5" >:: (neg 1.0) * (neg 0.5); + "123734.86124324198 * 23967986786.4834517" >:: 123734.86124324198 * 23967986786.4834517; + "nan * nan" >:: nan * nan; + "inf * inf" >:: inf * inf; + "-inf * -inf" >:: ninf * ninf; + "nan * -inf" >:: nan * ninf; + "-inf * nan" >:: ninf * nan; + "nan * inf" >:: nan * inf; + "inf * nan" >:: inf * nan; + "-inf * inf" >:: ninf * inf; + "inf * -inf" >:: inf * ninf; + "0.0 * big" >:: 0.0 * biggest_normal; + "0.0 * small" >:: 0.0 * biggest_subnormal; + "0.0 * small'" >:: 0.0 * smallest_nonzero; + "2.0 * small" >:: 2.0 * smallest_nonzero; + "1123131.45355 * small" >:: 1123131.45355 * smallest_nonzero; + "small * small" >:: smallest_nonzero * some_small; + "smallest normal * small" >:: smallest_normal * smallest_nonzero; + "biggest subnormal * small" >:: biggest_subnormal * smallest_nonzero; + "biggest normal * small" >:: biggest_normal * smallest_nonzero; + "biggest normal * 2.0" >:: biggest_normal * 2.0; + "biggest normal * biggest subnormal" >:: biggest_normal * biggest_subnormal; + "biggest subnormal * small" >:: biggest_subnormal * smallest_nonzero; + "biggest subnormal * biggest subnormal" >:: biggest_subnormal * biggest_subnormal; + "biggest normal * biggest normal" >:: biggest_normal * biggest_normal; + "test with underflow" >:: of_bits 974381688320862858L * of_bits (-5590604654947855237L); + "test1" >:: of_bits 0xec9059c2619517d5L + of_bits 0x6c52387cdb6aefadL; + "test2" >:: of_bits 0xa10d89faaef35527L - of_bits 0xa130e0fee63e0e6fL; + "test3" >:: of_bits 0x400199999999999aL - of_bits 0x4004cccccccccccdL; + "test4" >:: of_bits 0x7fefffffffffffffL - of_bits 0xfffffffffffffL; + + (* div *) + "2.0 / 0.5" >:: 2.0 / 0.5; + "1.0 / 3.0" >:: 1.0 / 3.0; + "3.0 / 32.0" >:: 3.0 / 32.0; + "324.32423 / 1.2" >:: 324.32423 / 1.2; + "2.4 / 3.123131" >:: 2.4 / 3.123131; + "0.1313134 / 0.578465631" >:: 0.1313134 / 0.578465631; + "9991132.2131363434 / 2435.05656549153" >:: 9991132.2131363434 / 2435.05656549153; + "nan / nan" >:: nan / nan; + "inf / inf" >:: inf / inf; + "-inf / -inf" >:: ninf / ninf; + "nan / -inf" >:: nan / ninf; + "-inf / nan" >:: ninf / nan; + "nan / inf" >:: nan / inf; + "inf / nan" >:: inf / nan; + "-inf / inf" >:: ninf / inf; + "inf / -inf" >:: inf / ninf; + "0.0 / small" >:: 0.0 / smallest_nonzero; + "small / small'" >:: smallest_nonzero / some_small; + "small' / small" >:: some_small / smallest_nonzero; + "small / small" >:: smallest_nonzero / smallest_nonzero; + "smallest_norm / small" >:: smallest_normal / smallest_nonzero; + "biggest_sub / small" >:: biggest_subnormal / smallest_nonzero; + "biggest_normal / small" >:: biggest_normal / smallest_nonzero; + "biggest_normal / biggest_subnorm" >:: biggest_normal / biggest_subnormal; + "biggest_normal / smallest_normal" >:: biggest_normal / smallest_normal; + + ] @ random_floats ~times:100 [`Add; `Sub; `Mul; `Div; `Sqrt] + +let () = run_test_tt_main (suite ()) diff --git a/plugins/bil/bil_float_tests.mli b/plugins/bil/bil_float_tests.mli new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/plugins/bil/bil_float_tests.mli @@ -0,0 +1 @@ + diff --git a/plugins/bil/bil_ir.ml b/plugins/bil/bil_ir.ml new file mode 100644 index 000000000..8ca769bf9 --- /dev/null +++ b/plugins/bil/bil_ir.ml @@ -0,0 +1,338 @@ +open Core_kernel +open Bap.Std +open Format + +open Bap_knowledge +open Bap_core_theory + +open Knowledge.Syntax + +include Self() + +type blk = { + name : Theory.Label.t; + defs : def term list; + jmps : jmp term list; +} [@@deriving bin_io] + +type cfg = { + blks : blk list; + entry : Theory.Label.t; +} [@@deriving bin_io] + +type t = cfg option + +module BIR = struct + type t = blk term list + + + let add_def = Blk.Builder.add_def + let add_jmp = Blk.Builder.add_jmp + let add_term add (blks,b) t = add b t; (blks,b) + let add_terms terms add (blks,b) = + List.fold ~init:(blks,b) (List.rev terms) + ~f:(add_term add) + + (* creates a tid block from the IR block, + expands non-representable terms into empty blocks. + postconditions: + - the list is not empty; + - the first element of the list, is the entry + *) + let make_blk {name; defs; jmps} = + let b = Blk.Builder.create ~tid:name () in + ([],b) |> + add_terms defs add_def |> + add_terms jmps add_jmp |> fun (blks,b) -> + Blk.Builder.result b :: blks |> + List.rev + + + (* postconditions: + - the list is not empty + - the first block is the entry block + - the last block is the exit block + *) + let reify {entry; blks} = + List.fold blks ~init:(None,[]) ~f:(fun (s,blks) b -> + match make_blk b with + | [] -> assert false + | blk::blks' -> + if Tid.equal entry (Term.tid blk) + then (Some blk, List.rev_append blks' blks) + else (s, List.rev_append (blk::blks') blks)) |> function + | None,[] -> [] + | None,_ -> failwith "No entry in IR builder" + | Some x, xs -> x :: xs +end + +let pp_cfg ppf ir = + fprintf ppf "%a" (pp_print_list Blk.pp) (BIR.reify ir) + +let inspect cfg = Sexp.Atom (asprintf "%a" pp_cfg cfg) + +let null = KB.Symbol.intern "null" Theory.Program.cls +let is_null x = + null >>| fun null -> + Theory.Label.equal null x + +let domain = KB.Domain.optional ~inspect "graph" + ~equal:(fun x y -> Theory.Label.equal x.entry y.entry) + +let graph = + KB.Class.property Theory.Program.Semantics.cls "ir-graph" domain + ~persistent:(KB.Persistent.of_binable (module struct + type t = cfg option [@@deriving bin_io] + end)) + +let slot = graph + +module IR = struct + include Theory.Core.Empty + let ret = Knowledge.return + + let blk tid = {name=tid; defs=[]; jmps=[]} + + + let def = (fun x -> x.defs), (fun x d -> {x with defs = d}) + let jmp = (fun x -> x.jmps), (fun x d -> match x.jmps with + | t :: _ when Option.is_none (Jmp.guard t) -> x + | _ -> {x with jmps = d}) + + let push_to_blk (get,put) blk elt = + put blk @@ elt :: get blk + + + let push fld elt cfg : cfg = match cfg with + | {blks=[]} -> assert false (* the precondition - we have a block *) + | {blks=blk::blks} -> { + cfg with blks = push_to_blk fld blk elt :: blks + } + + let fresh = KB.Object.create Theory.Program.cls + + let (++) b j = push_to_blk jmp b j + + let reify x : cfg knowledge = match KB.Value.get graph x with + | Some g -> KB.return g + | None -> null >>| fun entry -> { + blks = []; + entry; + } + + let empty = Theory.Effect.empty Theory.Effect.Sort.bot + let ret cfg = !!(KB.Value.put graph empty (Some cfg)) + let data cfg = ret cfg + let ctrl cfg = ret cfg + + let set v x = + x >>= fun x -> + fresh >>= fun entry -> + fresh >>= fun tid -> + data { + entry; + blks = [{name=entry; jmps=[]; defs=[Def.reify ~tid v x]}] + } + + let goto ?cnd ~tid dst = + Jmp.reify ?cnd ~tid ~dst:(Jmp.resolved dst) () + + (** reifies a [while () ] loop to + + {v + head: + goto tail + loop: + + tail: + when goto loop + v} + + or to just + + {v + head: + when goto head + v} + + if is empty. + *) + let repeat cnd body = + cnd >>= fun cnd -> + body >>= reify >>= function + | {blks=[]} -> + fresh >>= fun head -> + fresh >>= fun tid -> + data { + entry = head; + blks = [{ + name = head; + defs = []; + jmps = [goto ~cnd ~tid head]}]} + | {entry=loop; blks=b::blks} -> + fresh >>= fun head -> + fresh >>= fun tail -> + fresh >>= fun jmp1 -> + fresh >>= fun jmp2 -> + fresh >>= fun jmp3 -> + data { + entry = head; + blks = blk tail ++ goto ~tid:jmp1 ~cnd loop :: + blk head ++ goto ~tid:jmp2 tail :: + b ++ goto ~tid:jmp3 tail :: + blks + } + + let branch cnd yes nay = + fresh >>= fun head -> + fresh >>= fun tail -> + cnd >>= fun cnd -> + yes >>= fun yes -> + nay >>= fun nay -> + reify yes >>= fun yes -> + reify nay >>= fun nay -> + let jump = goto ~cnd in + match yes, nay with + | {entry; blks=[{defs=[]; jmps=[j]} as blk]},{blks=[]} -> ret { + entry; + blks = [{blk with defs=[]; jmps=[Jmp.with_guard j (Some cnd)]}] + } + | {entry=lhs; blks=b::blks},{blks=[]} -> + fresh >>= fun jmp1 -> + fresh >>= fun jmp2 -> + fresh >>= fun jmp3 -> + ret { + entry = head; + blks = + blk tail :: + blk head ++ + jump ~tid:jmp1 lhs ++ + goto ~tid:jmp2 tail :: + b ++ goto ~tid:jmp3 tail :: + blks + } + | {blks=[]}, {entry=rhs; blks=b::blks} -> + fresh >>= fun jmp1 -> + fresh >>= fun jmp2 -> + fresh >>= fun jmp3 -> + ret { + entry = head; + blks = + blk tail :: + blk head ++ + jump ~tid:jmp1 tail ++ + goto ~tid:jmp2 rhs :: + b ++ goto ~tid:jmp3 tail :: + blks + } + | {entry=lhs; blks=yes::ayes}, {entry=rhs; blks=nay::nays} -> + fresh >>= fun jmp1 -> + fresh >>= fun jmp2 -> + fresh >>= fun jmp3 -> + fresh >>= fun jmp4 -> + ret { + entry = head; + blks = + blk tail :: + blk head ++ + jump ~tid:jmp1 lhs ++ + goto ~tid:jmp2 rhs :: + yes ++ goto ~tid:jmp3 tail :: + nay ++ goto ~tid:jmp4 tail :: + List.rev_append ayes nays + } + | {blks=[]}, {blks=[]} -> + fresh >>= fun jmp1 -> + fresh >>= fun jmp2 -> + ret { + entry = head; + blks = [ + blk tail; + blk head ++ jump ~tid:jmp1 tail ++ goto ~tid:jmp2 tail + ] + } + + let jmp dst = + fresh >>= fun entry -> + dst >>= fun dst -> + fresh >>= fun tid -> + ctrl { + entry; + blks = [blk entry ++ Jmp.reify ~tid ~dst:(Jmp.indirect dst) ()] + } + + let appgraphs fst snd = + match fst, snd with + | {entry; blks}, {blks=[]} + | {blks=[]}, {entry; blks} -> ret {entry; blks} + | {entry; blks={jmps=[]} as x :: xs},{blks=[y]} -> ret { + entry; + blks = {x with defs = y.defs @ x.defs; jmps = y.jmps} :: xs + } + | {entry; blks=x::xs}, {entry=snd; blks=y::ys} -> + fresh >>= fun tid -> ret { + entry; + blks = + y :: + x ++ goto ~tid snd :: + List.rev_append xs ys + } + + let (>>->) x f = x >>= reify >>= f + + let seq fst snd = + fst >>-> fun fst -> + snd >>-> fun snd -> + appgraphs fst snd + + let do_goto dst = + fresh >>= fun entry -> + fresh >>= fun tid -> + ctrl { + entry; + blks = [blk entry ++ goto ~tid dst] + } + + + let blk _entry defs jmps = + defs >>-> fun defs -> + jmps >>-> fun jmps -> + appgraphs defs jmps + (* match defs, jmps with + * | {blks=[]}, {blks=[]} -> ret { + * entry; + * blks = [blk entry] + * } + * | {blks=[]}, {entry=next; blks=b::blks} + * | {entry=next; blks=b::blks}, {blks=[]} -> + * fresh >>= fun tid -> + * ret { + * entry; + * blks = b :: blk entry ++ goto ~tid next :: blks + * } + * | {entry=fst; blks=x::xs}, + * {entry=snd; blks=y::ys} -> + * fresh >>= fun jmp1 -> + * fresh >>= fun jmp2 -> + * ret { + * entry; + * blks = + * y :: + * blk entry ++ goto ~tid:jmp1 fst :: + * x ++ goto ~tid:jmp2 snd :: + * List.rev_append xs ys + * } *) + let goto = do_goto +end + +let reify = function + | None -> [] + | Some g -> BIR.reify g + +let init () = + Theory.register + ~desc:"CFG generator" + ~name:"cfg" + (module IR) + +module Theory = IR diff --git a/plugins/bil/bil_ir.mli b/plugins/bil/bil_ir.mli new file mode 100644 index 000000000..600ba984e --- /dev/null +++ b/plugins/bil/bil_ir.mli @@ -0,0 +1,11 @@ +open Bap.Std +open Bap_core_theory + +type t + +val slot : (Theory.Program.Semantics.cls, t) KB.slot + +val reify : t -> blk term list +val init : unit -> unit + +module Theory : Theory.Core diff --git a/plugins/bil/bil_lifter.ml b/plugins/bil/bil_lifter.ml new file mode 100644 index 000000000..d61d1fb08 --- /dev/null +++ b/plugins/bil/bil_lifter.ml @@ -0,0 +1,372 @@ +open Core_kernel +open Bap.Std +open Bap_future.Std +open Bap_knowledge +open Bap_core_theory +open Monads.Std + +open Knowledge.Syntax + +open Theory.Parser +include Self() + +module Call = struct + let prefix = "bil-fixup:" + let extern name = + let dst = sprintf "%s%s" prefix name in + Bil.special dst + + let is_extern name = + String.is_prefix name ~prefix + + let dst = + String.chop_prefix_exn ~prefix +end + +module BilParser = struct + type context = [`Bitv | `Bool | `Mem ] [@@deriving sexp] + let fail exp ctx = + error "ill-formed expression in %a ctxt: %a" + Sexp.pp (sexp_of_context ctx) Exp.pp exp + + type exp = Bil.exp + module Var = Bap.Std.Var + let rec uncat acc : exp -> exp list = function + | Concat ((Concat (x,y)), z) -> uncat (y::z::acc) x + | Concat (x,y) -> x::y::acc + | x -> x::acc + + let bits_of_var v = match Var.typ v with + | Imm x -> x + | _ -> failwith "not a bitv var" + + let byte x = Bil.int (Word.of_int ~width:8 x) + let is_big e = + Bil.int @@ + if e = BigEndian then Word.b1 else Word.b0 + + let is_reg v = match Var.typ v with + | Type.Imm 1 | Type.Mem _ -> false + | _ -> true + + let is_bit v = match Var.typ v with + | Type.Imm 1 -> true + | _ -> false + + let is_mem v = match Var.typ v with + | Type.Mem _ -> true + | _ -> false + + let bitv : type t r. (t,exp,r) bitv_parser = + fun (module S) -> function + | Cast (HIGH,n,x) -> S.high n x + | Cast (LOW,n,x) -> S.low n x + | Cast (UNSIGNED,n,x) -> S.unsigned n x + | Cast (SIGNED,n,x) -> S.signed n x + | BinOp(PLUS,x,y) -> S.add x y + | BinOp(MINUS,x,y) -> S.sub x y + | BinOp(TIMES,x,y) -> S.mul x y + | BinOp(DIVIDE,x,y) -> S.div x y + | BinOp(SDIVIDE,x,y) -> S.sdiv x y + | BinOp(MOD,x,y) -> S.modulo x y + | BinOp(SMOD,x,y) -> S.smodulo x y + | BinOp(LSHIFT,x,y) -> S.lshift x y + | BinOp(RSHIFT,x,y) -> S.rshift x y + | BinOp(ARSHIFT,x,y) -> S.arshift x y + | BinOp(AND,x,y) -> S.logand x y + | BinOp(OR,x,y) -> S.logor x y + | BinOp(XOR,x,y) -> S.logxor x y + | UnOp(NEG,x) -> S.neg x + | UnOp(NOT,x) -> S.not x + | Load(m,k,e,s) -> + S.load_word (Size.in_bits s) (is_big e) m k + | Var v -> S.var (Var.name v) (bits_of_var v) + | Int x -> S.int (Word.to_bitvec x) (Word.bitwidth x) + | Let (v,y,z) when is_bit v -> S.let_bit (Var.name v) y z + | Let (v,y,z) when is_reg v -> S.let_reg (Var.name v) y z + | Let (v,y,z) when is_mem v -> S.let_mem (Var.name v) y z + | Ite (x,y,z) -> S.ite x y z + | Extract (hi,lo,x) -> + let s = max 0 (hi-lo+1) in + S.extract s (byte hi) (byte lo) x + | Concat (_,_) as cat -> S.concat (uncat [] cat) + | Unknown (_, Imm s) -> S.unknown s + | BinOp ((EQ|NEQ|LT|LE|SLT|SLE), _, _) as op -> + S.ite op (Int Word.b1) (Int Word.b0) + + (* ill-formed expressions *) + | Let _ + | Store (_, _, _, _, _) + | Unknown (_, (Mem _|Unk)) as exp -> fail exp `Bitv; S.error + + + + let mem : type t. (t,exp) mem_parser = + fun (module S) -> function + | Unknown (_,Mem (k,v)) -> + S.unknown (Size.in_bits k) (Size.in_bits v) + | Store (m,k,v,e,_) -> + S.store_word (is_big e) m k v + | Var v -> + let with_mem_types v f = match Var.typ v with + | Mem (ks,vs) -> f (Size.in_bits ks) (Size.in_bits vs) + | _ -> fail (Var v) `Mem; S.error in + with_mem_types v (S.var (Var.name v)) + | Let (v,y,z) when is_bit v -> S.let_bit (Var.name v) y z + | Let (v,y,z) when is_reg v -> S.let_reg (Var.name v) y z + | Let (v,y,z) when is_mem v -> S.let_mem (Var.name v) y z + | Ite (c,x,y) -> S.ite c x y + (* the rest is ill-formed *) + | Let _ + | Unknown (_,_) + | Load (_,_,_,_) + | BinOp (_,_,_) + | UnOp (_,_) + | Int _ + | Cast (_,_,_) + | Extract (_,_,_) + | Concat (_,_) as exp -> fail exp `Mem; S.error + + let float _ _ = assert false + let rmode _ _ = assert false + + let bool : type t r. (t,exp,r) bool_parser = + fun (module S) -> function + | Var x -> S.var (Var.name x) + | Int x -> S.int (Word.to_bitvec x) + | Cast (HIGH,1,x) -> S.high x + | Cast (LOW,1,x) -> S.low x + | BinOp (EQ,x,y) -> S.eq x y + | BinOp (NEQ,x,y) -> S.neq x y + | BinOp (LT,x,y) -> S.lt x y + | BinOp (LE,x,y) -> S.le x y + | BinOp (SLT,x,y) -> S.slt x y + | BinOp (SLE,x,y) -> S.sle x y + | BinOp (OR,x,y) -> S.logor x y + | BinOp (AND,x,y) -> S.logand x y + | BinOp (XOR,x,y) -> S.logxor x y + | UnOp (NOT,x) -> S.not x + | Let (v,y,z) when is_bit v -> S.let_bit (Var.name v) y z + | Let (v,y,z) when is_reg v -> S.let_reg (Var.name v) y z + | Let (v,y,z) when is_mem v -> S.let_mem (Var.name v) y z + | Ite (x,y,z) -> S.ite x y z + | Extract (hi,lo,x) when hi = lo -> S.extract hi x + | Unknown (_,_) -> S.unknown () + | Let _ + | Extract _ + | UnOp (NEG,_) + | Cast (_,_,_) + | Load (_,_,_,_) + | Store (_,_,_,_,_) + | Concat (_,_) + | BinOp ((PLUS|MINUS|TIMES|DIVIDE|SDIVIDE| + MOD|SMOD|LSHIFT|RSHIFT|ARSHIFT),_,_) as exp + -> fail exp `Bool; S.error + + + let stmt : type t r. (t,exp,r,stmt) stmt_parser = + fun (module S) -> + let set v x = + let n = Var.name v in + match Var.typ v with + | Imm 1 -> S.set_bit n x + | Imm m -> S.set_reg n m x + | Mem (ks,vs) -> + S.set_mem n (Size.in_bits ks) (Size.in_bits vs) x in + function + | Move (v,x) -> set v x + | Jmp (Int x) -> S.goto (Word.to_bitvec x) + | Jmp x -> S.jmp x + | Special s when Call.is_extern s -> + S.call (Call.dst s) + | Special s -> S.special s + | While (c,xs) -> S.while_ c xs + | If (c,xs,ys) -> S.if_ c xs ys + | CpuExn n -> S.cpuexn n + + + let t = {bitv; mem; stmt; bool; float; rmode} +end + +module Lifter = Theory.Parser.Make(Theory.Manager) +module Optimizer = Theory.Parser.Make(Bil_semantics.Core) +[@@inlined] + +let provide_bir () = + Knowledge.promise Theory.Program.Semantics.slot @@ fun obj -> + KB.collect Theory.Program.Semantics.slot obj >>| fun sema -> + let bir = Bil_ir.reify @@ KB.Value.get Bil_ir.slot sema in + KB.Value.put Term.slot sema bir + + +module Relocations = struct + + type t = { + rels : addr Addr.Map.t; + exts : string Addr.Map.t; + } + + module Fact = Ogre.Make(Monad.Ident) + + module Request = struct + open Image.Scheme + open Fact.Syntax + + let of_aseq s = + Seq.fold s ~init:Addr.Map.empty ~f:(fun m (key,data) -> + Map.set m ~key ~data) + + let arch_width = + Fact.require arch >>= fun a -> + match Arch.of_string a with + | Some a -> Fact.return (Arch.addr_size a |> Size.in_bits) + | None -> Fact.failf "unknown/unsupported architecture" () + + let relocations = + arch_width >>= fun width -> + Fact.collect Ogre.Query.(select (from relocation)) >>= fun s -> + Fact.return + (of_aseq @@ Seq.map s ~f:(fun (addr, data) -> + Addr.of_int64 ~width addr, Addr.of_int64 ~width data)) + + let external_symbols = + arch_width >>= fun width -> + Fact.collect Ogre.Query.(select (from external_reference)) >>| fun s -> + Seq.fold s ~init:Addr.Map.empty ~f:(fun addrs (addr, name) -> + Map.set addrs + ~key:(Addr.of_int64 ~width addr) + ~data:name) + end + + let relocations = Fact.eval Request.relocations + let external_symbols = Fact.eval Request.external_symbols + let empty = {rels = Addr.Map.empty; exts = Addr.Map.empty} + + let of_spec spec = + match relocations spec, external_symbols spec with + | Ok rels, Ok exts -> {rels; exts} + | Error e, _ | _, Error e -> Error.raise e + + let span mem = + let start = Memory.min_addr mem in + let len = Memory.length mem in + Seq.init len ~f:(Addr.nsucc start) + + let find_external {exts} mem = + Seq.find_map ~f:(Map.find exts) (span mem) + + let find_internal {rels} mem = + Seq.find_map ~f:(Map.find rels) (span mem) + + let subscribe () = + let open Future.Syntax in + Stream.hd Project.Info.spec >>| + of_spec + + + let override_internal dst = + Stmt.map (object inherit Stmt.mapper + method! map_jmp _ = [Bil.Jmp (Int dst)] + end) + + let override_external name = + Stmt.map (object inherit Stmt.mapper + method! map_jmp _ = [Call.extern name] + end) + + + let fixup info mem bil = + match Future.peek info with + | None -> bil + | Some info -> + match find_internal info mem with + | Some dst -> + override_internal dst bil + | None -> + match find_external info mem with + | Some name -> + override_external name bil + | None -> bil + +end + +module Brancher = struct + include Theory.Core.Empty + + let pack kind dsts = + KB.Value.put Insn.Slot.dests (Theory.Effect.empty kind) dsts + + let get x = KB.Value.get Insn.Slot.dests x + + let union k e1 e2 = + pack k @@ match get e1, get e2 with + | None,s|s,None -> s + | Some e1, Some e2 -> Some (Set.union e1 e2) + + let ret kind dst = + let dsts = Set.singleton (module Theory.Label) dst in + KB.return @@ pack kind (Some dsts) + + let goto dst = ret Theory.Effect.Sort.jump dst + + let jmp _ = + KB.Object.create Theory.Program.cls >>= fun dst -> + ret Theory.Effect.Sort.jump dst + + let seq x y = + x >>= fun x -> + y >>= fun y -> + let k = Theory.Effect.sort x in + KB.return (union k x y) + + let blk _ data ctrl = + data >>= fun data -> + ctrl >>= fun ctrl -> + let k = Theory.Effect.Sort.join + [Theory.Effect.sort data] + [Theory.Effect.sort ctrl] in + KB.return (union k data ctrl) + + let branch _cnd yes nay = + yes >>= fun yes -> + nay >>= fun nay -> + let k = Theory.Effect.sort yes in + KB.return (union k yes nay) +end + +let provide_lifter () = + info "providing a lifter for all BIL lifters"; + let relocations = Relocations.subscribe () in + let unknown = Theory.Program.Semantics.empty in + let (>>?) x f = x >>= function + | None -> KB.return unknown + | Some x -> f x in + let lifter obj = + Knowledge.collect Arch.slot obj >>? fun arch -> + Knowledge.collect Memory.slot obj >>? fun mem -> + Knowledge.collect Disasm_expert.Basic.Insn.slot obj >>? fun insn -> + let module Target = (val target_of_arch arch) in + match Target.lift mem insn with + | Error _ -> + Knowledge.return (Insn.of_basic insn) + | Ok bil -> + Bil_semantics.context >>= fun ctxt -> + Knowledge.provide Bil_semantics.arch ctxt (Some arch) >>= fun () -> + Optimizer.run BilParser.t bil >>= fun sema -> + let bil = Insn.bil sema in + let bil = Relocations.fixup relocations mem bil in + Lifter.run BilParser.t bil >>| fun sema -> + let bil = Insn.bil sema in + KB.Value.merge ~on_conflict:`drop_left + sema (Insn.of_basic ~bil insn) in + Knowledge.promise Theory.Program.Semantics.slot lifter + + +let init () = + Bil_ir.init (); + provide_lifter (); + provide_bir (); + Theory.register + ~desc:"computes destinations" + ~name:"dests " + (module Brancher) diff --git a/plugins/bil/bil_lifter.mli b/plugins/bil/bil_lifter.mli new file mode 100644 index 000000000..28ead8dca --- /dev/null +++ b/plugins/bil/bil_lifter.mli @@ -0,0 +1 @@ +val init : unit -> unit diff --git a/plugins/bil/bil_main.ml b/plugins/bil/bil_main.ml index 50766ab12..9f008b49a 100644 --- a/plugins/bil/bil_main.ml +++ b/plugins/bil/bil_main.ml @@ -1,3 +1,4 @@ +open Bap_core_theory open Core_kernel open Bap.Std @@ -73,23 +74,16 @@ let () = `P "$(b,bap)(3)"; ] in let norml = - let doc = "Specifies the BIL normalization level. + let doc = "Selects a BIL normalization level. The normalization process doesn't change the semantics of a BIL program, but applies some transformations to simplify it. - There are two BIL normal forms (bnf): bnf1 and bnf2, both - of which are described in details in $(b,bap)(3). - Briefly, $(b,bnf1) produce the BIL code with load expressions - that applied to a memory only. And $(b,bnf2) also adds some - more restrictions like absence of let-expressions and makes - load/store operations sizes equal to one byte. - So, there are next possible options for normalization level: - $(b,0) - disables normalization; $(b,1) - produce BIL in bnf1; - $(b,2) - produce BIL in bnf2" in - Config.(param (enum normalizations) ~default:[bnf1] ~doc "normalization") in + Consult BAP Annotated Reference (BAR) for the detailed + description of the BIL normalized forms." in + Config.(param (enum normalizations) ~default:[] ~doc "normalization") in let optim = - let doc = "Specifies the optimization level.\n - Level $(b,0) disables optimization, and level $(b,1) performs - regular program simplifications, i.e., applies constant folding, + let doc = "Specifies an optimization level.\n + Level $(b,0) disables all optimizations, and level $(b,1) performs + regular program simplifications, e.g., applies constant folding, propagation, and elimination of dead temporary (aka virtual) variables." in Config.(param (enum optimizations) ~default:o1 ~doc "optimization") in let list_passes = @@ -101,6 +95,15 @@ let () = the lifing to BIL code." in Config.(param (list pass) ~default:[] ~doc "passes") in Config.when_ready (fun {Config.get=(!)} -> + if !list_passes then print_passes () - else - Bil.select_passes (!norml @ !optim @ !passes)) + else begin + Bil.select_passes (!norml @ !optim @ !passes); + Bil_lifter.init (); + Bil_ir.init(); + Theory.register + ~desc:"denotes programs in terms of BIL expressions and statements" + ~name:"bil" + (module Bil_semantics.Core_with_fp_emulation) + + end) diff --git a/plugins/bil/bil_semantics.ml b/plugins/bil/bil_semantics.ml new file mode 100644 index 000000000..1b755ea10 --- /dev/null +++ b/plugins/bil/bil_semantics.ml @@ -0,0 +1,637 @@ +open Core_kernel +open Bap.Std +open Bap_knowledge +open Bap_core_theory + +[@@@warning "-40"] + +type context = Context +let package = "bil-plugin-internal" +let cls = KB.Class.declare ~package "context" Context +let context = KB.Symbol.intern ~package "context" cls +let inherits slot = + KB.Class.property cls + (KB.Slot.name slot) + (KB.Slot.domain slot) + +let arch = inherits Arch.slot + + +let exp = Exp.slot +let stmt = Bil.slot + +let values = exp +let effects = stmt +let bool = Theory.Bool.t +let bits = Theory.Bitv.define +let size = Theory.Bitv.size +let sort = Theory.Value.sort +(* we need to recurse intelligently, only if optimization + occured, that might open a new optimization opportunity, + and continue recursion only if we have progress. +*) +module Simpl = struct + open Bil.Types + + let is0 = Word.is_zero and is1 = Word.is_one + let ism1 x = Word.is_zero (Word.lnot x) + + let zero width = Bil.Int (Word.zero width) + let ones width = Bil.Int (Word.ones width) + let app2 = Bil.Apply.binop + + let exp width = + let concat x y = match x, y with + | Int x, Int y -> Int (Word.concat x y) + | x,y -> Concat (x,y) + + and cast t s x = match x with + | Cast (_,s',_) as x when s = s' -> x + | Cast (t',s',x) when t = t' && s' > s -> Cast (t,s,x) + | Int w -> Int (Bil.Apply.cast t s w) + | _ -> Cast (t,s,x) + + and extract hi lo x = match x with + | Int w -> Int (Word.extract_exn ~hi ~lo w) + | x -> Extract (hi,lo,x) + + and unop op x = match x with + | Int x -> Int (Bil.Apply.unop op x) + | UnOp(op',x) when op = op' -> x + | x -> UnOp(op, x) + + and binop op x y = + let keep op x y = Bil.BinOp(op,x,y) in + let int f = function Bil.Int x -> f x | _ -> false in + let is0 = int is0 and is1 = int is1 and ism1 = int ism1 in + let (=) x y = compare_exp x y = 0 in + match op, x, y with + | op, Int x, Int y -> Int (app2 op x y) + | PLUS,BinOp(PLUS,x,Int y),Int z + | PLUS,BinOp(PLUS,Int y,x),Int z -> + BinOp(PLUS,x,Int (app2 PLUS y z)) + + | PLUS,x,y when is0 x -> y + | PLUS,x,y when is0 y -> x + | MINUS,x,y when is0 x -> UnOp(NEG,y) + | MINUS,x,y when is0 y -> x + | MINUS,x,y when x = y -> zero width + | MINUS,BinOp(MINUS,x,Int y), Int z -> + BinOp(MINUS,x,Int (app2 PLUS y z)) + | MINUS,BinOp(PLUS,x,Int c1),Int c2 + | MINUS,BinOp(PLUS,Int c1,x),Int c2 -> + BinOp(PLUS,x,Int (app2 MINUS c1 c2)) + + | TIMES,x,_ when is0 x -> x + | TIMES,_,y when is0 y -> y + | TIMES,x,y when is1 x -> y + | TIMES,x,y when is1 y -> x + | (DIVIDE|SDIVIDE),x,y when is1 y -> x + | (MOD|SMOD),_,y when is1 y -> zero width + | (LSHIFT|RSHIFT|ARSHIFT),x,y when is0 y -> x + | (LSHIFT|RSHIFT|ARSHIFT),x,_ when is0 x -> x + | ARSHIFT,x,_ when ism1 x -> x + | AND,x,_ when is0 x -> x + | AND,_,y when is0 y -> y + | AND,x,y when ism1 x -> y + | AND,x,y when ism1 y -> x + | AND,x,y when x = y -> x + | OR,x,y when is0 x -> y + | OR,x,y when is0 y -> x + | OR,x,_ when ism1 x -> x + | OR,_,y when ism1 y -> y + | OR,x,y when x = y -> x + | XOR,x,y when x = y -> zero width + | XOR,x,y when is0 x -> y + | XOR,x,y when is0 y -> x + | EQ,x,y when x = y -> Int Word.b1 + | NEQ,x,y when x = y -> Int Word.b0 + | (LT|SLT), x, y when x = y -> Int Word.b0 + | op,x,y -> keep op x y + and ite_ c x y = match c with + | Int c -> if Word.(c = b1) then x else y + | _ -> Ite(c,x,y) in + + let run : exp -> exp = function + | BinOp (op,x,y) -> binop op x y + | UnOp (op,x) -> unop op x + | Cast (t,s,x) -> cast t s x + | Ite (x,y,z) -> ite_ x y z + | Extract (h,l,x) -> extract h l x + | Concat (x,y) -> concat x y + | Let _ + | Var _ | Int _ | Unknown (_,_) + | Load _ | Store _ as x -> x in + run +end + +module Basic : Theory.Basic = struct + open Knowledge.Syntax + + module Base = struct + + let ret = Knowledge.return + + let simpl = Simpl.exp + + + let value x = KB.Value.get exp x + + let v s e = KB.Value.put exp (Theory.Value.empty s) e + + let (%:) e s = v s e + + + let exp s x = ret @@ x %: s + let bit x = ret @@ simpl 1 x %: Theory.Bool.t + let mem s v = ret @@ v %: s + let vec s v = ret @@ simpl (Theory.Bitv.size s) v %: s + + let gen s' v = + let s = Theory.Value.Sort.forget s' in + ret @@ match Theory.Bool.refine s with + | Some _ -> simpl 1 v %: s' + | None -> match Theory.Bitv.refine s with + | None -> v %:s' + | Some b -> simpl (Theory.Bitv.size b) v %: s' + + let unk s' = + let s = Theory.Value.Sort.forget s' in + ret @@ match Theory.Bool.refine s with + | Some _ -> Bil.unknown "bits" bool_t %: s' + | None -> match Theory.Bitv.refine s with + | Some b -> + Bil.unknown "bits" (Type.imm (Theory.Bitv.size b)) %: s' + | None -> Bil.unknown "unk" Type.Unk %: s' + + let empty = Theory.Effect.empty Theory.Effect.Sort.bot + let eff d = ret @@ KB.Value.put stmt empty d + let data s = eff s + let ctrl s = eff s + let bool_exp : _ -> Bil.exp = value + let bitv_exp : _ -> Bil.exp = value + let var r = exp (Theory.Var.sort r) (Var (Var.reify r)) + + let b0 = bit Bil.(int Word.b0) + let b1 = bit Bil.(int Word.b1) + + let int s w = vec s @@ Bil.(int @@ Word.create w (size s)) + + let effect x = KB.Value.get effects x + + let (>>->) v f = v >>= fun v -> f (sort v) (value v) + + let lift1 mk s f v = v >>-> fun sort x -> mk (s sort) (f x) + + let lift2 mk s f x y = + x >>-> fun sx x -> + y >>-> fun sy y -> + mk (s sx sy) (f x y) + + let lift3 mk s f x y z = + x >>-> fun sx x -> + y >>-> fun sy y -> + z >>-> fun sz z -> + mk (s sx sy sz) (f x y z) + + type 'a sort = 'a Theory.Value.sort + type bit = Theory.Bool.t + + (* typing rules *) + let t_lo1 : 'a sort -> bit sort = fun _ -> bool + let t_lo2 : 'a sort -> 'a sort -> bit sort = fun _ _ -> bool + let t_uop : 'a sort -> 'a sort = fun x -> x + let t_aop : 'a sort -> 'a sort -> 'a sort = fun x _ -> x + let t_sop : 'a sort -> 'b sort -> 'a sort = fun x _ -> x + + (* operators *) + let lo1 x = lift1 (fun _ x -> bit x) t_lo1 x + let lo2 x y = lift2 (fun _ x -> bit x) t_lo2 x y + let uop x = lift1 vec t_uop x + let aop x y = lift2 vec t_aop x y + let sop x y = lift2 vec t_sop x y + + let or_ x = lo2 Bil.(lor) x + let and_ x = lo2 Bil.(land) x + + let inv x = lo1 Bil.(lnot) x + let msb x = lo1 Bil.(cast high 1) x + let lsb x = lo1 Bil.(cast low 1) x + let neg x = uop Bil.(unop neg) x + let not x = uop Bil.(lnot) x + let add x y = aop Bil.(+) x y + let sub x y = aop Bil.(-) x y + let mul x y = aop Bil.( * ) x y + let div x y = aop Bil.(/) x y + let sdiv x y = aop Bil.(/$) x y + let modulo x y = aop Bil.(mod) x y + let smodulo x y = aop Bil.(%$) x y + let logand x y = aop Bil.(land) x y + let logor x y = aop Bil.(lor) x y + let logxor x y = aop Bil.(lxor) x y + let ule x y = lo2 Bil.(<=) x y + let sle x y = lo2 Bil.(<=$) x y + let eq x y = lo2 Bil.(=) x y + let neq x y = lo2 Bil.(<>) x y + let slt x y = lo2 Bil.(<$) x y + let ult x y = lo2 Bil.(<) x y + + let sgt x y = slt y x + let ugt x y = ult y x + let sge x y = sle y x + let uge x y = ule y x + + let small s x = Bil.Int (Word.of_int ~width:(size s) x) + let mk_zero s = Bil.Int (Word.zero (size s)) + + let is_zero x = + x >>= fun x -> + let s = sort x in + bit @@ Bil.(bitv_exp x = mk_zero s) + + let non_zero x = + x >>= fun x -> + let s = sort x in + bit @@ Bil.(bitv_exp x <> mk_zero s) + + let shiftr b x y = + b >>-> fun _s b -> + x >>-> fun xs x -> + y >>-> fun _ y -> + vec xs @@ + if Exp.equal b (Bil.int Word.b0) then Bil.(x lsr y) + else + let ones = Word.ones (size xs) in + let mask = Bil.(lnot (int ones lsr y)) in + Bil.(ite b ((x lsr y) lor mask) (x lsr y)) + + let shiftl b x y = + b >>-> fun _s b -> + x >>-> fun xs x -> + y >>-> fun _ y -> + vec xs @@ + if Exp.equal b (Bil.int Word.b0) then Bil.(x lsl y) + else + let simpl = simpl (size xs) in + let ones = Word.ones (size xs) in + let shifted = simpl Bil.(int ones lsl y) in + let mask = simpl Bil.(lnot shifted) in + let lhs = simpl Bil.(x lsl y) in + let yes = simpl Bil.(lhs lor mask) in + let nay = simpl Bil.(x lsl y) in + Bil.(ite b yes nay) + + let app_bop lift op x y = lift (Bil.binop op) x y + let arshift x y = app_bop sop Bil.arshift x y + let rshift x y = app_bop sop Bil.rshift x y + let lshift x y = app_bop sop Bil.lshift x y + + let ite cnd yes nay = + cnd >>= fun cnd -> + yes >>-> fun s yes -> + nay >>-> fun _ nay -> + gen s (Bil.ite (bool_exp cnd) yes nay) + + let (>>:=) v f = v >>= fun v -> f (effect v) + + let branch cnd yes nay = + cnd >>= fun cnd -> + yes >>= fun yes -> + nay >>:= fun nay -> + eff Bil.[If (bool_exp cnd,effect yes,nay)] + + let make_cast s t x = + x >>-> fun _ x -> vec s Bil.(cast t (size s) x) + + let high s = make_cast s Bil.high + let low s = make_cast s Bil.low + let signed s = make_cast s Bil.signed + let unsigned s = make_cast s Bil.unsigned + + let mask_high res n = + let width = size res in + let n = Word.of_int ~width n in + let w = Word.(lnot (ones width lsr n)) in + int res (Word.to_bitvec w) + + let cast res b x = + b >>= fun b -> + x >>= fun x -> + let sort = sort x in + let src = bitv_exp x in + let fill = bool_exp b in + let diff = size res - size sort in + let cast kind = vec res Bil.(Cast (kind,size res,src)) in + match compare diff 0,fill with + | 0,_ -> vec res src + | 1, Bil.Int b -> + if Word.(b = b0) then cast UNSIGNED + else ite (msb !!x) + (cast SIGNED) + (logor (cast UNSIGNED) (mask_high res diff)) + | 1, _ -> + ite !!b + (logor (cast UNSIGNED) (mask_high res diff)) + (cast UNSIGNED) + | _ -> vec res (Cast (LOW,size res,src)) + + let append s ex ey = + ex >>= fun ex -> + ey >>= fun ey -> + let sx = sort ex and sy = sort ey in + let x = bitv_exp ex and y = bitv_exp ey in + match compare (size sx + size sy) (size s) with + | 0 -> vec s (Concat (x,y)) + | 1 -> + let extra = size s - size sx - size sy in + vec s @@ Cast(UNSIGNED,extra,(Concat (x,y))) + | _ -> + if size s < size sx + then vec s (Cast (LOW,size s, x)) + else vec s (Cast (LOW,size s, Concat (x,y))) + + let rec uncat acc : Bil.exp -> Bil.exp list = function + | Concat ((Concat (x,y)), z) -> uncat (y::z::acc) x + | Concat (x,y) -> x::y::acc + | x -> x::acc + + let concat s vs = match vs with + | [] -> unk s + | _ -> + Knowledge.List.all vs >>= fun vs -> + let sz = List.fold ~init:0 vs ~f:(fun sz x -> + sz + size (sort x)) in + let x = List.reduce_exn ~f:(fun x y -> Bil.Concat (x,y)) @@ + List.map vs ~f:bitv_exp in + cast s b0 (vec (bits sz) x) + + let load mem key = + mem >>-> fun sort mem -> + key >>-> fun _ key -> + let vals = Theory.Mem.vals sort in + match Size.of_int_opt (size vals) with + | Some sz -> + exp vals Bil.(load mem key BigEndian sz) + | None -> unk vals + + let store m k d = + m >>-> fun ms m -> + k >>-> fun _ k -> + d >>-> fun ds d -> + match Size.of_int_opt (size ds) with + | Some rs -> + exp ms Bil.(store ~mem:m ~addr:k d BigEndian rs) + | _ -> unk ms + + let perform s = ret (Theory.Effect.empty s) + let pass = data [] + let skip = ctrl [] + + let seq x y = + x >>= fun x -> + y >>= fun y -> + eff (effect x @ effect y) + + let blk _ x y = + x >>:= fun x -> + y >>:= fun y -> + eff (x @ y) + + let recursive_simpl = Exp.simpl ~ignore:Bap.Std.Eff.[load;store;read] + + let let_ var rhs body = + let v = Var.reify var in + rhs >>-> fun _ rhs -> + body >>-> fun bs body -> + match rhs with + | (Int _) as rhs -> + exp bs @@ recursive_simpl @@ Let (v,rhs,body) + | _ -> gen bs @@ Let (v,rhs,body) + + let set var rhs = + rhs >>-> fun _ rhs -> + let var = Var.reify var in + data [Bil.Move (var,rhs)] + + let repeat cnd body = + cnd >>= fun cnd -> + body >>:= fun body -> + data [Bil.While (bool_exp cnd, body)] + + let jmp dst = + dst >>= fun dst -> ctrl [Bil.Jmp (bitv_exp dst)] + + let goto lbl = ctrl [ + Bil.special @@ Format.asprintf "(goto %a)" Tid.pp lbl + ] + + + end + + include Theory.Basic.Make(Base) + include Base + + let loadw rs cnd mem key = + match Size.of_int_opt (size rs) with + | None -> loadw rs cnd mem key + | Some sz -> + cnd >>= fun cnd -> + key >>= fun key -> + mem >>-> fun _ mem -> + let dir = bool_exp cnd and key = bitv_exp key in + let bel = vec rs @@ Load (mem,key,BigEndian,sz) + and lel = vec rs @@ Load (mem,key,LittleEndian,sz) in + match dir with + | Int dir -> if Word.(dir = b1) then bel else lel + | _ -> ite !!cnd bel lel + + let storew cnd mem key elt = + elt >>-> fun es e -> + match Size.of_int_opt (size es) with + | None -> storew cnd mem key elt + | Some sz -> + cnd >>| value >>= fun dir -> + key >>| value >>= fun key -> + mem >>-> fun sort mem -> + let bes = exp sort @@ Store (mem,key,e,BigEndian,sz) + and les = exp sort @@ Store (mem,key,e,LittleEndian,sz) in + match dir with + | Int dir -> if Word.(dir = b1) then bes else les + | _ -> ite (bit dir) bes les + + let extract s hi lo x = + hi >>= fun hi -> + lo >>= fun lo -> + x >>= fun e -> + match value hi,value lo with + | Int h, Int l -> + let h = Word.to_int_exn h + and l = Word.to_int_exn l in + if h - l + 1 = size s + then vec s @@ Bil.(extract ~hi:h ~lo:l (value e)) + else extract s !!hi !!lo !!e + | _ -> extract s !!hi !!lo !!e + + let arch lbl = + KB.collect Arch.slot lbl >>= function + | Some _ as r -> !!r + | None -> context >>= KB.collect arch + + let goto lbl = + KB.collect Theory.Label.addr lbl >>= fun dst -> + arch lbl >>= fun arch -> + match dst, arch with + | Some addr, Some arch -> + let size = Size.in_bits (Arch.addr_size arch) in + let dst = Word.create addr size in + ctrl Bil.[Jmp (Int dst)] + | _ -> KB.collect Theory.Label.ivec lbl >>= function + | Some ivec -> ctrl Bil.[CpuExn ivec] + | None -> KB.collect Theory.Label.name lbl >>= fun name -> + let dst = match name with + | Some name -> sprintf "(call %s)" name + | None -> (Format.asprintf "(goto %a)" Tid.pp lbl) in + ctrl Bil.[Special dst] +end + + +module Core : Theory.Core = struct + include Theory.Core.Empty + include Basic +end + +module FBil = Bil_float.Make(Core) + +module FPEmulator = struct + open Knowledge.Syntax + type 'a t = 'a knowledge + + let supported = Theory.IEEE754.[ + binary16; + binary32; + binary64; + binary80; + binary128; + ] + + let ieee754_of_sort s = + List.find supported ~f:(fun p -> + Theory.Value.Sort.same s (Theory.IEEE754.Sort.define p)) + + let resort s x = KB.Value.refine x s + + let fbits x = + x >>| fun x -> resort (Theory.Float.size (sort x)) x + + let float s x = + x >>| fun x -> resort s x + + + let fop : type f. + (_ -> _ -> _ Theory.bitv -> _ Theory.bitv -> _ Theory.bitv) -> + _ -> f Theory.float -> f Theory.float -> f Theory.float = + fun op rm x y -> + x >>= fun x -> + y >>= fun y -> + let xs = sort x in + match ieee754_of_sort xs with + | None -> Core.unk xs + | Some ({Theory.IEEE754.k} as p) -> + let bs = bits k in + let x = resort bs x and y = resort bs y in + let s = Theory.IEEE754.Sort.define p in + float xs (op s rm !!x !!y) + + let fadd rm = fop FBil.fadd rm + let fsub rm = fop FBil.fsub rm + let fmul rm = fop FBil.fmul rm + let fdiv rm = fop FBil.fdiv rm + + let fuop : type f. + _ -> + _ -> f Theory.float -> f Theory.float = + fun op rm x -> + x >>= fun x -> + let xs = sort x in + match ieee754_of_sort xs with + | None -> Core.unk xs + | Some ({Theory.IEEE754.k} as p) -> + let bs = bits k in + let x = resort bs x in + let s = Theory.IEEE754.Sort.define p in + float xs (op s rm !!x) + + let fsqrt rm x = fuop FBil.fsqrt rm x + + open Core + + let small s x = + let m = Bitvec.modulus (size s) in + int s Bitvec.(int x mod m) + + let classify {Theory.IEEE754.w; t} v ~fin ~inf ~nan = + let ws = bits w and fs = bits t in + let expn = extract ws (small ws (t+w-1)) (small ws t) v in + let frac = extract fs (small fs (t-1)) (small fs 0) v in + let ones = small ws ~-1 in + let zero = small fs 0 in + let is_fin = inv (eq expn ones) in + let is_sub = eq expn (small ws 0) in + let is_pos = msb v in + ite is_fin + (fin ~is_sub) + (ite (eq frac zero) (inf ~is_pos) + (nan ~is_tss:(msb frac))) + + + let tmp x f = + x >>= fun x -> + Theory.Var.scoped (sort x) @@ fun v -> + let_ v !!x (f (var v)) + + + let make_cast_float cast s m v = + match ieee754_of_sort s with + | None -> Core.unk s + | Some p -> + cast (Theory.IEEE754.Sort.define p) m v >>| resort s + + let cast_float s m v = make_cast_float FBil.cast_float s m v + let cast_sfloat s m v = make_cast_float FBil.cast_float_signed s m v + + let forder x y = + x >>= fun x -> + y >>= fun y -> + let xs = sort x in + match ieee754_of_sort xs with + | None -> Core.unk bool + | Some ({Theory.IEEE754.k; w; t}) -> + let bs = bits k and ms = bits (k-1)in + let x = resort bs x and y = resort bs y in + let ws = bits w and fs = bits t in + let ones = small ws ~-1 in + let zero = small fs 0 in + let expn v = extract ws (small ws (t+w-1)) (small ws t) v in + let frac v = extract fs (small fs (t-1)) (small fs 0) v in + let magn v = extract ms (small ms (k-2)) (small ms 0) v in + let not_nan v = or_ (neq (expn v) ones) (neq (frac v) zero) in + tmp (magn !!x) @@ fun mx -> + tmp (magn !!y) @@ fun my -> + tmp (msb !!x) @@ fun x_is_neg -> + tmp (msb !!y) @@ fun y_is_neg -> + let x_is_pos = inv x_is_neg and y_is_pos = inv y_is_neg in + List.reduce_exn ~f:and_ [ + not_nan !!x; + not_nan !!y; + inv (and_ (is_zero mx) (is_zero my)); + inv (and_ x_is_pos y_is_neg); + or_ + (and_ x_is_neg y_is_pos) + (ite x_is_neg (ult my mx) (ult mx my)) + ] +end + +module Core_with_fp_emulation = struct + include Core + include FPEmulator +end diff --git a/plugins/bil/bil_semantics.mli b/plugins/bil/bil_semantics.mli new file mode 100644 index 000000000..46b36d422 --- /dev/null +++ b/plugins/bil/bil_semantics.mli @@ -0,0 +1,10 @@ +open Bap.Std +open Bap_core_theory + +type context +val context : context KB.obj KB.t +val arch : (context, arch option) KB.slot + + +module Core : Theory.Core +module Core_with_fp_emulation : Theory.Core diff --git a/plugins/byteweight/byteweight_main.ml b/plugins/byteweight/byteweight_main.ml index 4036f0d70..9654f7c18 100644 --- a/plugins/byteweight/byteweight_main.ml +++ b/plugins/byteweight/byteweight_main.ml @@ -11,22 +11,24 @@ module Sigs = Bap_byteweight_signatures let create_finder path length threshold arch comp = match Sigs.load ?comp ?path ~mode:"bytes" arch with | Error `No_signatures -> - info "signature database is not available"; + info "the signature database is not available"; info "advice - use `bap-byteweight` to install signatures"; - Or_error.errorf "no signatures" + info "advice - alternatively, use `opam install bap-signatures'"; + Or_error.errorf "signatures are unavailable" | Error (`Corrupted err) -> let path = Option.value path ~default:Sigs.default_path in - error "signature database is corrupted: %s" err; + error "the signature database is corrupted: %s" err; info "advice - delete signatures at `%s'" path; info "advice - use `bap-byteweight` to install signatures"; - Or_error.errorf "corrupted database" - | Error (`No_entry err) -> - error "no signatures for specified compiler and architecture"; - info "advice - try to use default compiler entry"; - info "advice - create new entries with `bap-byteweight' tool"; - Or_error.errorf "no entry" + info "advice - alternatively, use `opam install bap-signatures'"; + Or_error.errorf "signatures are corrupted" + | Error (`No_entry _) -> + error "no signatures for the specified compiler and/or architecture"; + info "advice - try to use the default compiler entry"; + info "advice - create new entries using the `bap-byteweight' tool"; + Or_error.errorf "compiler is not supported by signatures" | Error (`Sys_error err) -> - error "signature loading was prevented by a system error: %s" err; + error "failed to load the signatures because of a system error: %s" err; Or_error.errorf "system error" | Ok data -> let bw = Binable.of_string (module BW) (Bytes.to_string data) in @@ -43,22 +45,25 @@ let main path length threshold comp = Set.union roots @@ Addr.Set.of_list (finder mem)) in let find_roots arch mem = match finder arch with | Error _ as err -> - warning "unable to provide rooter service"; + warning "will not provide roots"; err | Ok finder -> match find finder mem with | roots when Set.is_empty roots -> - info "no roots was found"; - info "advice - check your compiler's signatures"; + info "no roots were found"; + info "advice - check your signatures"; Ok (Rooter.create Seq.empty) | roots -> Ok (roots |> Set.to_sequence |> Rooter.create) in - let rooter = - let open Project.Info in - Stream.Variadic.(apply (args arch $ code) ~f:find_roots) in if sigs_exists path then - Rooter.Factory.register name rooter - else - let () = warning "signature database is not available" in - info "advice - use `bap-byteweight` to install signatures" + let inputs = Stream.zip Project.Info.arch Project.Info.code in + Stream.observe inputs (fun (arch,mem) -> + match find_roots arch mem with + | Ok roots -> Rooter.provide roots + | Error _ -> ()) + else begin + warning "the signature database is not available"; + info "advice - use `bap-byteweight` to install signatures"; + info "advice - alternatively, use `opam install bap-signatures'"; + end let () = @@ -67,11 +72,11 @@ let () = `P - "This plugin provides a rooter (function start identification) - service using the BYTEWEIGHT algorithm described in [1]. The - plugin operates on a byte level. The $(b,SEE ALSO) section - contains links for other plugins, that provides rooters"; - + "This plugin identifies function starts, partially \ + implementing on the BYTEWEIGHT algorithm described in \ + [1]. Only the byte level matching is implemented. The $(b,SEE \ + ALSO) section contains links for other plugins, that provides \ + rooters"; `P "[1]: Bao, Tiffany, et al. \"Byteweight: Learning to recognize functions in binary code.\" 23rd USENIX Security Symposium (USENIX @@ -86,7 +91,7 @@ let () = let doc = "Minimum score for the function start" in Config.(param float ~default:0.9 "threshold" ~doc) in let sigsfile : string option Config.param = - let doc = "Path to the signature file. No needed by default, \ + let doc = "Path to the signature file. Not needed by default, \ usually it is enough to run `bap-byteweight update'." in Config.(param (some non_dir_file) "sigs" ~doc) in let compiler : string option Config.param = diff --git a/plugins/constant_tracker/.merlin b/plugins/constant_tracker/.merlin index e92307058..26d8b94f2 100644 --- a/plugins/constant_tracker/.merlin +++ b/plugins/constant_tracker/.merlin @@ -1,3 +1,4 @@ PKG bap PKG bap-primus FLG -short-paths,-open Bap.Std,-open Bap_primus.Std +REC \ No newline at end of file diff --git a/plugins/ida/bap_ida_info.ml b/plugins/ida/bap_ida_info.ml index 422643348..9203de374 100644 --- a/plugins/ida/bap_ida_info.ml +++ b/plugins/ida/bap_ida_info.ml @@ -8,18 +8,18 @@ type new_kind = [ `idat | `idat64 | `ida | `ida64 ] [@@deriving sexp, enumerat type ida_kind = [ old_kind | new_kind ] [@@deriving sexp] type ida = { - headless : ida_kind; - graphical : ida_kind; - headless64 : ida_kind; - graphical64 : ida_kind; - version : version; - } [@@deriving sexp] + headless : ida_kind; + graphical : ida_kind; + headless64 : ida_kind; + graphical64 : ida_kind; + version : version; +} [@@deriving sexp] type t = { - path : string; - ida : ida; - is_headless : bool; - } + path : string; + ida : ida; + is_headless : bool; +} type mode = [ `m32 | `m64 ] @@ -76,12 +76,12 @@ module Check = struct let run {path; ida} = let require_kind = require_kind path in Result.all_unit [ - require_ida path; - require_ida_python path; - require_kind ida.graphical; - require_kind ida.graphical64; - require_kind ida.headless; - require_kind ida.headless64; ] + require_ida path; + require_ida_python path; + require_kind ida.graphical; + require_kind ida.graphical64; + require_kind ida.headless; + require_kind ida.headless64; ] let check_integrity ida = let files = [ida.graphical; ida.headless; @@ -103,8 +103,8 @@ let check ida = match Check.run ida with | Ok () as ok -> ok | Error fail -> - Or_error.errorf "IDA check failed with error code %d" - (code_of_failure fail) + Or_error.errorf "IDA check failed with error code %d" + (code_of_failure fail) let exists_kind path kind = Sys.file_exists (path / string_of_kind kind) @@ -117,12 +117,12 @@ let create_ida path = | `m32 -> [`idaq; `ida] | `m64 -> [`idaq64; `ida64] in List.find ~f:(exists_kind path) kinds |> - function - | Some k -> Ok k - | None -> - let kinds = List.map ~f:string_of_kind kinds in - let files = String.concat ~sep:"/" kinds in - Error (File_not_found files) in + function + | Some k -> Ok k + | None -> + let kinds = List.map ~f:string_of_kind kinds in + let files = String.concat ~sep:"/" kinds in + Error (File_not_found files) in let version_of_headless = function | `idal -> Vold | _ -> Vnew in @@ -148,21 +148,23 @@ let create path is_headless = match create' path is_headless with | Ok r -> Ok r | Error fail -> - warning "%s" (string_of_failure fail); - Or_error.errorf "IDA detection failed with error code %d" (code_of_failure fail) + warning "%s" (string_of_failure fail); + Or_error.errorf "IDA detection failed with error code %d: %s" + (code_of_failure fail) + (string_of_failure fail) (* Note, we always launch headless ida in case of IDA Pro 7 *) let ida32 info = match info.ida.version with | Vnew -> info.ida.headless | Vold -> - if info.is_headless then info.ida.headless - else info.ida.graphical + if info.is_headless then info.ida.headless + else info.ida.graphical let ida64 info = match info.ida.version with | Vnew -> info.ida.headless64 | Vold -> - if info.is_headless then info.ida.headless64 - else info.ida.graphical64 + if info.is_headless then info.ida.headless64 + else info.ida.graphical64 let ida_of_suffix info filename = let ext = FilePath.replace_extension in @@ -176,14 +178,14 @@ let ida_of_mode info = function | `m64 -> ida64 info let find_ida info mode target = - let kind = match mode with - | Some mode -> ida_of_mode info mode - | None -> - match ida_of_suffix info target with - | Some ida -> ida - | None -> ida64 info in - let s = Sexp.to_string (sexp_of_ida_kind kind) in - Filename.concat info.path s + let kind = match mode with + | Some mode -> ida_of_mode info mode + | None -> + match ida_of_suffix info target with + | Some ida -> ida + | None -> ida64 info in + let s = Sexp.to_string (sexp_of_ida_kind kind) in + Filename.concat info.path s let is_headless t = t.is_headless let path t = t.path diff --git a/plugins/ida/ida_main.ml b/plugins/ida/ida_main.ml index 5d600a10f..b071e6a36 100644 --- a/plugins/ida/ida_main.ml +++ b/plugins/ida/ida_main.ml @@ -1,3 +1,4 @@ +open Bap_knowledge open Core_kernel open Regular.Std open Bap_future.Std @@ -19,9 +20,7 @@ module Symbols = Data.Make(struct module type Target = sig type t val of_blocks : (string * addr * addr) seq -> t - module Factory : sig - val register : string -> t source -> unit - end + val provide : Knowledge.agent -> t -> unit end let digest = Caml.Digest.file @@ -58,13 +57,16 @@ let extract path arch = List.map syms ~f:(fun (n,s,e) -> n, addr s, addr e) |> Seq.of_list +let ida_symbolizer = + let reliability = Knowledge.Agent.reliable in + Knowledge.Agent.register ~reliability + ~package:"bap.std" "ida-symbolizer" + ~desc:"Provides information from IDA Pro" + let register_source (module T : Target) = - let source = - let open Project.Info in - let extract file arch = Or_error.try_with ~backtrace:true (fun () -> - extract file arch |> T.of_blocks) in - Stream.merge file arch ~f:extract in - T.Factory.register name source + let inputs = Stream.zip Project.Info.file Project.Info.arch in + Stream.observe inputs @@ fun (file,arch) -> + T.provide ida_symbolizer (T.of_blocks (extract file arch)) type perm = [`code | `data] [@@deriving sexp] @@ -217,16 +219,16 @@ let get_resolve_fun file arch = (IdaBrancher.resolve brancher) let register_brancher_source () = - let source = - let create_brancher file arch = Or_error.try_with (fun () -> - Brancher.create (get_resolve_fun file arch)) in - Project.Info.(Stream.merge file arch ~f:create_brancher) in - Brancher.Factory.register name source + let inputs = + Stream.zip Project.Info.file Project.Info.arch in + Stream.observe inputs @@ fun (file,arch) -> + Brancher.provide @@ Brancher.create (get_resolve_fun file arch) let main () = - register_source (module Rooter); + register_source (module struct include Rooter + let provide _ data = provide data + end); register_source (module Symbolizer); - register_source (module Reconstructor); register_brancher_source (); Project.Input.register_loader name loader @@ -303,5 +305,5 @@ module Cmdline = struct match Info.create ida_path is_headless with | Ok info -> Bap_ida_service.register info !mode; main () | Error e -> - error "%S. Service not registered." (Error.to_string_hum e)) + error "%S. Service not registered." (Error.to_string_hum e)) end diff --git a/plugins/mips/mips.ml b/plugins/mips/mips.ml index 08b39d4d0..e0903c5f6 100644 --- a/plugins/mips/mips.ml +++ b/plugins/mips/mips.ml @@ -1,5 +1,6 @@ open Core_kernel open Bap.Std +open Bap_core_theory (* This CPU model and instruction set is based on the * "MIPS Architecture For Programmers @@ -36,8 +37,12 @@ module Std = struct let lifters = String.Table.create () - let register name lifter = - String.Table.change lifters name ~f:(fun _ -> Some lifter) + + let delayed_opcodes = Hashtbl.create (module String) + + let register ?delay name lifter = + Option.iter delay ~f:(fun d -> Hashtbl.add_exn delayed_opcodes name d); + Hashtbl.add_exn lifters name lifter let (>>) = register @@ -82,5 +87,18 @@ module Std = struct end include Model - end + +let () = + let provide_delay obj = + let open KB.Syntax in + KB.collect Arch.slot obj >>= function + | Some #Arch.mips -> + KB.collect Theory.Program.Semantics.slot obj >>| fun insn -> + let name = KB.Value.get Insn.Slot.name insn in + Hashtbl.find_and_call Std.delayed_opcodes name + ~if_found:(fun delay -> + KB.Value.put Insn.Slot.delay insn (Some delay)) + ~if_not_found:(fun _ -> insn) + | _ -> KB.return Insn.empty in + KB.promise Theory.Program.Semantics.slot provide_delay diff --git a/plugins/mips/mips_abi.ml b/plugins/mips/mips_abi.ml index 80a5ee39e..c35bd9680 100644 --- a/plugins/mips/mips_abi.ml +++ b/plugins/mips/mips_abi.ml @@ -97,7 +97,6 @@ let strip_leading_dot s = let demangle demangle prog = Term.map sub_t prog ~f:(fun sub -> let name = demangle (Sub.name sub) in - Tid.set_name (Term.tid sub) name; Sub.with_name sub name) let set_abi proj m = diff --git a/plugins/mips/mips_branch.ml b/plugins/mips/mips_branch.ml index d292911d7..a9264737d 100644 --- a/plugins/mips/mips_branch.ml +++ b/plugins/mips/mips_branch.ml @@ -1,3 +1,4 @@ +open Bap.Std open Mips.Std (* BAL rs, offset @@ -371,13 +372,14 @@ let bne cpu ops = (* NAL * No-op and Link, MIPS32 Release 6, deprecated * Page 358 *) -let nal cpu ops = +let nal cpu _ = let step = unsigned const byte 8 in RTL.[ cpu.gpr 31 := cpu.cia + step; ] let () = + let (>>) = register ~delay:1 in "BAL" >> bal; "BEQ" >> beq; "BEQL" >> beql; @@ -403,7 +405,6 @@ let () = "BGTZC" >> bgtzc; "BEQZC" >> beqzc; "BNEZC" >> bnezc; - "BGEZC" >> bgezc; "BGTZ" >> bgtz; "BGTZL" >> bgtzl; "J" >> jump; diff --git a/plugins/mips/mips_conditional.ml b/plugins/mips/mips_conditional.ml index 506171ab4..b15566375 100644 --- a/plugins/mips/mips_conditional.ml +++ b/plugins/mips/mips_conditional.ml @@ -91,10 +91,9 @@ let selnez cpu ops = ] let () = - "SLT" >> slt; + "SLT" >> slt; "SLTi" >> slti; "SLTu" >> sltu; "SLTiu" >> sltiu; "SELEQZ" >> seleqz; "SELNEQ" >> selnez; - diff --git a/plugins/mips/mips_cpu.ml b/plugins/mips/mips_cpu.ml index 542a7b2c0..c8b77170a 100644 --- a/plugins/mips/mips_cpu.ml +++ b/plugins/mips/mips_cpu.ml @@ -52,9 +52,15 @@ let make_cpu addr_size endian memory = mips_fail "%s with number %d not found" name n in let gpr n = find "GPR" gpri n in let fpr n = find "FPR" fpri n in - let word_width, word_bitwidth = match addr_size with - | `r32 -> unsigned const byte 32, word - | `r64 -> unsigned const byte 64, doubleword in - { load; store; jmp; cia; word_width; word_bitwidth; + let word_width, delay, word_bitwidth = match addr_size with + | `r32 -> + unsigned const byte 32, + unsigned const byte 4, + word + | `r64 -> + unsigned const byte 64, + unsigned const byte 4, + doubleword in + { load; store; jmp; cia; word_width; delay; word_bitwidth; reg; gpr; fpr; hi; lo; } diff --git a/plugins/mips/mips_rtl.ml b/plugins/mips/mips_rtl.ml index e39665a47..5823cfc17 100644 --- a/plugins/mips/mips_rtl.ml +++ b/plugins/mips/mips_rtl.ml @@ -52,6 +52,7 @@ let rec bil_exp = function | Concat (x, y) -> Bil.(bil_exp x ^ bil_exp y) | Binop (op, x, y) -> Bil.binop op (bil_exp x) (bil_exp y) | Extract (hi, lo, x) -> Bil.extract hi lo (bil_exp x) + | Cast (_,1,x) -> Bil.(cast low 1 (bil_exp x)) | Cast (Signed, width, x) -> Bil.(cast signed width (bil_exp x)) | Cast (Unsigned, width, x) -> Bil.(cast unsigned width (bil_exp x)) | Unop (op, x) -> Bil.unop op (bil_exp x) @@ -72,15 +73,16 @@ let var_of_exp e = match e.body with module Exp = struct let cast x width sign = - let nothing_to_cast = - (x.sign = sign && x.width = width) || - (x.width = width && width = 1) in - if nothing_to_cast then x - else - if x.width = 1 then - {width; sign; body = Cast (x.sign, width, x.body)} - else - {width; sign; body = Cast (sign, width, x.body)} + let same_sign = x.sign = sign + and same_size = x.width = width in + match same_sign, same_size with + | true,true -> x (* nothing is changed *) + | false,true -> {x with sign} (* size is preserved - no BIL cast *) + | true,false (* size is changed, *) + | false,false -> (* possibly with sign *) + if x.width = 1 + then {width; sign; body = Cast (x.sign, width, x.body)} + else {width; sign; body = Cast (sign, width, x.body)} let cast_width x width = cast x width x.sign @@ -107,12 +109,14 @@ module Exp = struct let sign = derive_sign lhs.sign rhs.sign in binop_with_signedness sign op lhs rhs + let logop_with_cast op lhs rhs = + {(binop_with_cast op lhs rhs) with width = 1} + let concat lhs rhs = let width = lhs.width + rhs.width in let body = Concat (lhs.body, rhs.body) in { sign = Unsigned; width; body; } - let bit_result x = cast_width x 1 let plus = binop_with_cast Bil.plus let minus = binop_with_cast Bil.minus @@ -121,16 +125,16 @@ module Exp = struct let sdivide = binop_with_cast Bil.sdivide let modulo = binop_with_cast Bil.modulo let smodulo = binop_with_cast Bil.smodulo - let lt x y = bit_result (binop_with_cast Bil.lt x y) - let gt x y = bit_result (binop_with_cast Bil.lt y x) - let eq x y = bit_result (binop_with_cast Bil.eq x y) - let le x y = bit_result (binop_with_cast Bil.le x y) - let ge x y = bit_result (binop_with_cast Bil.le y x) - let neq x y = bit_result (binop_with_cast Bil.neq x y) - let slt x y = bit_result (binop_with_cast Bil.slt x y) - let sgt x y = bit_result (binop_with_cast Bil.slt y x) - let slte x y = bit_result (binop_with_cast Bil.sle x y) - let sgte x y = bit_result (binop_with_cast Bil.sle y x) + let lt x y = (logop_with_cast Bil.lt x y) + let gt x y = (logop_with_cast Bil.lt y x) + let eq x y = (logop_with_cast Bil.eq x y) + let le x y = (logop_with_cast Bil.le x y) + let ge x y = (logop_with_cast Bil.le y x) + let neq x y = (logop_with_cast Bil.neq x y) + let slt x y = (logop_with_cast Bil.slt x y) + let sgt x y = (logop_with_cast Bil.slt y x) + let slte x y = (logop_with_cast Bil.sle x y) + let sgte x y = (logop_with_cast Bil.sle y x) let lshift = unsigned_binop Bil.lshift let rshift = unsigned_binop Bil.rshift diff --git a/plugins/mips/mips_types.ml b/plugins/mips/mips_types.ml index 7cf6d3853..f9f166d23 100644 --- a/plugins/mips/mips_types.ml +++ b/plugins/mips/mips_types.ml @@ -12,6 +12,7 @@ type cpu = { jmp : exp -> rtl; cia : exp; word_width : exp; + delay : exp; word_bitwidth : bitwidth; reg : (op -> exp) ec; gpr : int -> exp; @@ -19,4 +20,3 @@ type cpu = { hi : exp; lo : exp; } - diff --git a/plugins/objdump/objdump_main.ml b/plugins/objdump/objdump_main.ml index d232cd0bd..ee0e975f3 100644 --- a/plugins/objdump/objdump_main.ml +++ b/plugins/objdump/objdump_main.ml @@ -1,98 +1,104 @@ +open Bap_core_theory open Core_kernel open Bap_future.Std open Bap.Std -open Regular.Std -open Format -open Option.Monad_infix open Objdump_config include Self() -let objdump_opts = "-rd --no-show-raw-insn" +open KB.Syntax -let objdump_cmds = +let default_objdump_opts = "-rd --no-show-raw-insn" + +let objdump_cmds demangler= objdump :: List.map targets ~f:(fun p -> p^"-objdump") |> String.Set.stable_dedup_list |> - List.map ~f:(fun cmd -> cmd ^ " " ^ objdump_opts) - - -(* expected format: [num] <[name]>: - Note the use of "^\s" to stop greedy globing of the re "+" - If you are not getting what you think you should, - this regular expression is a good place to start with debugging. -*) -let func_start_re = "([0-9A-Fa-f^\\s]+) <(.*)>:" - -let re r = - Re_pcre.re r |> Re.compile |> Re.execp -[@@warning "-D"] - -let objdump_strip = - String.strip ~drop:(function '<' | '>' | ':' | ' ' -> true | _ -> false) - -let text_to_addr l = - objdump_strip l |> (^) "0x" |> Int64.of_string + List.map ~f:(fun cmd -> + sprintf "%s %s %s" cmd default_objdump_opts @@ + match demangler with + | Some "disabled" -> "" + | None -> "-C" + | Some other -> "--demangle="^other) -let is_section_start s = - String.is_substring s ~substring:"Disassembly of section" +(* func_start ::= + | addr,space, "<", name, ">", ":" + | addr,space, "<", name, "@plt", ">", ":" *) +let parse_func_start = + let parse = + let func_start_re = {|([0-9A-Fa-f]+?) <(.*?)(@plt)?>:|} in + Re.Pcre.re func_start_re |> Re.compile |> Re.exec in + let parse_addr input ~start ~stop = + Z.of_substring_base 16 input ~pos:start ~len:(stop - start) in + fun input ~accept -> try + let groups = parse input in + let addr = parse_addr input + ~start:(Re.Group.start groups 1) + ~stop:(Re.Group.stop groups 1) + and name = Re.Group.get groups 2 in + info "%s => %s" (Z.format "%x" addr) name; + accept name addr + with _ -> () -(** "Disassembly of section .fini:" -> ".fini" *) -let section_name s = - match String.split_on_chars ~on:[' '; ':'] s with - | _ :: _ :: _ :: name :: _ -> Some name - | _ -> None - -let parse_func_start section l = - if re func_start_re l then - let xs = String.split_on_chars ~on:[' '; '@'] l in - match xs with - | addr::name::[] (* name w/o @plt case *) - | addr::name::_::[] -> (* name@plt case *) - let name = objdump_strip name in - if Some name = section then None - else - Some (name, text_to_addr addr) - | _ -> None - else None - -let popen cmd = +let run cmd ~f : _ Base.Continue_or_stop.t = let env = Unix.environment () in - let ic,oc,ec = Unix.open_process_full cmd env in - let r = In_channel.input_lines ic in - In_channel.iter_lines ec ~f:(fun msg -> debug "%s" msg); - match Unix.close_process_full (ic,oc,ec) with - | Unix.WEXITED 0 -> Some r + let stdin,stdout,stderr = Unix.open_process_full cmd env in + In_channel.iter_lines stdin ~f; + match Unix.close_process_full (stdin,stdout,stderr) with + | Unix.WEXITED 0 -> Stop () | Unix.WEXITED n -> - info "command `%s' terminated abnormally with exit code %d" cmd n; - None + info "`%s' has failed with %d" cmd n; + Continue () | Unix.WSIGNALED _ | Unix.WSTOPPED _ -> (* a signal number is internal to OCaml, so don't print it *) info "command `%s' was terminated by a signal" cmd; - None + Continue () + +let with_objdump_output demangler ~file ~f = + objdump_cmds demangler |> + List.fold_until ~init:() ~f:(fun () objdump -> + let cmd = sprintf "%s %S" objdump file in + run cmd ~f) + ~finish:ident + +let agent = + KB.Agent.register ~package:"bap.std" "objdump-symbolizer" + +let provide_roots funcs = + let promise_property slot = + KB.promise slot @@ fun label -> + KB.collect Theory.Label.addr label >>| function + | None -> None + | Some addr -> + let addr = Bitvec.to_bigint addr in + Option.some_if (Hashtbl.mem funcs addr) true in + promise_property Theory.Label.is_valid; + promise_property Theory.Label.is_subroutine -let run_objdump arch file = - let popen = fun cmd -> popen (cmd ^ " " ^ file) in - let names = Addr.Table.create () in - let width = Arch.addr_size arch |> Size.in_bits in - let add (name,addr) = - Hashtbl.set names ~key:(Addr.of_int64 ~width addr) ~data:name in - let () = match List.find_map objdump_cmds ~f:popen with - | None -> () - | Some lines -> - List.fold ~init:None lines ~f:(fun sec line -> - if is_section_start line then section_name line - else - let () = Option.iter (parse_func_start sec line) ~f:add in - sec) |> ignore in - if Hashtbl.length names = 0 +let provide_objdump demangler file = + let funcs = Hashtbl.create (module struct + type t = Z.t + let compare = Z.compare and hash = Z.hash + let sexp_of_t x = Sexp.Atom (Z.to_string x) + end) in + let accept name addr = Hashtbl.set funcs addr name in + with_objdump_output demangler ~file ~f:(parse_func_start ~accept); + if Hashtbl.length funcs = 0 then warning "failed to obtain symbols"; - Ok (Symbolizer.create (Hashtbl.find names)) + let symbolizer = Symbolizer.create @@ fun addr -> + Hashtbl.find funcs @@ + Bitvec.to_bigint (Word.to_bitvec addr) in + Symbolizer.provide agent symbolizer; + provide_roots funcs -let main () = - Stream.merge Project.Info.arch Project.Info.file ~f:run_objdump |> - Symbolizer.Factory.register name +let main demangler = + Stream.observe Project.Info.file @@ + provide_objdump demangler let () = + let demangler = + let doc = "Specify the demangler name. \ + Set to $(i,disabled) to disable demangling." in + Config.(param ~doc (some string) "demangler") in Config.manpage [ `S "DESCRIPTION"; `P "This plugin provides a symbolizer based on objdump. \ @@ -100,10 +106,10 @@ let () = is potentially fragile to changes in objdumps output."; `S "EXAMPLES"; `P "To view the symbols after running the plugin:"; - `P "$(b, bap --symbolizer=objdump --dump-symbols) $(i,executable)"; - `P "To use the internal extractor and *not* this plugin:"; - `P "$(b, bap --symbolizer=internal --dump-symbols) $(i,executable)"; + `P "$(b, bap) $(i,executable) --dump-symbols "; + `P "To view symbols without this plugin:"; + `P "$(b, bap) $(i,executable) --no-objdump --dump-symbols"; `S "SEE ALSO"; `P "$(b,bap-plugin-ida)(1)" ]; - Config.when_ready (fun _ -> main ()) + Config.when_ready (fun {get=(!!)} -> main !!demangler) diff --git a/plugins/optimization/optimization_main.ml b/plugins/optimization/optimization_main.ml index 6025b44a8..4cb382936 100644 --- a/plugins/optimization/optimization_main.ml +++ b/plugins/optimization/optimization_main.ml @@ -213,7 +213,8 @@ let () = (Since flags are rarely used non-locally). Finally, on level 3 we extend our analysis to all variables." in - Config.(param int ~default:2 ~doc "level") in + Config.(param int ~default:0 ~doc "level") in Config.when_ready (fun {Config.get=(!)} -> - Project.register_pass ~deps:["api"] ~autorun:true (run !level)) + if !level > 0 + then Project.register_pass ~deps:["api"] ~autorun:true (run !level)) diff --git a/plugins/powerpc/powerpc_rtl.ml b/plugins/powerpc/powerpc_rtl.ml index 74991fc0c..a435fba12 100644 --- a/plugins/powerpc/powerpc_rtl.ml +++ b/plugins/powerpc/powerpc_rtl.ml @@ -53,6 +53,7 @@ let rec bil_exp = function | Concat (x, y) -> Bil.(bil_exp x ^ bil_exp y) | Binop (op, x, y) -> Bil.binop op (bil_exp x) (bil_exp y) | Extract (hi, lo, x) -> Bil.extract hi lo (bil_exp x) + | Cast (_, 1, x) -> Bil.(cast low 1 (bil_exp x)) | Cast (Signed, width, x) -> Bil.(cast signed width (bil_exp x)) | Cast (Unsigned, width, x) -> Bil.(cast unsigned width (bil_exp x)) | Unop (op, x) -> Bil.unop op (bil_exp x) @@ -73,15 +74,16 @@ let var_of_exp e = match e.body with module Exp = struct let cast x width sign = - let nothing_to_cast = - (x.sign = sign && x.width = width) || - (x.width = width && width = 1) in - if nothing_to_cast then x - else - if x.width = 1 then - {width; sign; body = Cast (x.sign, width, x.body)} - else - {width; sign; body = Cast (sign, width, x.body)} + let same_sign = x.sign = sign + and same_size = x.width = width in + match same_sign, same_size with + | true,true -> x (* nothing is changed *) + | false,true -> {x with sign} (* size is preserved - no BIL cast *) + | true,false (* size is changed, *) + | false,false -> (* possibly with sign *) + if x.width = 1 + then {width; sign; body = Cast (x.sign, width, x.body)} + else {width; sign; body = Cast (sign, width, x.body)} let cast_width x width = cast x width x.sign @@ -108,13 +110,14 @@ module Exp = struct let sign = derive_sign lhs.sign rhs.sign in binop_with_signedness sign op lhs rhs + let logop_with_cast op lhs rhs = + {(binop_with_cast op lhs rhs) with width = 1} + let concat lhs rhs = let width = lhs.width + rhs.width in let body = Concat (lhs.body, rhs.body) in { sign = Unsigned; width; body; } - let bit_result x = cast_width x 1 - let derive_op x y op_u op_s = match derive_sign x.sign y.sign with | Signed -> op_s @@ -124,21 +127,22 @@ module Exp = struct let minus = binop_with_cast Bil.minus let times = binop_with_cast Bil.times + let lt x y = let op = derive_op x y Bil.lt Bil.slt in - bit_result (binop_with_cast op x y) + logop_with_cast op x y let gt x y = let op = derive_op x y Bil.lt Bil.slt in - bit_result (binop_with_cast op y x) + logop_with_cast op y x let le x y = let op = derive_op x y Bil.le Bil.sle in - bit_result (binop_with_cast op x y) + logop_with_cast op x y let ge x y = let op = derive_op x y Bil.le Bil.sle in - bit_result (binop_with_cast op y x) + logop_with_cast op y x let divide x y = let op = derive_op x y Bil.divide Bil.sdivide in @@ -148,8 +152,8 @@ module Exp = struct let op = derive_op x y Bil.modulo Bil.smodulo in binop_with_cast op x y - let eq x y = bit_result (binop_with_cast Bil.eq x y) - let neq x y = bit_result (binop_with_cast Bil.neq x y) + let eq x y = logop_with_cast Bil.eq x y + let neq x y = logop_with_cast Bil.neq x y let lshift = binop_with_cast Bil.lshift let rshift x y = diff --git a/plugins/primus_approximation/.merlin b/plugins/primus_approximation/.merlin new file mode 100644 index 000000000..28a83b472 --- /dev/null +++ b/plugins/primus_approximation/.merlin @@ -0,0 +1,10 @@ +REC +B ../../_build/plugins/bil +B ../../_build/lib/knowledge +B ../../_build/lib/bap_core_theory +FLG -open Bap_core_theory +PKG oUnit + +S . +B _build +FLG -short-paths \ No newline at end of file diff --git a/plugins/primus_approximation/Makefile b/plugins/primus_approximation/Makefile new file mode 100644 index 000000000..b925657c0 --- /dev/null +++ b/plugins/primus_approximation/Makefile @@ -0,0 +1,7 @@ + +float: + bapbuild -pkgs monads,bap-primus,bap-core-theory,bap-knowledge approximation.plugin + bapbundle install approximation.plugin +clean: + rm -rf _build + rm -f approximation.plugin diff --git a/plugins/primus_approximation/approximation.ml b/plugins/primus_approximation/approximation.ml new file mode 100644 index 000000000..5d0fd15a4 --- /dev/null +++ b/plugins/primus_approximation/approximation.ml @@ -0,0 +1,260 @@ +open Core_kernel +open Bap_primus.Std +open Bap.Std +open Monads.Std +open Bap_knowledge +open Bap_core_theory +open Theory +open Knowledge.Syntax + +include Self() + +module CT = Theory.Manager + +let word_of_float x = Word.of_int64 (Int64.bits_of_float x) +let floats = Array.map ~f:word_of_float + +(* Remez's coefficients for sin over 0 to pi. + Cos is implmented as sin(x + pi/2) *) +let table = Hashtbl.of_alist_exn (module String) [ + "sin", floats [| -2.1872506537704514e-10; 1.0000000200128274; + -3.0302439009453449e-7; -1.6666487069033565e-1; + -5.4940340462839474e-6; 8.3432342039387503e-3; + -1.1282260464920005e-5; -1.8998652204112551e-4; + -4.1639895115337191e-6; 4.0918788476303817e-6; + -2.6049442541122577e-7; + |]; + ] + +type op = Sin | Cos [@@deriving sexp] +let parse_op : string -> op option = fun s -> + Option.try_with (fun () -> op_of_sexp (Sexp.of_string s)) + +let size_of_var var = + let size = Bap.Std.Var.typ var in + match size with + | Type.Imm size -> size + | _ -> assert false + +let make_float_value fsort x = + let core_theory_i = CT.int (IEEE754.Sort.bits fsort) x in + CT.float fsort core_theory_i + +let with_fresh_var exp body = + exp >>= fun a -> + let sort = Value.sort a in + Var.scoped sort @@ fun v -> + CT.let_ v !!a (body v) + +let (>>>=) = with_fresh_var + +let (>>->) x f = + x >>= fun x -> + f (Value.sort x) x + + +module Reduction_Constant = struct + + let pi_float = 4. *. Float.atan 1. + let half_pi = 2. *. Float.atan 1. + let pi2 = 8. *. Float.atan 1. + let one_over_pi2 = 1. /. pi2 + + let pi_mul_2 fsort = + let wf = word_of_float pi2 in + let float_create = make_float_value fsort in + float_create wf + + let one_over_2pi fsort = + let wf = word_of_float one_over_pi2 in + let float_create = make_float_value fsort in + float_create wf + + let pi_div_2 fsort = + let wf = word_of_float half_pi in + let float_create = make_float_value fsort in + float_create wf + + let pi fsort = + let wf = word_of_float pi_float in + let float_create = make_float_value fsort in + float_create wf + + let fone fs = + let bs = Floats.size fs in + let one = Word.one (Bits.size bs) in + CT.(float fs (int bs one)) + + let fzero fs = + let bs = Floats.size fs in + let z = Word.zero (Bits.size bs) in + CT.(float fs (int bs z)) +end + +module Range_Reduction = struct + open CT + + let fast_and_dirty_is_neg x = msb (fbits x) + + let fast_and_dirty_is_fpos x = + and_ (inv (msb (fbits x))) (non_zero (fbits x)) + + let fast_and_dirty_ceil rm x = + x >>-> fun s x -> + let ix = cast_int (Floats.size s) rm !!x in + let a = cast_float s rm ix in + let b = cast_float s rm (succ ix) in + ite (eq (fbits !!x) (fbits a)) !!x b + + let fast_and_dirty_floor rm x = + x >>-> fun s x -> + cast_float s rm @@ + cast_int (Floats.size s) rm !!x + + let floor = fast_and_dirty_floor + let is_fpos = fast_and_dirty_is_fpos + let is_fneg = fast_and_dirty_is_neg + + let fmod r x y = + fmul r x y >>>= fun d -> + floor r (var d) >>>= fun c -> + fmul r y (var c) >>>= fun z -> + fsub r x (var z) + + let to_pos_angle rm x = + x >>-> fun sort x -> + Reduction_Constant.pi_mul_2 sort >>>= fun pi_2 -> + ite (is_fneg !!x) (fsub rm (var pi_2) !!x) !!x + + (* Sine is an odd function. *) + let odd_function rm x return = + x >>-> fun sort x -> + fsub rm !!x (Reduction_Constant.pi sort) >>>= fun x_m_pi -> + is_fpos (var x_m_pi) >>>= fun is_pos -> + ite (var is_pos) (var x_m_pi) !!x >>>= fun reduced -> + return (var reduced) (var is_pos) +end + +module Range_Reconstruction = struct + let sign_flip rm sign x = + CT.fmul rm sign x +end + +module Sin = struct + open CT + + let range_reduce rm x return = + x >>-> fun sort x -> + Reduction_Constant.one_over_2pi sort >>>= fun one_over_2pi -> + Range_Reduction.fmod rm !!x (var one_over_2pi) >>>= fun n -> + Range_Reduction.to_pos_angle rm (var n) >>>= fun pn -> + Range_Reduction.odd_function rm (var pn) return + + + let fast_and_dirty_fneg x = + x >>-> fun s x -> + let is = Floats.size s in + let width = Bits.size is in + let bit_p = Word.of_int ~width (width - 1) in + let mask = Word.(one width lsl bit_p) in + float s (logxor (fbits !!x) (int is mask)) + + let fneg = fast_and_dirty_fneg + + let build ?rm:(rm=rne) x c poly_eval = + range_reduce rm x @@ fun n needs_corr -> + poly_eval c n >>>= fun p -> + ite needs_corr (fneg (var p)) (var p) +end + +module Horner + : sig + (* [run op v coef] computes an approximation of [op] using coef with variable v *) + val build : op -> Bap.Std.Var.t -> word array -> exp option + end += struct + open CT + + let exp x = + let open Knowledge.Syntax in + let x = x >>| Value.semantics in + match Knowledge.run x Knowledge.empty with + | Error _ -> assert false + | Ok (s,_) -> Semantics.get Bil.Domain.exp s + + let approximate ~coefs ?rm:(rm=rne) x = + let rank = Array.length coefs - 1 in + let rec sum i y = + if i >= 0 then + fmul rm x y >>>= fun y -> + fadd rm (var y) coefs.(i) >>>= fun y -> + sum (i - 1) (var y) + else y in + sum (rank-1) coefs.(rank) + + let build func var coefs = + let size = size_of_var var in + let fsort = IEEE754.Sort.define (IEEE754.binary size |> Option.value_exn) in + let float_create = make_float_value fsort in + let c = Array.map ~f:float_create coefs in + let v = Var.define fsort "v" in + let polynomial_evaluaton = (fun c n -> approximate c n) in + let formula = match func with + | Sin -> Sin.build CT.(var v) c polynomial_evaluaton + | Cos -> + let shifted_var = CT.(fadd rne (var v) (Reduction_Constant.pi_div_2 fsort)) in + Sin.build shifted_var c polynomial_evaluaton in + exp formula +end + + +module Approximate(Machine : Primus.Machine.S) = struct + + module Eval = Primus.Interpreter.Make(Machine) + module Value = Primus.Value.Make(Machine) + + let int_of_value x = + Primus.Value.to_word x |> Word.to_int_exn + + [@@@warning "-P"] + let run [name; size; x;] = + let open Machine.Syntax in + let size = int_of_value size in + Value.Symbol.of_value name >>= fun name -> + let elementry_function = + match parse_op name with + | Some op -> op + | _ -> assert false in + let coefficients = + let name = match elementry_function with + | Sin | Cos -> "sin" in + match Hashtbl.find table name with + | Some coefs -> coefs + | _ -> assert false in + let vx = Bap.Std.Var.create "v" (Type.imm size) in + let exp = Horner.build elementry_function vx coefficients in + match exp with + | None -> assert false + | Some e -> + Eval.set vx x >>= fun () -> + printf "%a\n" Exp.ppo e; + Eval.exp e +end + +module Main(Machine : Primus.Machine.S) = struct + module Lisp = Primus.Lisp.Make(Machine) + open Primus.Lisp.Type.Spec + + let def name types closure = + Lisp.define ~types name closure + + let init () = + Machine.sequence [ + def "approximate" (tuple [sym; int; int] @-> int) (module Approximate); + ] +end + +let main () = + Primus.Machine.add_component (module Main) + +let () = Config.when_ready (fun _ -> main ()) diff --git a/plugins/primus_approximation/approximation.mli b/plugins/primus_approximation/approximation.mli new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/primus_approximation/lisp/math.lisp b/plugins/primus_approximation/lisp/math.lisp new file mode 100644 index 000000000..8484d2ec3 --- /dev/null +++ b/plugins/primus_approximation/lisp/math.lisp @@ -0,0 +1,11 @@ +(defun sin (x) + (declare (external "sin")) + (approximate 'sin 64 x)) + +(defun cos (x) + (declare (external "cos")) + (approximate 'cos 64 x)) + +(defun log (x) + (declare (external "log")) + (approximate 'log 64 x)) diff --git a/plugins/primus_approximation/lisp/test.lisp b/plugins/primus_approximation/lisp/test.lisp new file mode 100644 index 000000000..9ac747a45 --- /dev/null +++ b/plugins/primus_approximation/lisp/test.lisp @@ -0,0 +1,11 @@ +(require math) + +(defmethod init () + (msg "hello from math $0" (sin 4611686018427387904:64)) + (msg "hello from math sin(316.159265358979326) = $0" (sin 4644269548807471931:64)) + (msg "hello from math sin(7853981635.97448254) = $0" (sin 4755029503896426363:64)) + (msg "hello from math sin(8.28318530717958623) = $0" (sin 4620852636837615244:64)) +) + + + diff --git a/plugins/primus_test/lisp/check-value.lisp b/plugins/primus_test/lisp/check-value.lisp index 56b3a8c84..a254d7a7b 100644 --- a/plugins/primus_test/lisp/check-value.lisp +++ b/plugins/primus_test/lisp/check-value.lisp @@ -38,6 +38,7 @@ ;; Observed Signals + (defmethod loaded (addr value) (check-value/dereferenced addr value)) diff --git a/plugins/print/print_main.ml b/plugins/print/print_main.ml index 854e9a7c7..82725e8e8 100644 --- a/plugins/print/print_main.ml +++ b/plugins/print/print_main.ml @@ -1,3 +1,4 @@ +open Bap_core_theory open Core_kernel open Regular.Std open Graphlib.Std @@ -79,9 +80,12 @@ let extract_program subs secs proj = should_print subs (Sub.name sub) && should_print secs (sec_name mem bir sub)) -let print_bir subs secs ppf proj = +let print_bir subs secs sema ppf proj = + let pp = match sema with + | None -> Program.pp + | Some cs -> Program.pp_slots cs in Text_tags.with_mode ppf "attr" ~f:(fun () -> - Program.pp ppf (extract_program subs secs proj)) + pp ppf (extract_program subs secs proj)) module Adt = struct let pr ch = Format.fprintf ch @@ -238,13 +242,24 @@ let pp_addr ppf a = let setup_tabs ppf = pp_print_as ppf 50 ""; - pp_set_tab ppf () [@ocaml.warning "-3"] + pp_set_tab ppf () + +let sorted_blocks nodes = + let init = Set.empty (module Block) in + Seq.fold nodes ~init ~f:Set.add |> + Set.to_sequence + +let sort_fns fns = + let fns = Array.of_list_rev fns in + Array.sort fns ~compare:(fun (_,b1,_) (_,b2,_) -> + Block.compare b1 b2); + Seq.of_array fns let print_disasm pp_insn subs secs ppf proj = let memory = Project.memory proj in let syms = Project.symbols proj in - pp_open_tbox ppf () [@ocaml.warning "-3"]; - setup_tabs ppf [@ocaml.warning "-3"]; + pp_open_tbox ppf (); + setup_tabs ppf; Memmap.filter_map memory ~f:(Value.get Image.section) |> Memmap.to_sequence |> Seq.iter ~f:(fun (mem,sec) -> Symtab.intersecting syms mem |> @@ -254,14 +269,13 @@ let print_disasm pp_insn subs secs ppf proj = | _ when not(should_print secs sec) -> () | fns -> fprintf ppf "@\nDisassembly of section %s@\n" sec; - List.iter fns ~f:(fun (name,entry,cfg) -> + Seq.iter (sort_fns fns) ~f:(fun (name,entry,cfg) -> fprintf ppf "@\n%a: <%s>@\n" pp_addr (Block.addr entry) name; - Graphlib.reverse_postorder_traverse (module Graphs.Cfg) - ~start:entry cfg |> Seq.iter ~f:(fun blk -> - let mem = Block.memory blk in - fprintf ppf "%a:@\n" pp_addr (Memory.min_addr mem); - Block.insns blk |> List.iter ~f:(pp_insn ppf)))); - pp_close_tbox ppf () [@ocaml.warning "-3"] + sorted_blocks (Graphs.Cfg.nodes cfg) |> Seq.iter ~f:(fun blk -> + let mem = Block.memory blk in + fprintf ppf "%a:@\n" pp_addr (Memory.min_addr mem); + Block.insns blk |> List.iter ~f:(pp_insn ppf)))); + pp_close_tbox ppf () let pp_bil fmt ppf (mem,insn) = let pp_bil ppf = Bil.Io.print ~fmt ppf in @@ -275,13 +289,17 @@ let pp_insn fmt ppf (mem,insn) = Insn.Io.print ~fmt ppf insn; fprintf ppf "@\n" -let main attrs ansi_colors demangle symbol_fmts subs secs = +let pp_knowledge ppf _ = + KB.pp_state ppf @@ + Toplevel.current () + +let main attrs ansi_colors demangle symbol_fmts subs secs doms = let ver = version in let pp_syms = Data.Write.create ~pp:(print_symbols subs secs demangle symbol_fmts) () in Project.add_writer ~desc:"print symbol table" ~ver "symbols" pp_syms; - let pp_bir = Data.Write.create ~pp:(print_bir subs secs) () in + let pp_bir = Data.Write.create ~pp:(print_bir subs secs doms) () in let pp_adt = Data.Write.create ~pp:Adt.pp_project () in List.iter attrs ~f:Text_tags.Attr.show; @@ -310,6 +328,11 @@ let main attrs ansi_colors demangle symbol_fmts subs secs = Data.Write.create ~pp:(print_disasm (pp_insn "pretty") subs secs) () in let pp_disasm_sexp = Data.Write.create ~pp:(print_disasm (pp_insn "sexp") subs secs) () in + + let pp_knowledge = Data.Write.create ~pp:(pp_knowledge) () in + + Project.add_writer ~ver "knowledge" + ~desc:"dumps the knowledge base" pp_knowledge; Project.add_writer ~ver "cfg" ~desc:"print rich CFG for each procedure" pp_cfg; Project.add_writer ~ver "asm" @@ -383,5 +406,14 @@ let () = let secs : string list Config.param = let doc = "Only display information for section $(docv)" in Config.(param_all string "section" ~docv:"NAME" ~doc) in + let semantics : string list option Config.param = + let doc = + "Display the $(docv) semantics of the program. If used without + an argument then all semantic values associated with terms will + be printed. Otherwise only the selected (if present) will be + printed." in + Config.(param (some (list string)) ~as_flag:(Some []) + ~doc ~docv:"SEMANTICS-LIST" "semantics") in Config.when_ready (fun {Config.get=(!)} -> - main !bir_attr !ansi_colors !demangle !print_symbols !subs !secs) + main !bir_attr !ansi_colors !demangle !print_symbols !subs !secs + !semantics) diff --git a/plugins/read_symbols/read_symbols_main.ml b/plugins/read_symbols/read_symbols_main.ml index 629feb220..25c04caff 100644 --- a/plugins/read_symbols/read_symbols_main.ml +++ b/plugins/read_symbols/read_symbols_main.ml @@ -1,3 +1,5 @@ +open Bap_knowledge + open Core_kernel open Regular.Std open Bap_future.Std @@ -17,22 +19,25 @@ let extract name arch = module type Target = sig type t val of_blocks : (string * addr * addr) seq -> t - module Factory : sig - val register : string -> t source -> unit - end + val provide : Knowledge.agent -> t -> unit end +let agent = + let reliability = Knowledge.Agent.authorative in + Knowledge.Agent.register ~package:"bap.std" + ~reliability "user-symbolizer" + let register syms = - let name = "file" in - let register (module T : Target) = - let source = Stream.map Project.Info.arch (fun arch -> - Or_error.try_with (fun () -> - extract syms arch |> - Seq.of_list |> T.of_blocks)) in - T.Factory.register name source in - register (module Rooter); - register (module Symbolizer); - register (module Reconstructor) + let provide (module T : Target) = + Stream.observe Project.Info.arch (fun arch -> + extract syms arch |> + Seq.of_list |> T.of_blocks |> + T.provide agent) in + provide (module struct + include Rooter + let provide _ s = provide s + end); + provide (module Symbolizer) let () = let () = Config.manpage [ diff --git a/plugins/relocatable/rel_symbolizer.ml b/plugins/relocatable/rel_symbolizer.ml index c7235448c..c6ea4e20f 100644 --- a/plugins/relocatable/rel_symbolizer.ml +++ b/plugins/relocatable/rel_symbolizer.ml @@ -1,3 +1,4 @@ +open Bap_knowledge open Core_kernel open Bap.Std open Bap_future.Std @@ -23,35 +24,22 @@ module Rel = struct let external_symbols = arch_width >>= fun width -> - Fact.collect Ogre.Query.( - select (from external_reference)) >>= fun s -> + Fact.collect + Ogre.Query.(select (from external_reference)) >>= fun s -> Fact.return (of_aseq @@ Seq.map s ~f:(fun (addr, data) -> Addr.of_int64 ~width addr, data)) end -let find start len exts = - Seq.find_map ~f:(Map.find exts) @@ Seq.init len ~f:(Addr.nsucc start) - -let create cfg exts = - let insns = Disasm.create cfg |> Disasm.insns in - Seq.fold insns ~init:Addr.Map.empty - ~f:(fun calls (m,_) -> - let min = Memory.min_addr m in - let len = Memory.length m in - match find min len exts with - | None -> calls - | Some name -> Map.set calls min name) +let agent = Knowledge.Agent.register + ~package:"bap.std" "relocation-symbolizer" let init () = - let open Project.Info in - Stream.Variadic.(apply (args cfg $ spec) ~f:(fun cfg spec -> - let name = - match Fact.eval Rel.external_symbols spec with - | Ok exts -> Map.find (create cfg exts) - | _ -> fun _ -> None in - Ok (Symbolizer.create name))) |> - Symbolizer.Factory.register "relocatable" + Stream.observe Project.Info.spec @@ fun spec -> + let name = match Fact.eval Rel.external_symbols spec with + | Ok exts -> Map.find exts + | _ -> fun _ -> None in + Symbolizer.provide agent (Symbolizer.create name) let () = Config.manpage [ diff --git a/plugins/x86/x86_abi.ml b/plugins/x86/x86_abi.ml index 1f6baca7f..206a22c70 100644 --- a/plugins/x86/x86_abi.ml +++ b/plugins/x86/x86_abi.ml @@ -1,4 +1,4 @@ -open Core_kernel +open Core_kernel.Std open Bap.Std open Bap_c.Std open Bap_future.Std @@ -226,7 +226,6 @@ let dispatch default sub attrs proto = let demangle demangle prog = Term.map sub_t prog ~f:(fun sub -> let name = demangle (Sub.name sub) in - Tid.set_name (Term.tid sub) name; Sub.with_name sub name) let setup ?(abi=fun _ -> None) () = diff --git a/plugins/x86/x86_lifter.ml b/plugins/x86/x86_lifter.ml index 7a465f0f2..1f4d935b9 100644 --- a/plugins/x86/x86_lifter.ml +++ b/plugins/x86/x86_lifter.ml @@ -1037,7 +1037,7 @@ module ToIR = struct cf := cast high 1 dst; ] (* else *) [ if_ (count = one) [ - oF := cast high 1 dst lxor ((cast high 1 dst) lsl int_exp 1 word_size); + oF := cast high 1 dst lxor (cast high 1 (dst lsl int_exp 1 word_size)); ] (* else *) [ oF := unknown "OF undefined after rotate of more then 1 bit" bool_t; ] diff --git a/src/bap_main.ml b/src/bap_main.ml index 87cdcb66d..d33791bf2 100644 --- a/src/bap_main.ml +++ b/src/bap_main.ml @@ -1,3 +1,5 @@ +open Bap_knowledge + open Core_kernel open Bap_plugins.Std open Bap_future.Std @@ -15,34 +17,6 @@ module Recipe = Bap_recipe exception Failed_to_create_project of Error.t [@@deriving sexp] exception Pass_not_found of string [@@deriving sexp] -let find_source (type t) (module F : Source.Factory.S with type t = t) - field o = Option.(field o >>= F.find) - -let brancher = find_source (module Brancher.Factory) brancher -let reconstructor = - find_source (module Reconstructor.Factory) reconstructor - -let merge_streams ss ~f : 'a Source.t = - Stream.concat_merge ss - ~f:(fun s s' -> match s, s' with - | Ok s, Ok s' -> Ok (f s s') - | Ok _, Error er - | Error er, Ok _ -> Error er - | Error er, Error er' -> - Error (Error.of_list [er; er'])) - -let merge_sources create field (o : Bap_options.t) ~f = match field o with - | [] -> None - | names -> match List.filter_map names ~f:create with - | [] -> assert false - | ss -> Some (merge_streams ss ~f) - -let symbolizer = - merge_sources Symbolizer.Factory.find symbolizers ~f:(fun s1 s2 -> - Symbolizer.chain [s1;s2]) - -let rooter = - merge_sources Rooter.Factory.find rooters ~f:Rooter.union let print_formats_and_exit () = Bap_format_printer.run `writers (module Project); @@ -74,8 +48,8 @@ let args filename argv = String.Hash_set.sexp_of_t inputs |> Sexp.to_string_mach -let digest o = - Data.Cache.digest ~namespace:"project" "%s%s" +let digest ~namespace o = + Data.Cache.digest ~namespace "%s%s" (Caml.Digest.(file o.filename |> to_hex)) (args o.filename Sys.argv) @@ -105,44 +79,75 @@ let process options project = | `stdout,fmt,ver -> Project.Io.show ~fmt ?ver project) -let extract_format filename = - let fmt = match String.rindex filename '.' with - | None -> filename - | Some n -> String.subo ~pos:(n+1) filename in - match Bap_fmt_spec.parse fmt with - | `Error _ -> None, None - | `Ok (_,fmt,ver) -> Some fmt, ver - -let main o = +let knowledge_cache () = + let reader = Data.Read.create + ~of_bigstring:Knowledge.of_bigstring () in + let writer = Data.Write.create + ~to_bigstring:Knowledge.to_bigstring () in + Data.Cache.Service.request reader writer + +let project_state_cache () = + let module State = struct + type t = Project.state [@@deriving bin_io] + end in + let of_bigstring = Binable.of_bigstring (module State) in + let to_bigstring = Binable.to_bigstring (module State) in + let reader = Data.Read.create ~of_bigstring () in + let writer = Data.Write.create ~to_bigstring () in + Data.Cache.Service.request reader writer + +let import_knowledge_from_cache digest = + let digest = digest ~namespace:"knowledge" in + info "looking for knowledge with digest %a" + Data.Cache.Digest.pp digest; + let cache = knowledge_cache () in + match Data.Cache.load cache digest with + | None -> () + | Some state -> + info "importing knowledge from cache"; + Toplevel.set state + +let load_project_state_from_cache digest = + let digest = digest ~namespace:"project" in + let cache = project_state_cache () in + Data.Cache.load cache digest + +let save_project_state_to_cache digest state = + let digest = digest ~namespace:"project" in + let cache = project_state_cache () in + Data.Cache.save cache digest state + +let store_knowledge_in_cache digest = + let digest = digest ~namespace:"knowledge" in + info "caching knowledge with digest %a" + Data.Cache.Digest.pp digest; + let cache = knowledge_cache () in + Toplevel.current () |> + Data.Cache.save cache digest + + +let main ({filename; loader; disassembler} as opts) = + let digest = digest opts in + import_knowledge_from_cache digest; + let state = load_project_state_from_cache digest in let proj_of_input input = - let rooter = rooter o - and brancher = brancher o - and reconstructor = reconstructor o - and symbolizer = symbolizer o in - Project.create input ~disassembler:o.disassembler - ?brancher ?rooter ?symbolizer ?reconstructor |> function + Project.create ?state input ~disassembler |> function | Error err -> raise (Failed_to_create_project err) - | Ok project -> - Project.Cache.save (digest o) project; - project in - let proj_of_file ?ver ?fmt file = - In_channel.with_file file - ~f:(fun ch -> Project.Io.load ?fmt ?ver ch) in - let project = match Project.Cache.load (digest o) with - | Some proj -> - Project.restore_state proj; - proj - | None -> match o.source with - | `Project -> - let fmt,ver = extract_format o.filename in - proj_of_file ?fmt ?ver o.filename - | `Memory arch -> - proj_of_input @@ - Project.Input.binary arch ~filename:o.filename - | `Binary -> - proj_of_input @@ - Project.Input.file ~loader:o.loader ~filename: o.filename in - process o project + | Ok project -> project in + let project = + match opts.source with + | `Project -> failwith "Unsupported feature: project of file" + | `Memory arch -> + proj_of_input @@ + Project.Input.binary arch ~filename + | `Binary -> + proj_of_input @@ + Project.Input.file ~loader ~filename in + if Option.is_none state then begin + store_knowledge_in_cache digest; + save_project_state_to_cache digest (Project.state project); + end; + process opts project let program_info = let doc = "Binary Analysis Platform" in @@ -201,11 +206,12 @@ let program_info = `P "$(b,bap-mc)(1), $(b,bap-byteweight)(1), $(b,bap)(3)" ] in Term.info "bap" ~version:Config.version ~doc ~man + let program _source = let create passopt - _ _ a b c d e f g i j k = (Bap_options.Fields.create - a b c d e f g i j k []), passopt in + _ _ a b c d e f = (Bap_options.Fields.create + a b c d e f []), passopt in let open Bap_cmdline_terms in let passopt : string list Term.t = let doc = @@ -223,11 +229,7 @@ let program _source = $(loader ()) $(dump_formats ()) $source_type - $verbose - $(brancher ()) - $(symbolizers ()) - $(rooters ()) - $(reconstructor ())), + $verbose), program_info let parse_source argv = @@ -349,17 +351,35 @@ let nice_pp_error fmt er = let open R in match r with | With_backtrace (r, backtrace) -> - Format.fprintf fmt "%a\n" pp r; - Format.fprintf fmt "Backtrace:\n%s" @@ String.strip backtrace + Format.fprintf fmt "%a\n" pp r; + Format.fprintf fmt "Backtrace:\n%s" @@ String.strip backtrace | String s -> Format.fprintf fmt "%s" s | r -> pp_sexp fmt (R.sexp_of_t r) in Format.fprintf fmt "%a" pp (R.of_info (Error.to_info er)) +let setup_gc () = + let opts = Caml.Gc.get () in + info "Setting GC parameters"; + Caml.Gc.set { + opts with + window_size = 20; + minor_heap_size = 1024 * 1024; + major_heap_increment = 64 * 1024 * 1024; + space_overhead = 200; + } + +let has_env var = match Sys.getenv var with + | exception _ -> false + | _ -> true + let () = let () = try if Sys.getenv "BAP_DEBUG" <> "0" then Printexc.record_backtrace true with Caml.Not_found -> () in + if not (has_env "OCAMLRUNPARAM" || has_env "CAMLRUNPARAM") + then setup_gc () + else info "GC parameters are overriden by a user"; Sys.(set_signal sigint (Signal_handle exit)); let argv = load_recipe () in Log.start ?logdir:(get_logdir argv)(); diff --git a/src/bap_mc.ml b/src/bap_mc.ml index 6f0d127bd..8d1900569 100644 --- a/src/bap_mc.ml +++ b/src/bap_mc.ml @@ -3,6 +3,7 @@ open Format open Bap.Std open Bap_plugins.Std open Mc_options +open Bap_core_theory include Self() exception Bad_user_input @@ -12,12 +13,13 @@ exception Create_mem of Error.t exception No_input exception Unknown_arch exception Trailing_data of int +exception Inconsistency of KB.conflict module Program(Conf : Mc_options.Provider) = struct open Conf module Dis = Disasm_expert.Basic - let bad_insn addr state mem start = + let bad_insn addr state _ start = let stop = Addr.(Dis.addr state - addr |> to_int |> ok_exn) in raise (Bad_insn (Dis.memory state, start, stop)) @@ -49,7 +51,7 @@ module Program(Conf : Mc_options.Provider) = struct | "" | "\n" -> exit 0 | "\\x" -> to_binary input | "0x" -> to_binary ~map:escape_0x input - | x -> to_binary ~map:prepend_slash_x input + | _ -> to_binary ~map:prepend_slash_x input let create_memory arch s addr = let endian = Arch.endian arch in @@ -63,42 +65,56 @@ module Program(Conf : Mc_options.Provider) = struct List.map ~f:sexp_of_kind |> List.iter ~f:(printf "%a@." Sexp.pp) + let new_insn arch mem insn = + let open KB.Syntax in + KB.Object.create Theory.Program.cls >>= fun code -> + KB.provide Arch.slot code (Some arch) >>= fun () -> + KB.provide Memory.slot code (Some mem) >>= fun () -> + KB.provide Dis.Insn.slot code (Some insn) >>| fun () -> + code + + let lift arch mem insn = + match KB.run Theory.Program.cls (new_insn arch mem insn) KB.empty with + | Ok (code,_) -> KB.Value.get Theory.Program.Semantics.slot code + | Error conflict -> raise (Inconsistency conflict) + + let print_insn_size should_print mem = if should_print then let len = Memory.length mem in printf "%#x@\n" len let print_insn insn_formats insn = - let insn = Insn.of_basic insn in List.iter insn_formats ~f:(fun fmt -> Insn.with_printer fmt (fun () -> printf "%a@." Insn.pp insn)) - let bil_of_insn lift mem insn = - match lift mem insn with - | Ok bil -> bil - | Error e -> [Bil.special @@ sprintf "Lifter: %s" @@ - Error.to_string_hum e] - - let print_bil lift mem insn = - let bil = bil_of_insn lift mem in + let print_bil insn = + let bil = Insn.bil insn in List.iter options.bil_formats ~f:(fun fmt -> - printf "%s@." (Bytes.to_string @@ Bil.to_bytes ~fmt (bil insn))) + printf "%s@." (Bytes.to_string @@ Bil.to_bytes ~fmt bil)) - let print_bir lift mem insn = - let bil = bil_of_insn lift mem insn in - let bs = Blk.from_insn (Insn.of_basic ~bil insn) in + let print_bir insn = + let bs = Blk.from_insn insn in List.iter options.bir_formats ~f:(fun fmt -> printf "%s" @@ String.concat ~sep:"\n" (List.map bs ~f:(fun b -> Bytes.to_string @@ Blk.to_bytes ~fmt b))) - let print arch mem insn = - let module Target = (val target_of_arch arch) in + let print_sema sema = + Option.iter options.semantics ~f:(function + | [] -> printf "%a@\n" KB.Value.pp sema + | cs -> + let pp = KB.Value.pp_slots cs in + printf "%a@\n" pp sema ) + + let print arch mem code = + let insn = lift arch mem code in print_insn_size options.show_insn_size mem; print_insn options.insn_formats insn; - print_bil Target.lift mem insn; - print_bir Target.lift mem insn; - if options.show_kinds then print_kinds insn + print_bil insn; + print_bir insn; + print_sema insn; + if options.show_kinds then print_kinds code let main () = let arch = match Arch.of_string options.arch with @@ -188,6 +204,13 @@ module Cmdline = struct Arg.(value & opt_all ~vopt:"pretty" string [] & info ["show-bir"] ~doc) + let semantics = + let doc = + "Show instruction semantics. If an option value is specified, + then outputs only the semantics with the given name." in + Arg.(value & opt ~vopt:(Some []) (some (list string)) None & + info ["show-semantics"] ~doc) + let addr = let doc = "Specify an address of first byte" in Arg.(value & opt string "0x0" & info ["addr"] ~doc) @@ -196,8 +219,8 @@ module Cmdline = struct let doc = "Stop after the first instruction is decoded" in Arg.(value & flag & info ["only-one"] ~doc) - let create a b c d e f g h i j = - Mc_options.Fields.create a b c d e f g h i j + let create a b c d e f g h i j k = + Mc_options.Fields.create a b c d e f g h i j k let src = let doc = "String to disassemble. If not specified read stdin" in @@ -231,7 +254,7 @@ module Cmdline = struct `S "SEE ALSO"; `P "$(b,bap)(1), $(b,bap-llvm)(1), $(b,llvm-mc)(1)"] in Term.(const create $(disassembler ()) $src $addr $only_one $arch $show_insn_size - $insn_formats $bil_formats $bir_formats $show_kinds), + $insn_formats $semantics $bil_formats $bir_formats $show_kinds), Term.info "bap-mc" ~doc ~man ~version:Config.version let exitf n = @@ -263,6 +286,9 @@ let _main : unit = | Ok () -> exit 0 | Error err -> exitf 64 "%s\n" Error.(to_string_hum err) with + | Inconsistency conflict -> + exitf 67 "Lifters failed with a conflict: %a" + KB.Conflict.pp conflict | Bad_user_input -> exitf 65 "Could not parse: malformed input" | No_input -> exitf 66 "Could not read from stdin" diff --git a/src/bap_options.ml b/src/bap_options.ml index a0715d10e..b425836d0 100644 --- a/src/bap_options.ml +++ b/src/bap_options.ml @@ -11,9 +11,5 @@ type t = { dump : fmt_spec list; source : source; verbose : bool; - brancher : string option; - symbolizers : string list; - rooters : string list; - reconstructor : string option; passes : string list; } [@@deriving sexp, fields] diff --git a/src/mc_options.ml b/src/mc_options.ml index 8c52c9198..8691d57d9 100644 --- a/src/mc_options.ml +++ b/src/mc_options.ml @@ -8,6 +8,7 @@ type t = { arch : string; show_insn_size : bool; insn_formats : string list; + semantics : string list option; bil_formats : string list; bir_formats : string list; show_kinds: bool; diff --git a/testsuite b/testsuite index 16065e0f2..7dfd7d0d4 160000 --- a/testsuite +++ b/testsuite @@ -1 +1 @@ -Subproject commit 16065e0f2dc8367a0ea8ccce9556658e0745cda4 +Subproject commit 7dfd7d0d4684991583fd08a8941fb6fea69f2154