Skip to content

Commit

Permalink
fix indentation
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Sep 19, 2023
1 parent 2fcc62c commit bcf1656
Showing 1 changed file with 125 additions and 133 deletions.
258 changes: 125 additions & 133 deletions SciLean/Core/Meta/GenerateRevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -123,146 +123,141 @@ def generateRevCDeriv (constName : Name) (argIds : ArraySet Nat) : MetaM Unit :=
withLocalDeclQ `instW .instImplicit q(SemiInnerProductSpace $K $W) fun instW => do
withLocalDeclQ (u:=levelOne) `w .default W fun w => do

-- argFuns are selected arguments parametrized by `W`
let argFunDecls ←
args.mapM (fun arg => do
let name ← arg.fvarId!.getUserName
-- argFuns are selected arguments parametrized by `W`
let argFunDecls ←
args.mapM (fun arg => do
let name ← arg.fvarId!.getUserName
let bi : BinderInfo := .default
let type ← mkArrow W (← inferType arg)
pure (name, bi, fun _ : Array Expr => pure (f:=MetaM) type))

withLocalDecls argFunDecls fun argFuns => do

let argFunApps := argFuns.map (fun argFun => argFun.app w)

let xs' := Array.mergeSplit splitIds argFunApps otherArgs

let f ← mkLambdaFVars #[w] (mkAppN (← mkConst' constName) xs')
let lhs ← mkAppM ``revCDeriv #[K,f]

let argFunPropDecls ←
argFuns.mapM (fun argFun => do
let name := (← argFun.fvarId!.getUserName).appendBefore "h"
let bi : BinderInfo := .default
let type ← mkArrow W (← inferType arg)
let type ← mkAppM ``HasAdjDiff #[K,argFun]
pure (name, bi, fun _ : Array Expr => pure (f:=MetaM) type))

withLocalDecls argFunDecls fun argFuns => do

let argFunApps := argFuns.map (fun argFun => argFun.app w)

let xs' := Array.mergeSplit splitIds argFunApps otherArgs

let f ← mkLambdaFVars #[w] (mkAppN (← mkConst' constName) xs')
let lhs ← mkAppM ``revCDeriv #[K,f]

let argFunPropDecls ←
argFuns.mapM (fun argFun => do
let name := (← argFun.fvarId!.getUserName).appendBefore "h"
let bi : BinderInfo := .default
let type ← mkAppM ``HasAdjDiff #[K,argFun]
pure (name, bi, fun _ : Array Expr => pure (f:=MetaM) type))

withLocalDecls argFunPropDecls fun argFunProps => do

let constId := mkIdent constName
let (rhs, proof) ← rewriteByConv lhs (← `(conv| (unfold $constId; autodiff)))

IO.println s!"lhs: {← ppExpr lhs}"
IO.println s!"rhs: {← ppExpr rhs}"

if ¬(← isDefEq (← mkEq lhs rhs) (← inferType proof)) then
throwError "generated proof is not type correct, expected proof of\n{← ppExpr (← mkEq lhs rhs)}\nbut got proof of\n{← ppExpr (← inferType proof)}"

if let .lam _ _ b _ := rhs then
let b := b.instantiate1 w

let transArgFuns ← argFuns.mapM (fun argFun => mkAppM ``revCDeriv #[K, argFun, w])

let transArgFunDecls ←
argFuns.mapIdxM (fun i argFun => do
let name := (← argFun.fvarId!.getUserName)
let bi : BinderInfo := .default
let type ← inferType transArgFuns[i]!
pure (name, bi, fun _ : Array Expr => pure (f:=MetaM) type))

withLocalDecls transArgFunDecls fun transArgFunVars => do

-- find all occurances of `<∂ (w':=w), argFunᵢ w` and replace it with recently introduced fvar
let b' ← eliminateTransArgFun b argFuns transArgFuns transArgFunVars
if b'.containsFVar w.fvarId! then
throwError s!"transformed function {← ppExpr b'} still contains {← ppExpr w}"

let idx ← firstExplicitNonTypeIdx xs

let xs' := Array.mergeSplit splitIds transArgFunVars otherArgs
let fvars := xs'[0:idx] ++ (#[W,instW] : Array Expr) ++ xs'[idx:]
let transFun ← instantiateMVars (← mkLambdaFVars fvars b')
let transFunName := constName.append "arg_" |>.append "revCDeriv"
IO.println s!"revCDeriv def fun\n{← ppExpr transFun}"

let transFunInfo : DefinitionVal :=
{
name := transFunName
type := (← inferType transFun)
value := transFun
hints := .regular 0
safety := .safe
levelParams := info.levelParams
}

addAndCompile (.defnDecl transFunInfo)


let xs' := Array.mergeSplit splitIds argFuns otherArgs
let fvars := xs'[0:idx] ++ (#[W, instW] : Array Expr) ++ xs'[idx:] ++ argFunProps
let ruleProof ← instantiateMVars (← mkLambdaFVars fvars proof)
let ruleName := constName.append "arg_" |>.append "revCDeriv_rule"
IO.println s!"revCDeriv rule\n{← ppExpr (← inferType ruleProof)}"

let ruleInfo : TheoremVal :=
{
name := ruleName
type := (← inferType ruleProof)
value := ruleProof
levelParams := info.levelParams
}

addDecl (.thmDecl ruleInfo)

-- turn ftransArgFunVars to let bindings
let mut lctx ← getLCtx
for transArgFunVar in transArgFunVars, transArgFun in transArgFuns do
lctx := lctx.modifyLocalDecl transArgFunVar.fvarId!
fun decl =>
match decl with
| .cdecl index fvarId userName type _ kind =>
.ldecl index fvarId userName type transArgFun false kind
| _ => unreachable!

withLCtx lctx (← getLocalInstances) do

let xs' := Array.mergeSplit splitIds transArgFunVars otherArgs
let fvars := xs'[0:idx] ++ (#[W,instW] : Array Expr) ++ xs'[idx:]
-- TODO: !!!replace transformedFun with new declaration!!!
let rhs ← mkLambdaFVars ((#[w] : Array Expr) ++ transArgFunVars) (← mkAppOptM transFunName (fvars.map .some))

let xs' := Array.mergeSplit splitIds argFuns otherArgs
let fvars := xs'[0:idx] ++ (#[W, instW] : Array Expr) ++ xs'[idx:] ++ argFunProps
let ruleDef ← instantiateMVars (← mkForallFVars fvars (← mkEq lhs rhs))
let ruleDefName := constName.append "arg_" |>.append "revCDeriv_rule_def"
IO.println s!"revCDeriv rule def\n{← ppExpr ruleDef}"

let ruleDefInfo : TheoremVal :=
{
name := ruleDefName
type := ruleDef
value := ruleProof
levelParams := info.levelParams
}

addDecl (.thmDecl ruleDefInfo)

pure ()

pure ()
else
throwError "transformed function should be in the form `fun w => ...` but got\n{← ppExpr rhs}"
pure ()
withLocalDecls argFunPropDecls fun argFunProps => do

let constId := mkIdent constName
let (rhs, proof) ← rewriteByConv lhs (← `(conv| (unfold $constId; autodiff)))

def mymul {K : Type} [IsROrC K] (x y : K) := x * y
IO.println s!"lhs: {← ppExpr lhs}"
IO.println s!"rhs: {← ppExpr rhs}"

if ¬(← isDefEq (← mkEq lhs rhs) (← inferType proof)) then
throwError "generated proof is not type correct, expected proof of\n{← ppExpr (← mkEq lhs rhs)}\nbut got proof of\n{← ppExpr (← inferType proof)}"

let .lam _ _ b _ := rhs
| throwError "transformed function should be in the form `fun w => ...` but got\n{← ppExpr rhs}"

let b := b.instantiate1 w

let transArgFuns ← argFuns.mapM (fun argFun => mkAppM ``revCDeriv #[K, argFun, w])

let transArgFunDecls ←
argFuns.mapIdxM (fun i argFun => do
let name := (← argFun.fvarId!.getUserName)
let bi : BinderInfo := .default
let type ← inferType transArgFuns[i]!
pure (name, bi, fun _ : Array Expr => pure (f:=MetaM) type))

withLocalDecls transArgFunDecls fun transArgFunVars => do

-- find all occurances of `<∂ (w':=w), argFunᵢ w` and replace it with recently introduced fvar
let b' ← eliminateTransArgFun b argFuns transArgFuns transArgFunVars
if b'.containsFVar w.fvarId! then
throwError s!"transformed function {← ppExpr b'} still contains {← ppExpr w}"

let idx ← firstExplicitNonTypeIdx xs

let xs' := Array.mergeSplit splitIds transArgFunVars otherArgs
let fvars := xs'[0:idx] ++ (#[W,instW] : Array Expr) ++ xs'[idx:]
let transFun ← instantiateMVars (← mkLambdaFVars fvars b')
let transFunName := constName.append "arg_" |>.append "revCDeriv"
IO.println s!"revCDeriv def fun\n{← ppExpr transFun}"

let transFunInfo : DefinitionVal :=
{
name := transFunName
type := (← inferType transFun)
value := transFun
hints := .regular 0
safety := .safe
levelParams := info.levelParams
}

addAndCompile (.defnDecl transFunInfo)

let xs' := Array.mergeSplit splitIds argFuns otherArgs
let fvars := xs'[0:idx] ++ (#[W, instW] : Array Expr) ++ xs'[idx:] ++ argFunProps
let ruleProof ← instantiateMVars (← mkLambdaFVars fvars proof)
let ruleName := constName.append "arg_" |>.append "revCDeriv_rule"
IO.println s!"revCDeriv rule\n{← ppExpr (← inferType ruleProof)}"

let ruleInfo : TheoremVal :=
{
name := ruleName
type := (← inferType ruleProof)
value := ruleProof
levelParams := info.levelParams
}

addDecl (.thmDecl ruleInfo)

-- turn ftransArgFunVars to let bindings
let mut lctx ← getLCtx
for transArgFunVar in transArgFunVars, transArgFun in transArgFuns do
lctx := lctx.modifyLocalDecl transArgFunVar.fvarId!
fun decl =>
match decl with
| .cdecl index fvarId userName type _ kind =>
.ldecl index fvarId userName type transArgFun false kind
| _ => unreachable!

withLCtx lctx (← getLocalInstances) do

let xs' := Array.mergeSplit splitIds transArgFunVars otherArgs
let fvars := xs'[0:idx] ++ (#[W,instW] : Array Expr) ++ xs'[idx:]
-- TODO: !!!replace transformedFun with new declaration!!!
let rhs ← mkLambdaFVars ((#[w] : Array Expr) ++ transArgFunVars) (← mkAppOptM transFunName (fvars.map .some))

let xs' := Array.mergeSplit splitIds argFuns otherArgs
let fvars := xs'[0:idx] ++ (#[W, instW] : Array Expr) ++ xs'[idx:] ++ argFunProps
let ruleDef ← instantiateMVars (← mkForallFVars fvars (← mkEq lhs rhs))
let ruleDefName := constName.append "arg_" |>.append "revCDeriv_rule_def"
IO.println s!"revCDeriv rule def\n{← ppExpr ruleDef}"

let ruleDefInfo : TheoremVal :=
{
name := ruleDefName
type := ruleDef
value := ruleProof
levelParams := info.levelParams
}

addDecl (.thmDecl ruleDefInfo)

pure ()

pure ()

def mymul {K : Type} [IsROrC K] (x y : K) := x * y

variable {K : Type} [IsROrC K] {W : Type v} [SemiInnerProductSpace K W]

set_default_scalar K


set_option trace.Meta.Tactic.fprop.discharge true in
set_option trace.Meta.Tactic.simp.discharge true in
set_option trace.Meta.Tactic.simp.unify true in
Expand All @@ -273,7 +268,7 @@ example (x y : W → K) (hx : HasAdjDiff K x) (hy : HasAdjDiff K y)
by
unfold mymul
ftrans only
#exit


set_option pp.funBinderTypes true in
-- set_option trace.Meta.Tactic.ftrans.step true in
Expand All @@ -283,12 +278,9 @@ set_option trace.Meta.Tactic.simp.discharge true in
set_option trace.Meta.Tactic.simp.unify true in
set_option trace.Meta.Tactic.ftrans.step true in
#eval show MetaM Unit from do

-- generateRevCDeriv ``norm2 #[4].toArraySet
generateRevCDeriv ``mymul #[2,3].toArraySet



#print mymul.arg_.revCDeriv
#check mymul.arg_.revCDeriv_rule
#check mymul.arg_.revCDeriv_rule_def

0 comments on commit bcf1656

Please sign in to comment.