-
Notifications
You must be signed in to change notification settings - Fork 1
/
infer.sml
219 lines (193 loc) · 7.61 KB
/
infer.sml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
(* An example of type inference for a tiny ML-like language.
* Our language includes
*
* - Functions
* - Booleans
* - If expressions
* - Let abstraction
*
* This means the types are function types [t -> t] and booleans
* [bool] as well as type variables for polymorphic types
*
* To simplify the process of representing variables, we use something
* called DeBruijn indices (pronounced "DeBrown"). With these a
* variable is a number that counts how many binders it is away from
* where it was bound
*
* So if we had
*
* __________
* | |
* fn => fn => 1 + 0
* |________|
*
* The 0 refers to the inner binder and the 1 to the binder 1 binder
* away. Here's the wikipedia article:
* http://www.wikiwand.com/en/De_Bruijn_index
*
* For types we'll just use numbers to indicate binding. Since we
* never have nested scopes (all the binders for a type are fully
* shifted to the left) we don't need a clever scheme like DeBruijn
* indices. We don't allow the user to annotate types: we infer
* absolutely everything.
*)
structure TypeInfer :> TYPEINFER =
struct
fun dedup [] = []
| dedup (x :: xs) =
if List.exists (fn y => x = y) xs
then dedup xs
else x :: dedup xs
(* As mentioned, type variables are just integers: They're globally unique *)
type tvar = int
(* Cough, cough, ignore this bit. It lets us generate fresh type
* variables but uses some features of SML we haven't talked about yet *)
local val freshSource = ref 0 in
fun fresh () : tvar =
!freshSource before freshSource := !freshSource + 1
end
(* A normal (not polymorphic) type *)
datatype monotype = TBool
| TArr of monotype * monotype
| TVar of tvar
(* A polymorphic type which has binds a list of type variables *)
datatype polytype = PolyType of int list * monotype
(* Our definition of expressions in our language *)
datatype exp = True
| False
| Var of int
| App of exp * exp
| Let of exp * exp
| Fn of exp
| If of exp * exp * exp
(* A piece of information we know about a particular variable. We
either know it has a polytype or a monotype *)
datatype info = PolyTypeVar of polytype
| MonoTypeVar of monotype
(* A list of information. The ith entry is about the DeBruijn variable i *)
type context = info list
(* Substitute some type var for another type *)
fun subst ty' var ty =
case ty of
TVar var' => if var = var' then ty' else TVar var'
| TArr (l, r) => TArr (subst ty' var l, subst ty' var r)
| TBool => TBool
(* Collect a list of all the variables in a type *)
fun freeVars t =
case t of
TVar v => [v]
| TArr (l, r) => freeVars l @ freeVars r
| TBool => []
(* Any polytype can construct a new monotype with the polymorphic
* variables replaced by fresh weakly polymorphic variables *)
fun mintNewMonoType (PolyType (ls, ty)) =
foldl (fn (v, t) => subst (TVar (fresh ())) v t) ty ls
fun generalizeMonoType ctx ty =
let fun notMem xs x = List.all (fn y => x <> y) xs
fun free (MonoTypeVar m) = freeVars m
| free (PolyTypeVar (PolyType (bs, m))) =
List.filter (notMem bs) (freeVars m)
val ctxVars = List.concat (List.map free ctx)
val polyVars = List.filter (notMem ctxVars) (freeVars ty)
in PolyType (dedup polyVars, ty) end
exception UnboundVar of int
(* Lookup a variable in a list. If we don't find the variable then we throw
* an UnboundVar exception *)
fun lookupVar var ctx =
case List.nth (ctx, var) handle Subscript => raise UnboundVar var of
PolyTypeVar pty => mintNewMonoType pty
| MonoTypeVar mty => mty
(* A big part of the type checker is the generation of
* "constraints". These are assertions that some type is equivalent to
* another. These are represented with just a tuple of types which is to be
* read "The left component unifies with the right component".
*
* Once we solve these constrains we get a [sol] which is a list of
* variables to the monotypes the become.
*)
type constr = (monotype * monotype)
type sol = (tvar * monotype) list
exception UnificationError of constr
(* Substitute a type for a type variable in a list of constraints *)
fun substConstrs ty var (cs : constr list) : constr list =
map (fn (l, r) => (subst ty var l, subst ty var r)) cs
(* Given a solution and a type, apply the solution by rewriting each
* variable with what the solution says it's equivalent to.
*)
fun applySol sol ty =
foldl (fn ((v, ty), ty') => subst ty v ty') ty sol
fun applySolCxt sol cxt =
let fun applyInfo i =
case i of
PolyTypeVar (PolyType (bs, m)) =>
PolyTypeVar (PolyType (bs, (applySol sol m)))
| MonoTypeVar m => MonoTypeVar (applySol sol m)
in map applyInfo cxt end
(* Given a solution, add a new member to the solution ensuring that
* the rhs of the new part of the solution is normal with respect to the
* rest of the solution.
*)
fun addSol v ty sol = (v, applySol sol ty) :: sol
(* Returns true if a variable occurs in the type *)
fun occursIn v ty = List.exists (fn v' => v = v') (freeVars ty)
fun unify ([] : constr list) : sol = []
| unify (c :: constrs) =
case c of
(TBool, TBool) => unify constrs
| (TVar i, TVar j) =>
if i = j
then unify constrs
else addSol i (TVar j) (unify (substConstrs (TVar j) i constrs))
| ((TVar i, ty) | (ty, TVar i)) =>
if occursIn i ty
then raise UnificationError c
else addSol i ty (unify (substConstrs ty i constrs))
| (TArr (l, r), TArr (l', r')) =>
unify ((l, l') :: (r, r') :: constrs)
| _ => raise UnificationError c
fun <+> (sol1, sol2) =
let fun inSol2 v = List.all (fn (v', _) => v <> v') sol2
val sol1' = List.filter (fn (v, _) => inSol2 v) sol1
in
map (fn (v, ty) => (v, applySol sol1 ty)) sol2 @ sol1'
end
infixr 3 <+>
(* Generate all the constraints and solve them for a given expression *)
fun constrain ctx True = (TBool, [])
| constrain ctx False = (TBool, [])
| constrain ctx (Var i) = (lookupVar i ctx, [])
| constrain ctx (Fn body) =
let val argTy = TVar (fresh ())
val (rTy, sol) = constrain (MonoTypeVar argTy :: ctx) body
in (TArr (applySol sol argTy, rTy), sol) end
| constrain ctx (If (i, t, e)) =
let val (iTy, sol1) = constrain ctx i
val (tTy, sol2) = constrain (applySolCxt sol1 ctx) t
val (eTy, sol3) = constrain (applySolCxt (sol1 <+> sol2) ctx) e
val sol = sol1 <+> sol2 <+> sol3
val sol = sol <+> unify [ (applySol sol iTy, TBool)
, (applySol sol tTy, applySol sol eTy)]
in
(tTy, sol)
end
| constrain ctx (App (l, r)) =
let val (domTy, ranTy) = (TVar (fresh ()), TVar (fresh ()))
val (funTy, sol1) = constrain ctx l
val (argTy, sol2) = constrain (applySolCxt sol1 ctx) r
val sol = sol1 <+> sol2
val sol = sol <+> unify [(applySol sol funTy,
applySol sol (TArr (domTy, ranTy)))
, (applySol sol argTy, applySol sol domTy)]
in (ranTy, sol) end
| constrain ctx (Let (e, body)) =
let val (eTy, sol1) = constrain ctx e
val ctx' = applySolCxt sol1 ctx
val eTy' = generalizeMonoType ctx' (applySol sol1 eTy)
val (rTy, sol2) = constrain (PolyTypeVar eTy' :: ctx') body
in (rTy, sol1 <+> sol2) end
(* Finally, infer and solve all the constraints for a type,
* generating a final, inferred, polytype *)
fun infer e =
let val (ty, sol) = constrain [] e
in generalizeMonoType [] (applySol sol ty) end
end