Skip to content

Commit

Permalink
bug fix in structural inverse and test file for it
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Sep 18, 2023
1 parent 039a514 commit ca0c3e5
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 3 deletions.
46 changes: 43 additions & 3 deletions SciLean/Tactic/StructuralInverse.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ set_option linter.unusedVariables false

open Lean Meta Qq

initialize registerTraceClass `Meta.Tactic.structuralInverse.step

/--
This comparison is used to select the FVarIdSet with the least number of element but non-empty one! Thus zero size sets are bigger then everything else
-/
Expand All @@ -33,6 +35,30 @@ private structure SystemInverse where
xVals : Array Expr


private def equationsToString (yVals fVals : Array Expr) : MetaM String :=
yVals.zip fVals
|>.joinlM (fun (y,val) => do pure s!"{← ppExpr y} = {← ppExpr val}")
(fun s s' => pure (s ++ "\n" ++ s'))


private def partiallyResolvedSystemToString
(lctx : LocalContext) (xResVars' yVals : Array Expr)
(eqs : Array (Nat × Expr × FVarIdSet)) : MetaM String := do
withLCtx lctx (← getLocalInstances) do
let lets ← xResVars'.joinlM (fun var => do pure s!"let {← var.fvarId!.getUserName} := {← ppExpr (← var.fvarId!.getValue?).get!}")
(fun s s' => pure (s ++ "\n" ++ s'))
let (yVals', fVals') := eqs.filterMap (fun (i,val,idset) => if idset.size = 0 then none else .some (yVals[i]!, val))
|>.unzip
pure s!"{lets}\n{← equationsToString yVals' fVals'}"

private def afterBackwardPassSystemToString
(lctx : LocalContext) (xResVars' xVars'' xVars xVals : Array Expr) := do
withLCtx lctx (← getLocalInstances) do
let lets ← (xResVars' ++ xVars'').joinlM (fun var => do pure s!"let {← var.fvarId!.getUserName} := {← ppExpr (← var.fvarId!.getValue?).get!}")
(fun s s' => pure (s ++ "\n" ++ s'))
pure s!"{lets}\n{← equationsToString xVars xVals}"


/--
Solves the system of m-equations in n-variables
```
Expand All @@ -47,6 +73,8 @@ If `n>m` then the values `xᵢ` can depend also on other `xₖ`. The set `n-m` x
-/
private partial def invertValues (xVars yVals fVals : Array Expr) : MetaM (Option (SystemInverse × Array MVarId)) := do

trace[Meta.Tactic.structuralInverse.step] "inverting system in variables {← xVars.mapM ppExpr}\n{← equationsToString yVals fVals}"

let xIdSet : FVarIdSet := .fromArray (xVars.map (fun x => x.fvarId!)) _

-- data is and array of (yId, value, set of xId aprearing in value)
Expand Down Expand Up @@ -83,6 +111,8 @@ private partial def invertValues (xVars yVals fVals : Array Expr) : MetaM (Optio
let xVar' := varArr[xVarId]!
let varArrOther := varArr.eraseIdx xVarId

trace[Meta.Tactic.structuralInverse.step] "resolving {xVar'} from {yVals[j]!} = {← withLCtx lctx instances <| ppExpr yVal}"

let xName ← xVar'.fvarId!.getUserName

-- new value of x but it can still depend on x that have not been resolved
Expand All @@ -98,8 +128,11 @@ private partial def invertValues (xVars yVals fVals : Array Expr) : MetaM (Optio
pure (f, .some goal)

if let .some goal := goal? then
trace[Meta.Tactic.structuralInverse.step] "new obligation {← withLCtx lctx instances <| ppExpr (← goal.getType)}"
goals := goals.push goal

trace[Meta.Tactic.structuralInverse.step] "resolved {← ppExpr xVar'} with {← withLCtx lctx instances <| ppExpr xVal'}"

-- xRes is a function resolving x given all unresolved xs
let xResId ← withLCtx lctx instances <| mkFreshFVarId
let xResVar := Expr.fvar xResId
Expand All @@ -109,7 +142,7 @@ private partial def invertValues (xVars yVals fVals : Array Expr) : MetaM (Optio
inferType xResVal
lctx := lctx.mkLetDecl xResId (xName.appendAfter "'") xResType xResVal

let xVal'' := mkAppN xResVar varArr[1:]
let xVal'' := mkAppN xResVar varArrOther

xVars' := xVars'.push xVar'
xVals' := xVals'.push xVal''
Expand All @@ -127,6 +160,12 @@ private partial def invertValues (xVars yVals fVals : Array Expr) : MetaM (Optio
else
(j, default, {})

let (yVals',fVals') :=
eqs.filterMap (fun (i,val,idset) => if idset.size = 0 then none else .some (yVals[i]!, val))
|>.unzip
trace[Meta.Tactic.structuralInverse.step] "system after resolving {← ppExpr xVar'}\n{← partiallyResolvedSystemToString lctx xResVars' yVals eqs}"


let mut xVars'' : Array Expr := #[]

-- backward pass
Expand All @@ -143,11 +182,12 @@ private partial def invertValues (xVars yVals fVals : Array Expr) : MetaM (Optio
let xVal'' := xVal'.replaceFVars xVars'[0:i] xVars''
xVars'' := xVars''.push xVar''

lctx := lctx.mkLetDecl xId'' (← xId.getUserName) (← xId.getType) xVal''

lctx := lctx.mkLetDecl xId'' ((← xId.getUserName).appendAfter "''") (← xId.getType) xVal''

let xVals := xVars.map (fun xVar => xVar.replaceFVars xVars' xVars'')

trace[Meta.Tactic.structuralInverse.step] "system after backward pass\n{← afterBackwardPassSystemToString lctx xResVars' xVars'' xVars xVals}"

let resolvedIdSet : FVarIdSet := .fromArray (xVars'.map (fun x => x.fvarId!)) _
let unresolvedIdSet := xIdSet.diff resolvedIdSet

Expand Down
59 changes: 59 additions & 0 deletions test/structural_inverse.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import SciLean.Tactic.StructuralInverse
import SciLean.Data.Idx

open SciLean

open Lean Meta Qq

/--
info: fun ij1 y =>
let ij0' := fun ij1 => Function.invFun (fun ij0 => ij0 + ij1.1) y;
let ij0'' := ij0' ij1;
(ij0'', ij1)
-/
#guard_msgs in
#eval show MetaM Unit from do

let f := q(fun ij : Idx 10 × Idx 5 => ij.1 + ij.2.1)
-- let f := q(fun ij : Int × Int × Int => (ij.1 + ij.2.2, ij.2.1 + ij.1 + ij.2.2))
-- let f := q(fun ij : Int × Int × Int => (ij.2.2, ij.1))
let .some (.right f', _) ← structuralInverse f
| throwError "failed to invert"
IO.println (← ppExpr f'.invFun)

/--
info: fun ij1 y =>
let ij0' := fun ij2 => Function.invFun (fun ij0 => ij0 + ij2) y.fst;
let ij2' := fun ij1 => Function.invFun (fun ij2 => ij1 + ij0' ij2 + ij2) y.snd;
let ij2'' := ij2' ij1;
let ij0'' := ij0' ij2'';
(ij0'', ij1, ij2'')
-/
#guard_msgs in
#eval show MetaM Unit from do

let f := q(fun ij : Int × Int × Int => (ij.1 + ij.2.2, ij.2.1 + ij.1 + ij.2.2))
let .some (.right f', _) ← structuralInverse f
| throwError "failed to invert"
IO.println (← ppExpr f'.invFun)

/--
info: fun ij1 y => (y.snd, ij1, y.fst)
-/
#guard_msgs in
#eval show MetaM Unit from do
let f := q(fun ij : Int × Int × Int => (ij.2.2, ij.1))
let .some (.right f', _) ← structuralInverse f
| throwError "failed to invert"
IO.println (← ppExpr f'.invFun)


/--
info: fun y => (y.snd.fst, y.snd.snd, y.fst)
-/
#guard_msgs in
#eval show MetaM Unit from do
let f := q(fun ij : Int × Int × Int => (ij.2.2, ij.1, ij.2.1))
let .some (.full f', _) ← structuralInverse f
| throwError "failed to invert"
IO.println (← ppExpr f'.invFun)

0 comments on commit ca0c3e5

Please sign in to comment.