Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow unpacking ADTs in types (dependent projections as atoms) #290

Merged
merged 14 commits into from
Dec 15, 2020

Conversation

danieldjohnson
Copy link
Collaborator

@danieldjohnson danieldjohnson commented Nov 23, 2020

This PR will fix #285 and #286, and should hopefully also make it possible to solve #258 in the future.

In the core Dex IR, all types must be reduced Atoms, which means that they should be comparable by syntactic equality and shouldn't involve doing any additional computation. This previously made it impossible to unpack ADTs in types, because unpacking ADTs required emitting an Unpack decl, and thus could not be reduced into a single atom.

This PR (based on discussion with @dougalm) replaces Unpack decls with a ProjectElt atom, which allows directly referencing a particular subcomponent of a Var. ProjectElt atoms must always be in fully-reduced form, which I partially enforce by using the definition ProjectElt (NE.NonEmpty Int) Var: this ensures that only Vars can be projected, and not other atoms (since any other atom should be immediately reduced away). The substitution rules for ProjectElt also carefully maintain this invariant by "reducing away" whenever the variable they reference is substituted for a concrete value.

Some other miscellaneous changes that were required to make this work properly:

  • The "always-reduced" invariant also requires moving the reduction machinery back into Embed.hs, because any time a projection is built, we have to ensure that it is fully reduced.
  • In the presence of dependent constructors like AsList, the type of a ProjectElt atom may also be a ProjectElt atom that references earlier components of the type. This makes the getType implementation for ProjectElt a bit complicated. (This is another reason why the always-reduced invariant is necessary even for values not used as types, since getType on those values doesn't have enough context to do reduction on its own when obtaining the type of a dependent projection.)
  • Simplification didn't used to simplify through type signatures in some cases, but doing so is necessary to ensure that ProjectElt atoms get properly reduced at simplification time.
  • I've modified the implicit-implicit-args modification to ignore identifiers used as functions, since we know that those couldn't possibly be of type Type. This matches the behavior of Idris and makes it possible to call functions inside type annotations without bizarre errors.

I've also modified the UPi constructor to take a pattern instead of a single binder, so that unpacking syntax in def definitions works as expected. For now, the concrete syntax for pi types themselves still only takes a single binder, but this should be fairly straightforward to extend in the future (perhaps at the same time as fixing #282).

Still to do:

  • Fold fst and snd into this same design, so that tuples can be unpacked in type signatures.
  • Possibly directly expose ProjectElt in the user syntax? Although maybe this isn't worth doing, since the user can always write a function to do this themselves using the existing, type-checked unpacking syntax (e.g. you can produce ProjectElt [1, 0] v by writing something like (\(Foo (Bar _ x) _). x) v.
  • Maybe change the pretty-printing so that it's easier to interpret? Perhaps instead of ProjectElt [1, 0] v it would look like %project [[_, @], _] v or something.
  • Figure out what to do with autodiff (probably straightforward but I haven't really thought about it yet)
  • Refactor some things
  • Add more tests

To try to reduce the surface for bugs, ProjectElt syntactically requires its
argument to be a Var, indexed by a nonempty list of indices. This means that
substituting into a ProjectElt requires immediately reducing it.

Note that getting the type of a ProjectElt atom is a bit subtle, because if
we are extracting a value from an existential ADT DataCon, the type of the
projected result may itelf include earlier bindings in the DataCon, which
must also be converted to ProjectElt atoms.
After this change, we should use projections instead of unpacks in most
places. There's a few remaining bugs due to differences in simplifications
between decls and atoms.
Now that types can involve projections, we need to fully simplify the type
arguments to type constructors instead of simply using substEmbedR. This fixes
one of the broken tests.
There is likely a better implementation that re-uses an existing destination
if we have the right structure. But this seems to work for now, and fixes the
broken tests.
All unpacks can be represented as let decls followed by projections.
…king.

Modify implicit args: To allow function calls inside type annotations, this
change makes it so that any lowercase name that is used in function position
of an application (e.g. `f x`) will NOT be added as an automatic implicit type
parameter. This is an improvement because such a type application would never
typecheck anyway if `f` was inferred to have type `Type`. This also matches
the behavior of Idris.

Reduce projections: We assume during typechecking that `getType` always returns
a fully reduced type. But `getType` of a projection may produce another
projection with the same root variable. Thus, whenver we create a new
projection, we have to reduce the variable. (It's not clear that this is the
best way to do this, but it seems to work for now.)

Get list extraction example working: With these two changes, it becomes possible
to construct functions that do dependent projections, for instance by converting
a `List a` into a table. The syntax is a bit unwieldy for this, but that should
be easy to fix.
UPi atoms now can take a pattern instead of a single binder. If the pattern
is more complex than a single binder, that pattern is then bound while
converting the UPi into a Pi atom. Note that the Pi representation is the
same as it was; the returned type must still be reducible to an atom for
the conversion to succeed. However, with the new projection atoms, unpacking
of ADTs will still be reducible.

The parser implementation for "def"-style functions has been modified to allow
using patterns, which means it is now possible to bring values from an ADT
into scope in the type for a "def"-style function. For now, the parser does not
support patterns in explicit pi type expressions. This should be fairly
straightforward to add but might require some care regarding ambiguity of the
grammar (see google-research#282).
@google-cla google-cla bot added the cla: yes label Nov 23, 2020
Support for autodiff is partial, because we currently don't seem to have a way
to unpack a reference to a record or an ADT (we do have `FstRef` and `SndRef`
for pairs, though). However, this is enough to get the tests to pass.
@danieldjohnson
Copy link
Collaborator Author

I have a preliminary version of autodiff working now, I think. At least, it's enough to get our tests to pass.

As far as my limited understanding goes, autodiff works by maintaining a Ref object for every variable, and then writing into those Ref objects when computing cotangents in the backwards pass. I don't think we currently have any support for unpacking references to records or ADTs, so currently it also doesn't work to do autodiff through unpacks of records or ADTs (since the way to compute gradients for a projection is to project the reference and then update the gradients there). Looks like maybe this used to work for records because of how isUnpack is used, but I don't think that solution works anymore, because linAtomRef needs to actually return a Ref when given a ProjectElt atom.

Does this approach seem reasonable? My guess is we will want to be able to index into references to ADTs and records at some point, so once that happens we should be able to fix linAtomRef to support those as well.

(I briefly tried a different version of autodiff here that constructed a cotangent value using zeroAt, updated one of the elements with the input cotangent, then used that mostly-zeros object as the output cotangent. That worked for some of the tests but didn't work for the linAtomRef problem, which in particular seems to occur when indexing a table inside a pair. I guess another option would be to make a local reference to that mostly-zeros object, if we wanted to keep support for gradients-with-respect-to-records in the short term? But it seems like this is a hacky solution anyway, we probably don't want to materialize zeros all the time.)

Refactors autodiff to not use `isUnpack` since `Unpack` decls are no longer
part of the core IR. Adds tests for more complicated uses of dependent
projections and deeply nestesd projections.

Also fixes a bug where (,) was not respecting precedence correctly, by using
`mayPair`/`mayNotPair` for patterns as well as expressions.
Since %projectElt can show up in the type of ordinary expressions, the prefix
of % should make it clear that this isn't a user-defined ADT but instead an
internal type.
@danieldjohnson danieldjohnson marked this pull request as ready for review December 2, 2020 00:58
@danieldjohnson
Copy link
Collaborator Author

This should be ready to review now!

Copy link
Collaborator

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me just being by saying that this is extremely cool!

In any case, I've left a bunch of comments and I'd like to iterate on some of them before we merge this. For example, I'm quite confused about the AD changes and they are left untested.

Apart from that, I'm very much conflicted about the use of NonEmpty in ProjectElt instead of parametrizing each one of those as an Int and using chains of those when necessary. Here are a few arguments for why that might be better:

  • In most places in the code, you explicitly handle recursion over the list, which would happen automatically if our ADT was nested instead.
  • It makes our representation of Atoms less normal. We now have to be careful to never produce ProjectElt [0] (ProjectElt [1] something)) because it would be unequal to ProjectElt [0, 1] something (or [1, 0], I'm not sure; this ordering issue might be another argument to favor the recursive definition).

examples/adt-tests.dx Outdated Show resolved Hide resolved
examples/adt-tests.dx Show resolved Hide resolved
examples/adt-tests.dx Show resolved Hide resolved
src/lib/Syntax.hs Show resolved Hide resolved
src/lib/Syntax.hs Show resolved Hide resolved
src/lib/Embed.hs Outdated
unless (null decls) $ throw CompilerErr $ "Unexpected decls: " ++ pprint decls
return piTy
let block = wrapDecls decls ans
case reduceBlock scope block of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering what kind of decls can we expect here? Applications? Anything else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now at least, I think it should just be applications and bindings of names to ProjectElts, e.g.

  n = ProjectElt [0] somelist
  Fin n

src/lib/Embed.hs Outdated Show resolved Hide resolved
examples/adt-tests.dx Show resolved Hide resolved
src/lib/Autodiff.hs Outdated Show resolved Hide resolved
src/lib/Autodiff.hs Show resolved Hide resolved
prelude.dx Outdated Show resolved Hide resolved
In the IR, projections are represented with simple integers. But since
projections can show up in user expressions, we rewrite them during
printing to instead be of the form `(\pat. elt) x` where `pat` is a
pattern that does the unpacking.

Also changes the way patterns are pretty-printed to remove some
redundant visual noise.
The functions now live in Type.hs and have a `typeReduce` prefix
instead of just being called `reduceAtom`/`reduceExpr` etc.
Copy link
Collaborator

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks amazing. Thanks!

@apaszke apaszke merged commit 2d9d987 into google-research:dev Dec 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants