Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Represent transpose of Delta as rewrite rules #100

Open
Mikolaj opened this issue Apr 15, 2023 · 3 comments
Open

Represent transpose of Delta as rewrite rules #100

Mikolaj opened this issue Apr 15, 2023 · 3 comments
Labels
help wanted Extra attention is needed

Comments

@Mikolaj
Copy link
Owner

Mikolaj commented Apr 15, 2023

Here's the Tom's idea from #95 (comment):

the transposition can't be expressed as rewriting rules, because it's stateful.

It's only stateful because it's Cayley-transformed. The eval function:

eval :: R -> Delta -> DeltaMap -> DeltaMap

is really eval :: R -> Delta -> Endo DeltaMap: its codomain is Cayley-transformed (think DList difference lists) in order to make things more efficient. (addDelta becomes a bit more expensive this way, because it has to update a value in the map instead of being able to create a singleton map, but addition of maps is much cheaper because surely (.) is more efficient than Map.union.)

So maybe it can be expressed using rewrite rules, but just into the language of Endo DeltaMap, not of DeltaMap?

As far as I understand so far, we'd need more Delta constructors and we'd transpose by rewriting the Delta expressions. Afterwards, to get rid of the Delta constructors, we'd interpret them straightforwardly as linear transformation syntax, which is implemented (and probably bit-rotted) in buildDerivative, which computes forward derivatives on the basis of Delta collected in forward pass. In the pipeline that produces gradient Ast, we'd run buildDerivative with Ast codomain.

I wonder how much the extended Delta would start resembling Ast. If we switched from Delta to Ast, we'd end up with Ast nested inside Ast, which would be fine in this case. But perhaps Delta has some advantage over Ast? I guess it's linear, just as the linear sublanguage in the YOLO paper. I wonder if it's easy to extend Delta to be enough for this term rewriting, but still syntactically guaranteeing linearity (assuming it does ATM, after all the tensor extensions).

A modest version of this rewrite is to evaluate Delta to combinators of type DeltaMap -> DeltaMap, but not reify the combinators as constructors of Delta. Then, this can be presented as a rewrite, but implemented with immediate evaluation of the rewritten terms. Relevant parts of buildDerivative are then manually inlined.

@Mikolaj Mikolaj added the help wanted Extra attention is needed label Apr 15, 2023
@tomsmeding
Copy link
Collaborator

It's a bit stranger than I thought, but this is the best I've come up so far.

Source language: our Delta expressions.
Target language: terms in the target language described below, with a distinguished variable c in the environment. Target language terms t are typed with a judgement Γ | c : σ ||- t : τ. (Using ||- just to visually distinguish from the standard non-linear judgement |-, which is also used in some of the typing rules below.) These terms define an algebraically linear function σ -> τ with a non-linear environment Γ.

Given a source term Γ |- t : τ, transposition produces in the target language: Γ | c : τ ||- transp[t] : DMap. The DMap, Variable, Array and Shape types don't have further type indices in the typing rules below because I feared madness otherwise. There's probably a way to include those and make this type-safe.

# LINEARITY

------------------- (zero)
Γ | c : τ ||- 0 : σ

Γ | c : τ ||- s : σ      Γ | c : τ ||- t : σ
-------------------------------------------- (plus)
         Γ | c : τ ||- s + t : σ

Γ |- r : Real     Γ | c : τ ||- t : σ
------------------------------------- (scale)
       Γ | c : τ ||- r * t : σ

# LINEAR BINDING

Γ | c : τ ||- s : σ₁      Γ | c : σ₁ ||- t : σ₂
----------------------------------------------- (let)
      Γ | c : τ ||- let c = s in t : σ₂

------------------- (zero)
Γ | c : τ ||- c : τ

# COTANGENT MAPS

        n Variable
------------------------------- (onehot map)
Γ | c : τ ||- Onehot n c : DMap

Γ | c : τ ||- t : DMap      n Variable
-------------------------------------- (map delete)
   Γ | c : τ ||- Delete n t : DMap

Γ | c : τ ||- t : DMap      n Variable
-------------------------------------- (map lookup)
   Γ | c : τ ||- Lookup n t : DMap

# ARRAY OPERATIONS

Γ |- sh : Shape       Γ |- f : Int -> Int
        Γ | c : τ ||- t : Array
----------------------------------------- (scatter)
   Γ | c : τ ||- Scatter sh f t : Array

Example transposition rules (rewritten from eval in the original POPL paper, and gather from simplified/HordeAd/Core/Delta.hs:buildFinMaps:evalR):

T[Zero] = 0
T[Scale y u] = let c = y * c in T[u]
T[Add u₁ u₂] = T[u₁] + T[u₂]
T[Var n] = Onehot n c
T[Let n u₁ u₂] = let c = T[u₂] in Delete n c + (let c = Lookup n c in T[u₁])
  -- Note that the first c binding above is of type DMap.
T[GatherZ sh u f sha] = let c = Scatter sha c f in T[u]
  -- I would rather write 'GatherZ sh f u sha' (or even 'GatherZ sh f u') and
  -- 'Scatter sha f c', but this is horde-ad argument order. :D

@tomsmeding
Copy link
Collaborator

While the above is a fine linear language, I think, the point of this was to Cayley-transform DMap, so we'd need to specialise 0 and + to only work on DMap values, i.e.

------------------- (zero)
Γ | c : τ ||- 0 : DMap

Γ | c : τ ||- s : DMap      Γ | c : τ ||- t : DMap
-------------------------------------------------- (plus)
           Γ | c : τ ||- s + t : DMap

after which I think we can without issues implement DMap actually as Endo DMap ~= DMap -> DMap with the standard monoid operations id and (.). We'd have to double-check that Delete and Lookup continue to be implementable in the presence of let after DMap is Cayley-transformed. I think that works out, but I'm not 100% sure.

@Mikolaj
Copy link
Owner Author

Mikolaj commented Apr 25, 2023

Hah, yes, this is unexpected.

Re GatherZ sh f u', IIRC the extra shape is needed to compute the forward derivative from a delta expression without traversing the delta expression to reconstruct the shape. The u in front of f may be related to simplification, vectorization, etc., recursing over u rather than over f or to indexing taking first the term, only then the index. These are weak reasons.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants