diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d75782fda..b513e61f0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -51,8 +51,12 @@ jobs: run: rm -rf ~/.stack/setup-exe-cache if: runner.os == 'macOS' + - name: Build, treating warnings as errors + run: make build-ci + if: runner.os == 'Linux' + - name: Build - run: make + run: make build - name: Run tests run: make tests diff --git a/.hlint.yaml b/.hlint.yaml new file mode 100644 index 000000000..44155cbac --- /dev/null +++ b/.hlint.yaml @@ -0,0 +1,2 @@ +- arguments: [--color] +- ignore: {name: "Use fmap"} diff --git a/README.md b/README.md index 85d4c8a12..47e3e9f12 100644 --- a/README.md +++ b/README.md @@ -12,25 +12,33 @@ To learn more, check out our or these example programs: * [Dex prelude](https://google-research.github.io/dex-lang/prelude.html) - * [Mandelbrot set](https://google-research.github.io/dex-lang/mandelbrot.html) - * [Ray tracer](https://google-research.github.io/dex-lang/raytrace.html) - * [Estimating pi](https://google-research.github.io/dex-lang/pi.html) - * [Hamiltonian Monte Carlo](https://google-research.github.io/dex-lang/mcmc.html) - * [ODE integrator](https://google-research.github.io/dex-lang/ode-integrator.html) - * [Sierpinski triangle](https://google-research.github.io/dex-lang/sierpinski.html) - * [Basis function regression](https://google-research.github.io/dex-lang/regression.html) - * [Brownian bridge](https://google-research.github.io/dex-lang/brownian_motion.html) - -Please note that Dex is an experimental research project at an early stage of -development. Contributions welcome! + * [Mandelbrot set](https://google-research.github.io/dex-lang/examples/mandelbrot.html) + * [Ray tracer](https://google-research.github.io/dex-lang/examples/raytrace.html) + * [Estimating pi](https://google-research.github.io/dex-lang/examples/pi.html) + * [Hamiltonian Monte Carlo](https://google-research.github.io/dex-lang/examples/mcmc.html) + * [ODE integrator](https://google-research.github.io/dex-lang/oexamples/de-integrator.html) + * [Sierpinski triangle](https://google-research.github.io/dex-lang/examples/sierpinski.html) + * [Basis function regression](https://google-research.github.io/dex-lang/examples/regression.html) + * [Brownian bridge](https://google-research.github.io/dex-lang/examples/brownian_motion.html) + +🚨 **Dex is an experimental research project at an early stage of +development. Expect monstrous bugs and razor-sharp edges!** + +🤝 **Contributions welcome!** See our issue tracker for [good first issues](https://github.com/google-research/dex-lang/labels/good%20first%20issue), or browse by [thematic labels](https://github.com/google-research/dex-lang/labels). ## Dependencies * Install [stack](https://www.haskellstack.org) * Install LLVM 9 - * `apt-get install llvm-9-dev` on Ubuntu/Debian, - * `brew install llvm@9` on macOS, and ensure it is on your `PATH` e.g. via `export PATH="$(brew --prefix llvm@9)/bin:$PATH"` before building. - * Install libpng (often included by default in *nix) + * Ubuntu/Debian: `apt-get install llvm-9-dev` + * macOS: `brew install llvm@9` + * Make sure `llvm@9` is on your `PATH` before building. Example: `export PATH="$(brew --prefix llvm@9)/bin:$PATH"` + * Install clang (may be installed together with llvm) + * Ubuntu/Debian: `apt-get install clang` + * macOS: installs with llvm + * Install libpng (often included by default in *nix platforms) + * Ubuntu/Debian: `apt-get install libpng-dev` + * macOS: `brew install libpng` ## Building diff --git a/benchmarks/dexbench.py b/benchmarks/dexbench.py index b66634c7e..6dbf7391d 100644 --- a/benchmarks/dexbench.py +++ b/benchmarks/dexbench.py @@ -55,15 +55,15 @@ def restore_machine(): def run_benches(lang, backend): if lang == "dex": if backend == "CPU": - backend_args = ["--backend", "LLVM-MC"] + backend_args = ["--backend", "llvm-mc"] env = {} elif backend == "GPU": - backend_args = ["--backend", "LLVM-CUDA"] + backend_args = ["--backend", "llvm-cuda"] env = {"CUDA_LAUNCH_BLOCKING":"1"} else: raise Exception command = (["stack", "exec", "dex", "--"] + backend_args + - ["script", "--outfmt", "JSON", dex_microbench_file]) + ["script", "--outfmt", "json", dex_microbench_file]) elif lang == "jax": if backend == "CPU": env = {"CUDA_VISIBLE_DEVICES":""} diff --git a/cabal.project b/cabal.project new file mode 100644 index 000000000..33ff576b5 --- /dev/null +++ b/cabal.project @@ -0,0 +1,13 @@ +packages: dex.cabal + +source-repository-package + type: git + location: https://github.com/llvm-hs/llvm-hs + tag: llvm-9 + subdir: llvm-hs + +source-repository-package + type: git + location: https://github.com/llvm-hs/llvm-hs + tag: llvm-9 + subdir: llvm-hs-pure diff --git a/dex.cabal b/dex.cabal index 8fe9d6835..8c4f5acd7 100644 --- a/dex.cabal +++ b/dex.cabal @@ -10,6 +10,7 @@ name: dex version: 0.1.0.0 author: Dougal Maclaurin maintainer: dougalm@google.com +license-file: LICENSE build-type: Simple flag cuda @@ -32,17 +33,22 @@ library exposed-modules: Env, Syntax, Type, Inference, JIT, LLVMExec, Parser, Util, Imp, Imp.Embed, Imp.Optimize, PPrint, Algebra, Parallelize, Optimize, Serialize - Actor, Cat, Flops, Embed, + Actor, Cat, Embed, Export, RenderHtml, LiveOutput, Simplify, TopLevel, - Autodiff, Interpreter, Logging, PipeRPC, CUDA, + Autodiff, Interpreter, Logging, CUDA, LLVM.JIT, LLVM.Shims - build-depends: base, containers, mtl, binary, bytestring, - time, tf-random, llvm-hs-pure ==9.*, llvm-hs ==9.*, - aeson, megaparsec >=8.0, warp, wai, filepath, - parser-combinators, http-types, prettyprinter, text, - blaze-html, cmark, diagrams-lib, ansi-terminal, - transformers, directory, mmap, unix, - process, primitive, store, dex-resources, temporary, + build-depends: base, containers, mtl, bytestring, + llvm-hs-pure, llvm-hs, + -- Parsing + megaparsec, parser-combinators, + -- Text output + prettyprinter, text, + -- Portable system utilities + filepath, directory, ansi-terminal, process, temporary, + -- Serialization + store, + -- Notebook support + warp, wai, blaze-html, aeson, http-types, cmark, binary if !os(darwin) exposed-modules: Resources hs-source-dirs: src/resources @@ -51,11 +57,12 @@ library build-depends: dex-resources default-language: Haskell2010 hs-source-dirs: src/lib - ghc-options: -Wall -fPIC + ghc-options: -Wall -fPIC -threaded -optP-Wno-nonportable-include-path cxx-sources: src/lib/dexrt.cpp cxx-options: -std=c++11 -fPIC default-extensions: CPP, DeriveTraversable, TypeApplications, OverloadedStrings, - TupleSections, ScopedTypeVariables, LambdaCase, PatternSynonyms + TupleSections, ScopedTypeVariables, LambdaCase, PatternSynonyms, + BlockArguments pkgconfig-depends: libpng if flag(cuda) include-dirs: /usr/local/cuda/include @@ -71,12 +78,14 @@ executable dex main-is: dex.hs other-extensions: OverloadedStrings build-depends: dex, base, haskeline, prettyprinter, mtl, - optparse-applicative, unix, store, bytestring, directory + optparse-applicative, ansi-wl-pprint, + unix, store, bytestring, directory if os(darwin) build-depends: dex-resources default-language: Haskell2010 hs-source-dirs: src - default-extensions: CPP, LambdaCase + default-extensions: CPP, LambdaCase, BlockArguments + ghc-options: -optP-Wno-nonportable-include-path if flag(optimized) ghc-options: -O3 else @@ -84,29 +93,18 @@ executable dex foreign-library Dex type: native-shared - other-modules: API - build-depends: base, dex, dex-resources, mtl - hs-source-dirs: src/foreign - c-sources: src/foreign/rts.c + other-modules: Dex.Foreign.API, Dex.Foreign.Util, Dex.Foreign.JIT + , Dex.Foreign.Context, Dex.Foreign.Serialize + build-depends: base, mtl, containers, llvm-hs, dex, dex-resources + if os(darwin) + build-depends: dex-resources + hs-source-dirs: src/ + c-sources: src/Dex/Foreign/rts.c cc-options: -std=c11 -fPIC - ghc-options: -Wall -fPIC - default-language: Haskell2010 - default-extensions: TypeApplications, ScopedTypeVariables, LambdaCase - if flag(optimized) - ghc-options: -O3 - else - ghc-options: -O0 - -Test-Suite test-dex - type: exitcode-stdio-1.0 - main-is: PropTests.hs - build-depends: dex, base, prettyprinter, containers, - hedgehog, microlens-platform, mtl - other-modules: GenExpr, TestPass + ghc-options: -Wall -fPIC -optP-Wno-nonportable-include-path default-language: Haskell2010 - hs-source-dirs: tests - ghc-options: cbits/libdex.so - -Wall + default-extensions: TypeApplications, ScopedTypeVariables, LambdaCase, + BlockArguments if flag(optimized) ghc-options: -O3 else diff --git a/examples/ad-tests-interp.dx b/examples/ad-tests-interp.dx deleted file mode 100644 index ce57d9854..000000000 --- a/examples/ad-tests-interp.dx +++ /dev/null @@ -1,56 +0,0 @@ - - -:p f : Float --o Float - f x = x - transposeLinear f 2.0 -> 2.0 - -:p f : Float --o Float - f x = y = x; y - transposeLinear f 2.0 -> 2.0 - -:p f : Float --o Float - f x = x + x - transposeLinear f 2.0 -> 4.0 - -:p f : Float --o Float - f x = y = 2.0 * x - 3.0 * y + x - transposeLinear f 1.0 -> 7.0 - -:p f : Float --o Float - f x = (2.0 + 3.0) * x - transposeLinear f 1.0 -> 5.0 - -:p f : (Float, Float) --o Float - f z = (x, y) = z - x + y * 2.0 - transposeLinear f 1.0 -> (1.0, 2.0) - -:p f : Float --o (Float, Float) - f x = (x, x * 2.0) - transposeLinear f (1.0, 3.0) -> 7.0 - -:p f x = x * x + 1.0 - jvp f 3.0 2.0 -> 12.0 - -:p f x = x * x + 1.0 - snd (vjp f 3.0) 2.0 -> 12.0 - -:p f : (Float, Float) -> Float - f (x,y) = x * y * 3.0 - jvp f (2.0, 5.0) (1.0, 100.0) -> 615.0 - -:p f : 3=>Float -> 3=>Float - f x = for i. x.i * x.i - jvp f [1.0, 1.5, 2.5] [3.0, 4.0, 1.0] -> [6.0, 12.0, 5.0] diff --git a/examples/aspirational.dx b/examples/aspirational.dx deleted file mode 100644 index ed3c1b2b2..000000000 --- a/examples/aspirational.dx +++ /dev/null @@ -1,48 +0,0 @@ --- === logistic regression === - --- features needed --- * type aliases with type variables --- * index set sum types (and generalization of inL and inR) --- * while loop construct (for fixed-point iter version) --- loop : (a -> Either b a) -> (b, E n. n=>a) --- * unpacking multiple type variables - -type ParamsIdx d = Either d () -- concrete syntax for sum types? -type Params d = (ParamsIdx d)=>Float - -logLogistic : Float -> Float -logLogistic x = log $ 1 / (1 + exp (-x)) - -bool2pm1 : Bool -> Float -bool2pm1 x = select x 1.0 -1.0 - -evalLogreg : Params d -> d=>Float -> Float -evalLogreg params x = - let w.i = x.(L i) -- can we improve this unpacking? - b = x.(R ()) - in logLogistic $ b + vdot w x - -logRegLoss : Params d -> d=>Float -> Bool -> Float -logRegLoss params x y = (evalLogReg params x) * (bool2pm1 y) - --- what about looping until convergence? Need a different looping construct -optimize : (n=>Float -> Float) -> (n=>Float) -optimize f = .. - let lr = 0.1 - scale = 0.1 - nIters = 1000 - x0.i = scale * randn (fanout 0).i - in loopN nIters x0 lam x. let dx = grad f x - in for i. x.i + lr * dx.ix' - -theData : E n d. (n=>d=>Float, n=>Bool) - -(xs, ys), N, D = unpack data - -loss : D=>R -> R -loss params = mean (for i. logRegLoss params xs.i ys.i) - --- cross-validation? minibatches? -optParams : Params D -optParams = optimize 0 logRegLoss - diff --git a/examples/bad-binary-file.dxbo b/examples/bad-binary-file.dxbo deleted file mode 100644 index 0fb628742..000000000 --- a/examples/bad-binary-file.dxbo +++ /dev/null @@ -1,3 +0,0 @@ --- dex-object-file-v0.0.1 num-header-bytes 128 --------------------------------- -type: (2=>1=>Real, Int) -bufferSizes: [8, 8, 8] diff --git a/examples/brownian_motion.dx b/examples/brownian_motion.dx index f97d0f759..9f9456291 100644 --- a/examples/brownian_motion.dx +++ b/examples/brownian_motion.dx @@ -1,22 +1,22 @@ +include "plot.dx" UnitInterval = Float -bmIter : (Key & Float & Float & UnitInterval) -> (Key & Float & Float & UnitInterval) = - \(key, y, sigma, t). - (kDraw, kL, kR) = splitKey3 key - t' = abs (t - 0.5) - y' = sigma * randn kDraw * (0.5 - t') - key' = select (t > 0.5) kL kR - (key', y + y', sigma / sqrt 2.0, t' * 2.0) +def bmIter ((key, y, sigma, t):(Key & Float & Float & UnitInterval)) : + (Key & Float & Float & UnitInterval) = + [kDraw, kL, kR] = splitKey key + t' = abs (t - 0.5) + y' = sigma * randn kDraw * (0.5 - t') + key' = select (t > 0.5) kL kR + (key', y + y', sigma / sqrt 2.0, t' * 2.0) -sampleBM : Key -> UnitInterval -> Float = - \key t. - (_, y, _, _) = fold (key, 0.0, 1.0, t) \i:(Fin 10). bmIter - y +def sampleBM (key:Key) (t:UnitInterval) : Float = + (_, y, _, _) = fold (key, 0.0, 1.0, t) \i:(Fin 10). bmIter + y xs = linspace (Fin 1000) 0.0 1.0 ys = map (sampleBM (newKey 0)) xs --- :plot zip xs ys --- > +:html showPlot $ xyPlot xs ys +> diff --git a/examples/bugs.dx b/examples/bugs.dx deleted file mode 100644 index b0bd31919..000000000 --- a/examples/bugs.dx +++ /dev/null @@ -1,35 +0,0 @@ --- we don't do let generalization on patterns, but this is a problem if --- generalization is required. This fails: -(f, g) = (lam x. x, lam x. x) - - --- printing of tuple-index tables not implemented -x = [1,2,3] -:p for (i,j). iadd x.i x.j - - --- out-of-bounds indexing - need to wrap indices -:p let litArr = [10, 5, 3] - in litArr.(asidx 4) -> 5 - --- polymorphic declarations without explicit types crash the compiler --- (should be a straightforward error message) -f x = x - --- apparently we're treating unbound type aliases as things to infer -x : N -x = 1 - --- need a type class constrain for index sets so that this is an error -:t for i:Int. 1 - --- Bad error message because we lose provenance of the constraint -:t lam x. - z = iadd x 1 - y = fadd x 1.0 - (z, y) -> Type error: -> Expected: Int -> Actual: Float -> In: From subst diff --git a/examples/chol.dx b/examples/chol.dx index 2ef573113..63473ba91 100644 --- a/examples/chol.dx +++ b/examples/chol.dx @@ -1,77 +1,65 @@ ' # Cholesky Factorization https://en.wikipedia.org/wiki/Cholesky_decomposition -' ### Matrix Math - -eye : n=>n=>Float -eye = for i j. select (i == j) 1.0 0.0 - -mmadd: (n=>m=>Float)->(n=>m=>Float)->(n=>m=>Float) -mmadd x y = for i j. x.i.j + y.i.j - ' ## Cholesky Algorithm -chol : (n=>n=>Float) -> (n=>n=>Float) -chol x = getState (for _ _. 0.0) \buf. - for i. - for j':(...i). +def chol [Eq n] (x:n=>n=>Float) : (n=>n=>Float) = + yieldState zero \buf. + for_ i. for j':(..i). j = %inject(j') - row = for k:(..n=>Float -> n=>Float -> n=>Float -trisolveL mat b = getState (for _. 0.0) \buf. - for i. - row = for j:(..n=>Float) (b:n=>Float) : n=>Float = + yieldState zero \buf. for i. + row = for j:(..n=>Float -> n=>Float -> n=>Float -trisolveU mat b = getState (for _. 0.0) \buf. - rof i. - row = for j:(i...). mat.i.%inject(j) - xPrev = for j:(i...). getAt buf %inject(j) - putAt buf i $ (b.i - vdot row xPrev) / mat.i.i +def trisolveU (mat:n=>n=>Float) (b:n=>Float) : n=>Float = + yieldState zero \buf. rof i. + row = for j:(i..). mat.i.%inject(j) + xPrev = for j:(i..). get (buf!%inject j) + buf!i := (b.i - vdot row xPrev) / mat.i.i -psdsolve : n=>n=>Float -> n=>Float -> n=>Float -psdsolve mat b = +def psdsolve [Eq n] (mat:n=>n=>Float) (b:n=>Float) : n=>Float = l = chol mat trisolveU (transpose l) $ trisolveL l b ' Test -type N = 4 -(k1, k2) = splitKey $ newKey 0 +N = Fin 4 +[k1, k2] = splitKey $ newKey 0 -psd : N=>N=>Float -psd = +psd : N=>N=>Float = a = for i:N j:N. randn $ ixkey2 k1 i j - x = mmp a (transpose a) - mmadd x eye + x = a ** transpose a + x + eye -l : N=>N=>Float -l = chol psd +l : N=>N=>Float = chol psd :p l -> [ [2.021765, 0.0, 0.0, 0.0] -> , [-1.7950183, 1.9901744, 0.0, 0.0] -> , [-0.89788574, 0.18675673, 1.9802661, 0.0] -> , [1.4457518, -0.29644823, 0.72458607, 2.2308075] ] +> [ [2.021765, 0., 0., 0.] +> , [-1.795019, 1.990174, 0., 0.] +> , [-0.897886, 0.186757, 1.980266, 0.] +> , [1.445752, -0.296448, 0.724586, 2.230807] ] -psdReconstructed = l `mmp` transpose l +psdReconstructed = l ** transpose l :p sum for (i, j). sq (psd.i.j - psdReconstructed.i.j) -> 2.4651903e-32 +> 0. -vec : N=>Float -vec = for i. randn $ ixkey k2 i +vec : N=>Float = arb k2 -:p (vec, mvp psd (psdsolve psd vec)) -> ( [1.2112769, 0.23284969, -0.74191034, 0.8833507] -> , [1.2112769, 0.23284969, -0.74191034, 0.8833507] ) +:p (vec, psd **. psdsolve psd vec) +> ( [1.211277, 0.23285, -0.741911, 0.883351] +> , [1.211277, 0.23285, -0.741911, 0.883351] ) diff --git a/examples/ctc.dx b/examples/ctc.dx index 6b63fa852..aa7b5fc77 100644 --- a/examples/ctc.dx +++ b/examples/ctc.dx @@ -48,8 +48,11 @@ def logaddexp (x:Float) (y:Float) : Float = m = max x y m + ( log ( (exp (x - m) + exp (y - m)))) -def ctc (dict: Eq vocab) ?=> (dict2: Eq position) ?=> (dict3: Eq time) ?=> (blank: vocab) - (logits: time=>vocab=>Float) (labels: position=>vocab) : Float = +def ctc [Eq vocab, Eq position, Eq time] + (blank: vocab) + (logits: time=>vocab=>Float) + (labels: position=>vocab) + : Float = -- Computes log p(labels | logits), marginalizing over possible alignments. -- Todo: remove unnecessary implicit type annotations once -- Dex starts putting implicit types in scope. @@ -103,22 +106,22 @@ def randIdxNoZero (n:Type) -> (k:Key) : n = unif = rand k fromOrdinal n $ (1 + (FToI (floor ( unif * IToF ((size n) - 1))))) -vocab = Fin 6 +Vocab = Fin 6 position = Fin 3 -blank = 0@vocab +blank = 0@Vocab -- Create random logits -time = Fin 4 -logits = for i:time j:vocab. (randn $ ixkey2 (newKey 0) i j) +Time = Fin 4 +logits : Time => Vocab => Float = arb $ newKey 0 -- Create random labels -labels = for i:position. randIdxNoZero vocab (newKey (ordinal i)) +labels = for i:position. randIdxNoZero Vocab (newKey (ordinal i)) :p labels > [(1@Fin 6), (5@Fin 6), (5@Fin 6)] -- Evaluate marginal probability of labels given logits :p exp $ ctc blank logits labels -> 1.0398488e-3 +> 0.00104 @@ -130,14 +133,14 @@ labels = for i:position. randIdxNoZero vocab (newKey (ordinal i)) -- e.g. the summed-over labels should include blanks. -:p sum for i:vocab. +:p sum for i:Vocab. exp $ ctc blank logits [i] -> 0.14146839 +> 0.141468 -:p sum for (i, j):(vocab & vocab). +:p sum for (i, j):(Vocab & Vocab). exp $ ctc blank logits [i, j] -> 0.7091234 +> 0.709123 -:p sum for (i, j, k):(vocab & vocab & vocab). +:p sum for (i, j, k):(Vocab & Vocab & Vocab). exp $ ctc blank logits [i, j, k] -> 0.9251011 +> 0.925101 diff --git a/examples/dxbo-example.dxbo b/examples/dxbo-example.dxbo deleted file mode 100644 index 656783b3e..000000000 Binary files a/examples/dxbo-example.dxbo and /dev/null differ diff --git a/examples/flop-tests.dx b/examples/flop-tests.dx deleted file mode 100644 index e500ddc75..000000000 --- a/examples/flop-tests.dx +++ /dev/null @@ -1,33 +0,0 @@ -matmul : i=>j=>Float -> j=>k=>Float -> i=>k=>Float -matmul x y = for i k. sum (for j. x.i.j * y.j.k) - -_, N = unpack range 10 -_, M = unpack range 10 - -k = newKey 0 - -mat : N=>N=>Float -mat = for i j. rand (ixkey (ixkey k i) j) - -:flops matmul mat mat -> %fadd 1 N^3 -> %fmul 1 N^3 -> copy 1 N^3 - --- This should be O(N) but we're instantiating and adding zeros -:flops transposeLinear (llam xs. for i. xs.i) (for i:N. 0.0) -> %%int_to_index_set 1 N^1 -> %eq 1 N^2 -> %fadd 1 N^2 -> %isub 2 N^1 -> %select 1 N^2 -> copy 1 + 1 N^1 - --- This should be O(NM) but we're instantiating and adding zeros -:flops transposeLinear (llam m. for i j. m.j.i) (for i:N j:M. 0.0) -> %%int_to_index_set 1 M^1 N^1 + 1 N^1 -> %eq 1 M^1 N^2 + 1 M^2 N^1 -> %fadd 1 M^1 N^2 + 1 M^2 N^2 -> %isub 2 M^1 N^1 + 2 N^1 -> %select 1 M^1 N^2 + 1 M^2 N^2 -> copy 1 + 1 M^1 N^1 + 2 N^1 diff --git a/examples/fluidsim.dx b/examples/fluidsim.dx index f2c6fb8b7..e1ea7e9ac 100644 --- a/examples/fluidsim.dx +++ b/examples/fluidsim.dx @@ -2,57 +2,48 @@ Fluid simulation code based on [Real-Time Fluid Dynamics for Games](https://www.josstam.com/publications) by Jos Stam -include "examples/plot.dx" - -def zeroedges (dict:VSpace a) ?=> (n:Type) ?-> (m:Type) ?-> (x: n=>m=>a) : n=>m=>a = - -- Todo: update in place without starting with a copy. - snd $ withState x \buf. - for i j. - edge = i==0 || j==0 || i==ordinal n || i==ordinal m - select edge (buf!(i@n)!(j@m) := zero) () +include "plot.dx" def wrapidx (n:Type) -> (i:Int) : n = -- Index wrapping around at ends. asidx $ mod i $ size n -def incwrap (n:Type) ?-> (i:n) : n = - -- Increments index, wrapping around at ends. +def incwrap (i:n) : n = -- Increment index, wrapping around at ends. asidx $ mod ((ordinal i) + 1) $ size n -def decwrap (n:Type) ?-> (i:n) : n = - -- Decrements index, wrapping around at ends. +def decwrap (i:n) : n = -- Decrement index, wrapping around at ends. asidx $ mod ((ordinal i) - 1) $ size n -def finite_difference_neighbours (n:Type) ?-> (x:n=>Float) : n=>Float = +def finite_difference_neighbours [Add a] (x:n=>a) : n=>a = for i. x.(incwrap i) - x.(decwrap i) -def add_neighbours (n:Type) ?-> (x:n=>Float) : n=>Float = +def add_neighbours [Add a] (x:n=>a) : n=>a = for i. x.(incwrap i) + x.(decwrap i) -def apply_along_axis1 (f : b=>Float -> b=>Float) (x : b=>c=>Float) : b=>c=>Float = +def apply_along_axis1 (f:b=>a -> b=>a) (x:b=>c=>a) : b=>c=>a = transpose for j. f for i. x.i.j -def apply_along_axis2 (f : c=>Float -> c=>Float) (x : b=>c=>Float) : b=>c=>Float = +def apply_along_axis2 (f:c=>a -> c=>a) (x:b=>c=>a) : b=>c=>a = for i. f x.i -def fdx (x : n=>m=>Float) : (n=>m=>Float) = +def fdx [Add a] (x:n=>m=>a) : (n=>m=>a) = apply_along_axis1 finite_difference_neighbours x -def fdy (x : n=>m=>Float) : (n=>m=>Float) = +def fdy [Add a] (x:n=>m=>a) : (n=>m=>a) = apply_along_axis2 finite_difference_neighbours x -def divergence (vx : n=>m=>Float) (vy : n=>m=>Float) : (n=>m=>Float) = +def divergence [Add a] (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) = fdx vx + fdy vy -def add_neighbours_2d (x : n=>m=>Float) : (n=>m=>Float) = +def add_neighbours_2d [Add a] (x:n=>m=>a) : (n=>m=>a) = ax1 = apply_along_axis1 add_neighbours x ax2 = apply_along_axis2 add_neighbours x ax1 + ax2 -def project (v: n=>m=>(Fin 2)=>Float) : n=>m=>(Fin 2)=>Float = +def project [VSpace a] (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a = -- Project the velocity field to be approximately mass-conserving, -- using a few iterations of Gauss-Seidel. - h = 0.01 -- todo: work out units + h = 1.0 / IToF (size n) -- unpack into two scalar fields vx = for i j. v.i.j.(fromOrdinal _ 0) @@ -60,43 +51,27 @@ def project (v: n=>m=>(Fin 2)=>Float) : n=>m=>(Fin 2)=>Float = div = -0.5 .* h .* (divergence vx vy) - p_init = for i. for j. 0.0 - p = snd $ withState p_init \state. + p = yieldState zero \state. for i:(Fin 10). - p = get state - state := (1.0 / 4.0) .* (div + add_neighbours_2d p) + state := (1.0 / 4.0) .* (div + add_neighbours_2d (get state)) vx = vx - (0.5 / h) .* fdx(p) vy = vy - (0.5 / h) .* fdy(p) - for i j. [vx.i.j, vy.i.j] -- pack back into a vector field - - -- zeroedges v -- BUG: Crashes with "Not implemented Int" + for i j. [vx.i.j, vy.i.j] -- pack back into a table. -def bilinear_interp (dict:VSpace a) ?=> (right_weight:Float) --o (bottom_weight:Float) --o - (topleft: a) --o (bottomleft: a) --o (topright: a) --o (bottomright: a) --o : a = +def bilinear_interp [VSpace a] (right_weight:Float) (bottom_weight:Float) + (topleft: a) (bottomleft: a) (topright: a) (bottomright: a) : a = left = (1.0 - right_weight) .* ((1.0 - bottom_weight) .* topleft + bottom_weight .* bottomleft) right = right_weight .* ((1.0 - bottom_weight) .* topright + bottom_weight .* bottomright) left + right - -N = Fin 100 -M = Fin 100 - --- BUG: Changing the order of implicit arguments causes an error further down. --- i.e. it doesn't work to start the next line with --- (n:Type) ?-> (m:Type) ?-> (dict:VSpace a) ?=> -def advect (dict:VSpace a) ?=> (n:Type) ?-> (m:Type) ?-> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = +def advect [VSpace a] (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = -- Move field f according to x and y velocities (u and v) -- using an implicit Euler integrator. - -- Create table of cell locations. - -- BUG: using n and m below causes a crash, so I hardcoded it for now. - numrows = 100.0 -- IToF $ ordinal n - numcols = 100.0 -- IToF $ ordinal m - - cell_xs = linspace n 0.0 numrows - cell_ys = linspace m 0.0 numcols + cell_xs = linspace n 0.0 $ IToF (size n) + cell_ys = linspace m 0.0 $ IToF (size m) for i j. -- Location of source of flow for this cell. No meshgrid! @@ -108,46 +83,64 @@ def advect (dict:VSpace a) ?=> (n:Type) ?-> (m:Type) ?-> (f: n=>m=>a) (v: n=>m=> source_row = floor center_ys -- Relative weight of right-hand and bottom cells. - -- TODO: clipping shouldn't be necessary here, find out why it is. - right_weight = clip (0.0, 1.0) $ center_xs - source_col - bottom_weight = clip (0.0, 1.0) $ center_ys - source_row + right_weight = center_xs - source_col + bottom_weight = center_ys - source_row -- Cast back to indices, wrapping around edges. - source_col_int = FToI source_col - source_row_int = FToI source_row - l = wrapidx n source_col_int - r = wrapidx n (source_col_int + 1) - t = wrapidx m source_row_int - b = wrapidx m (source_row_int + 1) + l = wrapidx n (FToI source_col) + r = wrapidx n ((FToI source_col) + 1) + t = wrapidx m (FToI source_row) + b = wrapidx m ((FToI source_row) + 1) -- A convex weighting of the 4 surrounding cells. bilinear_interp right_weight bottom_weight f.l.t f.l.b f.r.t f.r.b -def fluidsim (dict: VSpace a) ?=> (num_steps: Int) (color_init: n=>m=>a) - (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = - (color_final, v) = snd $ withState (color_init, v) \state. +def fluidsim [ VSpace a] (num_steps: Int) (color_init: n=>m=>a) + (v: n=>m=>(Fin 2)=>Float) : (Fin num_steps)=>n=>m=>a = + withState (color_init, v) \state. for i:(Fin num_steps). (color, v) = get state v = advect v v -- Move velocities v = project v -- Project to be volume-preserving - color = advect color v -- Move color - state := (color, v) - color_final + color' = advect color v -- Move color + state := (color', v) + color '### Demo +N = Fin 50 +M = Fin 50 + -- Create random velocity field. def ixkey3 (k:Key) (i:n) (j:m) (k2:o) : Key = hash (hash (hash k (ordinal i)) (ordinal j)) (ordinal k2) -v = for i:N j:M k:(Fin 2). 3.0 * (randn $ ixkey3 (newKey 0) i j k) +init_velocity = for i:N j:M k:(Fin 2). + 3.0 * (randn $ ixkey3 (newKey 0) i j k) -- Create diagonally-striped color pattern. init_color = for i:N j:M. - BToF $ (sin $ (IToF $ (ordinal j) + (ordinal i)) / 8.0) > 0.0 + r = BToF $ (sin $ (IToF $ (ordinal j) + (ordinal i)) / 8.0) > 0.0 + b = BToF $ (sin $ (IToF $ (ordinal j) - (ordinal i)) / 6.0) > 0.0 + g = BToF $ (sin $ (IToF $ (ordinal j) + (ordinal i)) / 4.0) > 0.0 + [r, g, b] -- Run fluid sim and plot it. -num_steps = 50 -final_color = fluidsim num_steps init_color v +num_steps = 5 +:html imseqshow $ fluidsim num_steps init_color init_velocity +> + +'### Gradient test + +target = transpose init_color + +-- This is partial +def last (xs:n=>a) : a = xs.((size n - 1)@_) + +def objective (v:N=>M=>(Fin 2)=>Float) : Float = + final_color = last $ fluidsim num_steps init_color v + sum for (i, j, c). sq (final_color.i.j.c - target.i.j.c) + +init_vel_grad = grad objective zero -:html matshow final_color +:html imshow for i j. [0.0, init_vel_grad.i.j.(0@_), init_vel_grad.i.j.(1@_)] > diff --git a/examples/include-test.dx b/examples/include-test.dx deleted file mode 100644 index 87a74a4bd..000000000 --- a/examples/include-test.dx +++ /dev/null @@ -1,109 +0,0 @@ - -include "examples/included.dx" -> 30 -> 40 - -:p x -> 10 - -load dxo "examples/somedata.dxo" as dat - -:t dat -> (Float, 2, (2=>(3=>Float)), (2=>(Int, Bool))) - -:p dat -> (1.0, 1@2, [[2.0, 1.0, 3.0], [0.0, -10.0, 20.0]], [(1, True), (2, False)]) - -dump dxbo "test-scratch/bin-data-dump.dxbo" dat - -load dxbo "test-scratch/bin-data-dump.dxbo" as dat2 - -:t dat2 -> (Float, 2, (2=>(3=>Float)), (2=>(Int, Bool))) - -:p dat2 -> (1.0, 1@2, [[2.0, 1.0, 3.0], [0.0, -10.0, 20.0]], [(1, True), (2, False)]) - -load dxbo "not-a-file" as notData -> IO error: not-a-file: openFile: does not exist (No such file or directory) - -load dxbo "examples/bad-binary-file.dxbo" as badData -> IO error: unexpected number of buffers: [16,8] vs [8,8,8] -> Validation error -> Claimed header length: 128 -> Claimed total length: 152 -> Actual file length: 128 -> Header data: -> type: ((2=>(1=>Float)), Int) -> bufferSizes: [8,8,8] - -load dxbo "test-scratch/pydata.dxbo" as pydata - -:t pydata -> ( Float -> , Int -> , () -> , Bool -> , Bool -> , (Int, (3=>Float)) -> , (2=>(3=>Float)) -> , (3=>(2=>Float)) -> , Float -> , Float -> , (1=>(1=>(1=>Int))) -> , (4=>Int) -> , (3=>Bool) ) - -:p pydata -> ( 1.2 -> , 12 -> , () -> , True -> , False -> , (-2, [1.0, 2.0, 3.0]) -> , [[10.0, 20.0, 30.0], [0.1, 0.2, 0.3]] -> , [[10.0, 0.1], [20.0, 0.2], [30.0, 0.3]] -> , 1.3 -> , 0.123 -> , [[[1]]] -> , [6, 5, 4, 3] -> , [True, False, True] ) - -dump dxbo "/tmp/stuff.dxbo" pydata - -load dxbo "/tmp/stuff.dxbo" as xs - -:t xs -> ( Float -> , Int -> , () -> , Bool -> , Bool -> , (Int, (3=>Float)) -> , (2=>(3=>Float)) -> , (3=>(2=>Float)) -> , Float -> , Float -> , (1=>(1=>(1=>Int))) -> , (4=>Int) -> , (3=>Bool) ) - -:p xs -> ( 1.2 -> , 12 -> , () -> , True -> , False -> , (-2, [1.0, 2.0, 3.0]) -> , [[10.0, 20.0, 30.0], [0.1, 0.2, 0.3]] -> , [[10.0, 0.1], [20.0, 0.2], [30.0, 0.3]] -> , 1.3 -> , 0.123 -> , [[[1]]] -> , [6, 5, 4, 3] -> , [True, False, True] ) - -load dxbo "examples/dxbo-example.dxbo" as exampleData - -:p exampleData -> (1, 2, [(3, [4, 5]), (6, [7, 8]), (9, [10, 11])]) diff --git a/examples/included.dx b/examples/included.dx deleted file mode 100644 index 14dd0687d..000000000 --- a/examples/included.dx +++ /dev/null @@ -1,8 +0,0 @@ - -x = 10 - -y = 20 - -:p 30 - -:p 40 diff --git a/examples/interp-tests.dx b/examples/interp-tests.dx deleted file mode 100644 index 4b1b3503c..000000000 --- a/examples/interp-tests.dx +++ /dev/null @@ -1,18 +0,0 @@ --- language features implemented in interpreter but not yet in the compiler - -:p 1 -> 1 - -_, M = unpack range 4 - -xs : M => (E n. n=>Int) -xs = for i. _, N = unpack range (asint i) - x = for j:N. asint j - pack x, N, E n. n=>Int - -for i. x, N2 = unpack xs.i -- TODO: underscore type binder - sum x -> [0, 0, 1, 3] - -:p filter (lam x. x > 2.0) [4.0, 0.0, 10.0, 2.0] -> pack [4.0, 10.0], 2, (Ea.(a=>Float)) diff --git a/examples/isomorphisms.dx b/examples/isomorphisms.dx index b127c16a2..9668eac42 100644 --- a/examples/isomorphisms.dx +++ b/examples/isomorphisms.dx @@ -16,21 +16,21 @@ cycleThree : Iso (a & b & c) (b & c & a) = and flipped with `flipIso` :p appIso cycleThree (1, 2.0, 3) -> (2.0, (3, 1)) +> (2., (3, 1)) :p revIso cycleThree (1, 2.0, 3) -> (3, (1, 2.0)) +> (3, (1, 2.)) :p appIso (flipIso cycleThree) (1, 2.0, 3) -> (3, (1, 2.0)) +> (3, (1, 2.)) 'They can also be composed with `&>>`: :p appIso (cycleThree &>> cycleThree) (1, 2.0, 3) -> (3, (1, 2.0)) +> (3, (1, 2.)) :p appIso (cycleThree &>> cycleThree &>> cycleThree) (1, 2.0, 3) -> (1, (2.0, 3)) +> (1, (2., 3)) 'Note that we assume but do not check that the isomorphism is lawful (i.e. `appIso iso $ revIso iso x == x` for all `x`, or equivalently @@ -88,13 +88,13 @@ Record accessor isomorphisms can be passed into the helper function `getAt`: 'We can also do other types of things: :p popAt #foo {foo=1, bar=2.0} -> {bar = 2.0} +> {bar = 2.} :p pushAt #foo 3.0 {foo=1, bar=2.0} -> {bar = 2.0, foo = 3.0, foo = 1} +> {bar = 2., foo = 3., foo = 1} :p setAt #foo 2 {foo=1, bar=2.0} -> {bar = 2.0, foo = 2} +> {bar = 2., foo = 2} 'These helper functions work with any "lens-like" isomorphism. For instance, we can select everything except for a particular field: @@ -103,7 +103,7 @@ we can select everything except for a particular field: > ((a:Type) ?-> (b:Type) ?-> (c:Type) ?-> (Iso a (b & c)) -> Iso a (c & b)) :p getAt (exceptLens #foo) {foo=1, bar=2.0, baz=3} -> {bar = 2.0, baz = 3} +> {bar = 2., baz = 3} '## Variant accessors and prism-like helpers @@ -127,7 +127,7 @@ Similarly, there are prism-like helpers > {| foo = 3 |} :p matchWith (exceptPrism #?foo) $ {|bar = 1.0|}:{foo:Int | bar:Float} -> (Just {| bar = 1.0 |}) +> (Just {| bar = 1. |}) '## Record zipper isomorphisms The isomorphisms shown above are specialized for removing a single field from @@ -200,7 +200,7 @@ ordinary record accessor lens. '`splitR` can be used if you want to process multiple fields at once: :p pushAt (splitR &>> #&a &>> #&b) {a=1, b=2.0} {c=3, d=4.0} -> {a = 1, b = 2.0, c = 3, d = 4.0} +> {a = 1, b = 2., c = 3, d = 4.} '## Variant zipper isomorphisms Just as there are record zipper isomorphisms, there are also variant diff --git a/examples/jax-tests.dx b/examples/jax-tests.dx deleted file mode 100644 index 5b13f6f3a..000000000 --- a/examples/jax-tests.dx +++ /dev/null @@ -1,28 +0,0 @@ - -x = 1.0 + 2.0 - -:p x + 3.0 -> 6.0 - -:p - getAccumulator \ref. - ref += 1.0 - ref += 2.0 -> 3.0 - -:p for i:3. x + 1.0 -> [4.0, 4.0, 4.0] - -xs = for i:4. 2.0 - -:p sum for i. xs.i * xs.i -> 16.0 - -:p float 1 -> 1.0 - -:p for i:3. neg (float 2) -> [-2.0, -2.0, -2.0] - -:p 0.0 + (1.0 + (2.0 + 0.0)) -> 3.0 diff --git a/examples/latex.dx b/examples/latex.dx new file mode 100644 index 000000000..fcb0dc07b --- /dev/null +++ b/examples/latex.dx @@ -0,0 +1,41 @@ +'# $\href{https://katex.org/}{\KaTeX}$ rendering examples + +'This document demonstrates $\KaTeX$ rendering in literate Dex programs. + +'## Random examples + +'$$\text{This is a multiline equation:} \\\\ \textbf{A}\textbf{x} = \textbf{b}$$ + +'$$f(\relax{x}) = \int_{-\infty}^\infty \hat{f}(\xi)\,e^{2 \pi i \xi x} \,d\xi$$ + +'## [Environments](https://katex.org/docs/supported.html#environments) + +'$$\begin{matrix} a & b \\\\ c & d \end{matrix}$$ + +'$$\begin{pmatrix} a & b \\\\ c & d \end{pmatrix}$$ + +'$$\begin{bmatrix} a & b \\\\ c & d \end{bmatrix}$$ + +'$$\def\arraystretch{1.5} \begin{array}{c:c:c} a & b & c \\\\ \hline d & e & f \\\\ \hdashline g & h & i \end{array}$$ + +'$$\begin{aligned} a&=b+c \\\\ d+e&=f \end{aligned}$$ + +'$$\begin{alignedat}{2} 10&x+ &3&y = 2 \\\\ 3&x+&13&y = 4 \end{alignedat}$$ + +'$$\begin{gathered} a=b \\\\ e=b+c \end{gathered}$$ + +'$$x = \begin{cases} a &\text{if } b \\\\ c &\text{if } d \end{cases}$$ + +'$$\begin{rcases} a &\text{if } b \\\\ c &\text{if } d \end{rcases} \Rightarrow \dots$$ + +'## [Layout annotation](https://katex.org/docs/supported.html#annotation) + +'$$\overbrace{a+b+c}^{\text{note}}$$ + +'$$\underbrace{a+b+c}_{\text{note}}$$ + +'$$\xcancel{\text{second-order array combinators}}$$ + +'## [Logic and Set Theory](https://katex.org/docs/supported.html#logic-and-set-theory) + +'$$\begin{aligned} \forall \\; & \texttt{\textbackslash forall} & \complement \\; & \texttt{\textbackslash complement} & \therefore \\; & \texttt{\textbackslash therefore} & \emptyset \\; & \texttt{\textbackslash emptyset} \\\\ \exists \\; & \texttt{\textbackslash exists} & \subset \\; & \texttt{\textbackslash subset} & \because \\; & \texttt{\textbackslash because} & \empty \\; & \texttt{\textbackslash empty} \\\\ \exist \\; & \texttt{\textbackslash exist} & \supset \\; & \texttt{\textbackslash supset} & \mapsto \\; & \texttt{\textbackslash mapsto} & \varnothing \\; & \texttt{\textbackslash varnothing} \\\\ \nexists \\; & \texttt{\textbackslash nexists} & \mid \\; & \texttt{\textbackslash mid} & \to \\; & \texttt{\textbackslash to} & \implies \\; & \texttt{\textbackslash implies} \\\\ \in \\; & \texttt{\textbackslash in} & \land \\; & \texttt{\textbackslash land} & \gets \\; & \texttt{\textbackslash gets} & \impliedby \\; & \texttt{\textbackslash impliedby} \\\\ \isin \\; & \texttt{\textbackslash isin} & \lor \\; & \texttt{\textbackslash lor} & \leftrightarrow \\; & \texttt{\textbackslash leftrightarrow} & \iff \\; & \texttt{\textbackslash iff} \\\\ \notin \\; & \texttt{\textbackslash notin} & \ni \\; & \texttt{\textbackslash ni} & \notni \\; & \texttt{\textbackslash notni} & \neg \\; & \texttt{\textbackslash neg} \\\\ \lnot \\; & \texttt{\textbackslash lnot} \\\\ \end{aligned}$$ diff --git a/examples/linear_algebra.dx b/examples/linear_algebra.dx index 7dce43f3c..2d0cffd14 100644 --- a/examples/linear_algebra.dx +++ b/examples/linear_algebra.dx @@ -1,78 +1,88 @@ '## LU Decomposition and Matrix Inversion -def identity_matrix (_:Eq n) ?=> (_:Add a) ?=> (_:Mul a) ?=> : n=>n=>a = +def identity_matrix [Eq n, Add a, Mul a] : n=>n=>a = for i j. select (i == j) one zero - '### Triangular matrices -def LowerTriMat (n:Type) : Type = i:n=>(..i)=>Float -def UpperTriMat (n:Type) : Type = i:n=>(i..)=>Float +def LowerTriMat (n:Type) (v:Type) : Type = i:n=>(..i)=>v +def UpperTriMat (n:Type) (v:Type) : Type = i:n=>(i..)=>v + +def upperTriDiag (u:UpperTriMat n v) : n=>v = for i. u.i.(0@_) +def lowerTriDiag (l:LowerTriMat n v) : n=>v = for i. l.i.((ordinal i)@_) -def forward_substitute (_:VSpace v) ?=> (a:LowerTriMat n) (b:n=>v) : n=>v = +def forward_substitute [VSpace v] (a:LowerTriMat n Float) (b:n=>v) : n=>v = -- Solves lower triangular linear system (inverse a) **. b - snd $ withState zero \sRef. + yieldState zero \sRef. for i:n. s = sum for k:(.. (a:UpperTriMat n) (b:n=>v) : n=>v = +def backward_substitute [VSpace v] (a:UpperTriMat n Float) (b:n=>v) : n=>v = -- Solves upper triangular linear system (inverse a) **. b - snd $ withState zero \sRef. + yieldState zero \sRef. rof i:n. s = sum for k:(i..). -- dot product - a.i.((ordinal k)@_) .* (get sRef).(%inject k) + a.i.((ordinal k)@_) .* get sRef!(%inject k) sRef!i := (b.i - s) / a.i.(0@_) -- 0 is the diagonal index -- Todo: get rid of these by writing a dependent indexing (!) operator. -def lowerTriIndex (ref:Ref h (LowerTriMat n)) (i:n) : Ref h ((..i)=>Float) = +def lowerTriIndex (ref:Ref h (LowerTriMat n v)) (i:n) : Ref h ((..i)=>v) = %indexRef ref i -def upperTriIndex (ref:Ref h (UpperTriMat n)) (i:n) : Ref h ((i..)=>Float) = +def upperTriIndex (ref:Ref h (UpperTriMat n v)) (i:n) : Ref h ((i..)=>v) = %indexRef ref i '### Permutations -def Permutation (n:Type) : Type = n=>n -def apply_permutation (permutation: n=>n) (array: n=>t) : n=>t = - for i. array.(permutation.i) -def identity_permutation (n:Type) ?-> : Permutation n = - for i. i +-- The sign of the determinant of a permutation is either 1.0 or -1.0 +PermutationSign = Float +def Permutation (n:Type) : Type = (perm:n=>n & PermutationSign) + +def apply_permutation ((perm, _):Permutation n) (xs: n=>t) : n=>t = + for i. xs.(perm.i) + +def identity_permutation : Permutation n = + (for i. i, 1.0) + +def swapInPlace (pRef: Ref h (Permutation n)) (i:n) (j:n) : {State h} Unit = + (permRef, signRef) = (fstRef pRef, sndRef pRef) + tempj = get permRef!j + permRef!j := get permRef!i + permRef!i := tempj + signRef := -(get signRef) + +def perToTable ((perm, _):Permutation n) : n=>n = perm +def permSign ((_, sign):Permutation n) : PermutationSign = sign -'### LU decomposition functions -Sign = Float -- Either 1.0 or -1.0 -def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : (Permutation n & Sign) = - -- Permutes rows of a matrix to make Gaussian elimination more stable. - -- Returns permutation and the sign of its determinant. - snd $ withState (identity_permutation, 1.0) \stateRef. - (pRef, signRef) = (fstRef stateRef, sndRef stateRef) +'### LU decomposition functions + +def pivotize [Eq n] (a:n=>n=>Float) : Permutation n = + -- Gives a row permutation that makes Gaussian elimination more stable. + yieldState identity_permutation \permRef. for j:n. - row_with_largest = argmin for i:(j..). (-(abs a.(%inject i).j)) - row_with_largest = %inject row_with_largest + row_with_largest' = argmin for i:(j..). (-(abs a.(%inject i).j)) + row_with_largest = %inject row_with_largest' case (j == row_with_largest) of True -> () - False -> - tempj = get pRef!j -- Is there a refSwap? - pRef!j := get pRef!row_with_largest - pRef!row_with_largest := tempj - signRef := -(get signRef) - -def lu (_:Eq n) ?=> (a: n=>n=>Float) : - (LowerTriMat n & UpperTriMat n & Permutation n & Sign) = + False -> swapInPlace permRef j row_with_largest + +def lu [Eq n] (a: n=>n=>Float) : + (LowerTriMat n Float & UpperTriMat n Float & Permutation n) = -- Computes lower, upper, and permuntation matrices from a square matrix, -- such that apply_permutation permutation a == lower ** upper. - (permutation, swapcount) = pivotize a + permutation = pivotize a a = apply_permutation permutation a init_lower = for i:n. for j':(..i). select (i == (%inject j')) 1.0 0.0 init_upper = for i:n. for j'':(i..). 0.0 - (lower, upper) = snd $ withState (init_lower, init_upper) \stateRef. + (lower, upper) = yieldState (init_lower, init_upper) \stateRef. lRef = fstRef stateRef uRef = sndRef stateRef @@ -103,10 +113,10 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : ukj = get (upperTriIndex uRef k')!(((ordinal j) - (ordinal k))@_) lik = get (lowerTriIndex lRef i')!((ordinal k)@_) ukj * lik - + uijRef = (upperTriIndex uRef i')!(((ordinal j) - (ordinal i))@_) uijRef := a.(%inject i).j - s - + for i:(j<..). i' = %inject i s = sum for k:(..j). @@ -115,40 +125,38 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : ukj = get (upperTriIndex uRef k')!i'' lik = get (lowerTriIndex lRef i')!((ordinal k)@_) ukj * lik - + i'' = ((ordinal i) + (ordinal j) + 1)@_ ujj = get (upperTriIndex uRef j)!(0@_) lijRef = (lowerTriIndex lRef i'')!((ordinal j)@_) lijRef := (a.i'.j - s) / ujj - (lower, upper, permutation, swapcount) + (lower, upper, permutation) '### General linear algebra functions. -def solve (_:Eq n) ?=> (_:VSpace v) ?=> (a:n=>n=>Float) (b:n=>v) : n=>v = +def solve [Eq n, VSpace v] (a:n=>n=>Float) (b:n=>v) : n=>v = -- There's a small speedup possible by exploiting the fact -- that l always has ones on the diagonal. It would just require a -- custom forward_substitute routine that doesn't divide -- by the diagonal entries. - (l, u, perm, _) = lu a + (l, u, perm) = lu a b' = apply_permutation perm b y = forward_substitute l b' backward_substitute u y -def invert (_:Eq n) ?=> (a:n=>n=>Float) : n=>n=>Float = +def invert [Eq n] (a:n=>n=>Float) : n=>n=>Float = solve a identity_matrix -def determinant (_:Eq n) ?=> (a:n=>n=>Float) : Float = - (l, u, perm, permutation_sign) = lu a - -- formerly u.i.i * l.i.i - (prod for i. u.i.(0@_) * l.i.((ordinal i)@_)) * permutation_sign - -def sign_and_log_determinant (_:Eq n) ?=> (a:n=>n=>Float) : (Float & Float) = - (l, u, perm, permutation_sign) = lu a - -- formerly u.i.i * l.i.i - diags = for i. u.i.(0@_) * l.i.((ordinal i)@_) - sign = permutation_sign * prod for i. sign diags.i - sum_of_log_abs = sum for i. log (abs diags.i) +def determinant [Eq n] (a:n=>n=>Float) : Float = + (l, u, perm) = lu a + prod (for i. (upperTriDiag u).i * (lowerTriDiag l).i) * permSign perm + +def sign_and_log_determinant [Eq n] (a:n=>n=>Float) : (Float & Float) = + (l, u, perm) = lu a + diags = for i. (upperTriDiag u).i * (lowerTriDiag l).i + sign = (permSign perm) * prod for i. sign diags.i + sum_of_log_abs = sum for i. log (abs diags.i) (sign, sum_of_log_abs) diff --git a/examples/mandelbrot.dx b/examples/mandelbrot.dx index 700ede927..468ef17dd 100644 --- a/examples/mandelbrot.dx +++ b/examples/mandelbrot.dx @@ -1,20 +1,18 @@ '# Mandelbrot set -include "examples/plot.dx" +include "plot.dx" 'Escape time algorithm -update : Complex -> Complex -> Complex = - \c z. c + (z * z) +def update (c:Complex) (z:Complex) : Complex = c + (z * z) tol = 2.0 -inBounds : Complex -> Bool = - \z. complex_abs z < tol +def inBounds (z:Complex) : Bool = complex_abs z < tol -escapeTime : Complex -> Float = - \c. fst $ fold (0.0, zero) $ \i:(Fin 1000) (n, z). - z' = update c z - (n + BToF (inBounds z'), z') +def escapeTime (c:Complex) : Float = + fst $ fold (0.0, zero) $ \i:(Fin 1000) (n, z). + z' = update c z + (n + BToF (inBounds z'), z') 'Evaluate on a grid and plot the results diff --git a/examples/mcmc.dx b/examples/mcmc.dx index 9bf7159ed..d205cdbd7 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -1,5 +1,8 @@ +'# Markov Chain Monte Carlo --- === General MCMC utilities === +'## General MCMC utilities + +include "plot.dx" LogProb : Type = Float @@ -9,8 +12,8 @@ def runChain (numSamples: Int) (k:Key) : Fin numSamples => a = - (k1, k2) = splitKey k - fst $ withState (initialize k1) \s. + [k1, k2] = splitKey k + withState (initialize k1) \s. for i:(Fin numSamples). x = step (ixkey k2 i) (get s) s := x @@ -22,18 +25,17 @@ def propose (proposal : a) (k : Key) : a = - acceptProb = exp (logDensity proposal) / exp (logDensity cur) - select (bern acceptProb k) proposal cur + accept = logDensity proposal > (logDensity cur + log (rand k)) + select accept proposal cur -def meanAndCovariance (n:Type) ?-> (d:Type) ?-> - (xs:n=>d=>Float) : (d=>Float & d=>d=>Float) = +def meanAndCovariance (xs:n=>d=>Float) : (d=>Float & d=>d=>Float) = xsMean : d=>Float = (for i. sum for j. xs.j.i) / IToF (size n) xsCov : d=>d=>Float = (for i i'. sum for j. (xs.j.i' - xsMean.i') * (xs.j.i - xsMean.i ) ) / IToF (size n - 1) (xsMean, xsCov) --- === Metropolis-Hastings implementation === +'## Metropolis-Hastings implementation MHParams : Type = Float -- step size @@ -43,16 +45,16 @@ def mhStep (k:Key) (x:d=>Float) : d=>Float = - (k1, k2) = splitKey k + [k1, k2] = splitKey k proposal = x + stepSize .* randnVec k1 propose logProb x proposal k2 --- === HMC implementation === +'## HMC implementation HMCParams : Type = (Int & Float) -- leapfrog steps, step size def leapfrogIntegrate - (_:VSpace a) ?=> + [VSpace a] ((nsteps, dt): HMCParams) (logProb: a -> LogProb) ((x, p): (a & a)) @@ -72,33 +74,40 @@ def hmcStep (x:d=>Float) : d=>Float = hamiltonian = \(x, p). logProb x - 0.5 * vdot p p - (k1, k2) = splitKey k + [k1, k2] = splitKey k p = randnVec k1 proposal = leapfrogIntegrate params logProb (x, p) fst $ propose hamiltonian (x, p) proposal k2 --- === test it out === +'## Test it out + +'Generate samples from a multivariate normal distribution N([1.5, 2.5], [[1., 0.], [0., 0.05]]). def myLogProb (x:(Fin 2)=>Float) : LogProb = x' = x - [1.5, 2.5] neg $ 0.5 * inner x' [[1.,0.],[0.,20.]] x' -hmcParams = (10, 0.1) -mhParams = 0.1 -numSamples = 500 +numSamples = if dex_test_mode () + then 1000 + else 10000 k0 = newKey 1 -hmcSamples = runChain randnVec (hmcStep hmcParams myLogProb) numSamples k0 +mhParams = 0.1 mhSamples = runChain randnVec (mhStep mhParams myLogProb) numSamples k0 -:p meanAndCovariance hmcSamples -> ([1.4468338, 2.4944723], [[1.065676, 2.047594e-2], [2.047594e-2, 5.288498e-2]]) +:p meanAndCovariance mhSamples +> ([0.369159, 2.453517], [[0.575722, 0.08787], [0.08787, 0.125873]]) --- :plot for i. (IToF (ordinal i), hmcSamples.i.(0@_)) --- > +:html showPlot $ yPlot $ + slice (map head mhSamples) 0 (Fin 1000) +> -:p meanAndCovariance mhSamples -> ([0.64555484, 2.4140575], [[0.38236195, 0.17941256], [0.17941256, 0.22895703]]) +hmcParams = (10, 0.1) +hmcSamples = runChain randnVec (hmcStep hmcParams myLogProb) numSamples k0 + +:p meanAndCovariance hmcSamples +> ([1.431633, 2.503093], [[0.964188, 0.005688], [0.005688, 0.049492]]) --- :plot for i. (IToF (ordinal i), mhSamples.i.(0@_)) --- > +:html showPlot $ yPlot $ + slice (map head hmcSamples) 0 (Fin 1000) +> diff --git a/examples/mnist-nearest-neighbors.dx b/examples/mnist-nearest-neighbors.dx index 161e6aa8b..4448420f5 100644 --- a/examples/mnist-nearest-neighbors.dx +++ b/examples/mnist-nearest-neighbors.dx @@ -1,3 +1,6 @@ +'# THIS FILE IS STALE + +'(But we plan to update it at some point) load dxbo "scratch/mnist.dxbo" as mnist diff --git a/examples/ode-integrator.dx b/examples/ode-integrator.dx index 0b87184e2..53e568d5a 100644 --- a/examples/ode-integrator.dx +++ b/examples/ode-integrator.dx @@ -4,13 +4,15 @@ This version is a port of the [Jax implementation](https://github.com/google/jax One difference is that it uses a lower-triangular matrix type for the Butcher tableau, and so avoids zero-padding everywhere. +include "plot.dx" + Time = Float -- Should this go in the prelude? def length (x: d=>Float) : Float = sqrt $ sum for i. sq x.i def (./) (x: d=>Float) (y: d=>Float) : d=>Float = for i. x.i / y.i -def fit_4th_order_polynomial (_:VSpace v) ?=> +def fit_4th_order_polynomial [VSpace v] (z0:v) (z1:v) (z_mid:v) (dz0:v) (dz1:v) (dt:Time) : (Fin 5)=>v = -- dz0 and dz1 are gradient evaluations. a = -2. * dt .* dz0 + 2. * dt .* dz1 - 8. .* z0 - 8. .* z1 + 16. .* z_mid @@ -24,7 +26,7 @@ dps_c_mid = [6025192743. /30085553152. /2., 0., 51252292925. /65400821598. /2., -2691868925. /45128329728. /2., 187940372067. /1594534317056. /2., -1776094331. /19743644256. /2., 11237099. /235043384. /2.] -def interp_fit_dopri (_:VSpace v) ?=> +def interp_fit_dopri [VSpace v] (z0:v) (z1:v) (k:(Fin 7)=>v) (dt:Time) : (Fin 5)=>v = -- Fit a polynomial to the results of a Runge-Kutta step. z_mid = z0 + dt .* (dot dps_c_mid k) @@ -62,13 +64,13 @@ c_error = [35. / 384. - 1951. / 21600., 0., 500. / 1113. - 22642. / 50085., 125. / 192. - 451. / 720., -2187. / 6784. + 12231. / 42400., 11. / 84. - 649. / 6300., -1. / 60.] -def runge_kutta_step (_:VSpace v) ?=> (func:v->Time->v) +def runge_kutta_step [VSpace v] (func:v->Time->v) (z0:v) (f0:v) (t0:Time) (dt:Time) : (v & v & v & (Fin 7)=>v) = - evals_init = snd $ withState zero \r. + evals_init = yieldState zero \r. r!(0@_) := f0 - evals_filled = snd $ withState evals_init \func_evals. for i:(Fin 6). + evals_filled = yieldState evals_init \func_evals. for i:(Fin 6). cur_evals = for j:(..i). get func_evals!((ordinal j)@_) ti = t0 + dt .* alpha.i zi = z0 + dt .* dot beta.i cur_evals @@ -110,15 +112,15 @@ def odeint (func: d=>Float -> Time -> d=>Float) atol = 1.4e-8 -- absolute local error tolerance for solver. max_iters = 10000 - integrate_to_next_time = \iter init_carry. - target_t = times.iter + integrate_to_next_time = \i init_carry. + target_t = times.i - stopping_condition = \(_, _, t, dt, _, _). + continue_condition = \(_, _, t, dt, _, _). -- State of solver: (next state, next f, next time, dt, t, interp coeffs) -- def State (v:Type) : Type = (v & v & Time & Time & Time & (Fin 5)=>v) -- This ended up being unnecessary to spell anywhere, but was -- useful for debugging. - (t < target_t) && (dt > 0.0) && (ordinal iter < max_iters) + (t < target_t) && (dt > 0.0) && (ordinal i < max_iters) possible_step = \(z, f, t, dt, last_t, interp_coeff). (next_z, next_f, next_z_error, k) = runge_kutta_step func z f t dt @@ -132,9 +134,14 @@ def odeint (func: d=>Float -> Time -> d=>Float) select (ratio <= 1.0) move_state stay_state -- Take steps until we pass target_t - new_state = snd $ withState init_carry \state. - while (\(). stopping_condition (get state)) \(). - state := possible_step (get state) + new_state = yieldState init_carry \stateRef. + iter \_. + state = get stateRef + if continue_condition state + then + stateRef := possible_step state + Continue + else Done () (_, _, t, _, last_t, interp_coeff) = new_state -- Interpolate to the target time. @@ -158,17 +165,15 @@ t1 = [1.0] approx_e = odeint myDyn z0 t0 t1 :p approx_e -> [[2.7201762]] +> [[2.720176]] exact_e = [[exp 1.0]] :p (approx_e - exact_e) -- amount of numerical error -> [[1.894474e-3]] +> [[0.001894]] times = linspace (Fin 100) 0.00001 1.0 ys = odeint myDyn z0 t0 times --- :plot --- ys' = for i. ys.i.(fromOrdinal _ 0) --- zip times ys' --- > +:html showPlot $ yPlot for i. ys.i.(fromOrdinal _ 0) +> diff --git a/examples/particle-filter.dx b/examples/particle-filter.dx new file mode 100644 index 000000000..10291f5e6 --- /dev/null +++ b/examples/particle-filter.dx @@ -0,0 +1,68 @@ +def Distribution (range:Type) : Type = + ( Key -> range + & range -> Float) -- log prob + +def Model (state:Type) (observation:Type) : Type = + ( Distribution state -- initial state + & state -> Distribution state -- dynamics + & state -> Distribution observation) -- observations + +def sample (d: Distribution a) (k: Key) : a = + (sampler, _) = d + sampler k + +def simulate (model: Model s v) (t: Int) (key: Key) : Fin t=>(s & v) = + (init, dynamics, observe) = model + [key, subkey] = splitKey key + s0 = sample init subkey + withState s0 \s_ref . + for i. + [k1, k2] = splitKey (ixkey key i) + s = get s_ref + s_next = sample (dynamics s) k1 + v = sample (observe s) k2 + s_ref := s_next + (s, v) + +def filter + (num_particles: Int) (num_timesteps: Int) + (model: Model s v) + (summarize: (Fin num_particles => s) -> a) + (obs: Fin num_timesteps=>v) + (key: Key) + : Fin num_timesteps => a = + (init, dynamics, observe) = model + [key, init_key] = splitKey key + init_particles = for i: (Fin num_particles). sample init (ixkey init_key i) + withState init_particles \p_ref . + for t: (Fin num_timesteps). + p_prev = get p_ref + logLikelihoods = for i. snd (observe p_prev.i) obs.t + [resample_key, dynamics_key] = splitKey (ixkey key t) + resampled_idxs = categoricalBatch logLikelihoods resample_key + p_resampled = for i. p_prev.(resampled_idxs.i) + p_next = for i. fst (dynamics p_resampled.i) (ixkey dynamics_key i) + p_ref := p_next + summarize p_resampled + +def normalDistn (mean: Float) (var: Float) : Distribution Float = + ( \k. (randn k) * (sqrt var) + mean + , \v. -0.5 * (sq (v - mean)) / var - 0.5 * log (2.0 * pi * var) + ) + +gaussModel : Model Float Float = + ( normalDistn 0.1 0.1 + , \s. normalDistn s 1.0 + , \s. normalDistn s 1.0 + ) + +timesteps = 10 +num_particles = 10000 + +truth = for i:(Fin timesteps). + s = IToF (ordinal i) + (s, sample (normalDistn s 1.0) $ ixkey (newKey 0) i) + +filtered = filter num_particles _ gaussModel mean (map snd truth) (newKey 0) + +-- :p for i. (truth.i, filtered.i) diff --git a/examples/particle-swarm-optimizer.dx b/examples/particle-swarm-optimizer.dx index dc5126e5d..58227779e 100644 --- a/examples/particle-swarm-optimizer.dx +++ b/examples/particle-swarm-optimizer.dx @@ -16,16 +16,16 @@ rosenbrock2 : ((Fin 2)=>Float) -> Float = ' Min should be at 1.0, 1.0 :p rosenbrock 1.0 1.000 -> 0.0 +> 0. :p rosenbrock2 [1.0, 1.000] -> 0.0 +> 0. :p rosenbrock 1.0 1.02 -> 3.199994e-2 +> 0.032 :p rosenbrock2 [1.0, 1.02] -> 3.199994e-2 +> 0.032 ' ## Helper functions @@ -43,7 +43,7 @@ randBounded : Key -> (d=>Float)->(d=>Float)->(d=>Float) = for i. lb.i + ((rand $ ixkey key i) * (ub.i - lb.i)) :p randBounded (newKey 4) [1.0, -2.0] [-1.0, 2.0] -> [-0.35101044, 1.4935503] +> [-0.35101, 1.49355] ' ## The Optimizer itself. We have **arguments**: @@ -57,7 +57,6 @@ We have **arguments**: ' **Returns**: the optimal point found with-in the bounds on the input domain of `f`. def optimize - (d:Type) ?-> (np':Int) -- number of particles (niter:Int) -- number of iterations (key:Key) -- random seed @@ -72,7 +71,7 @@ def optimize minbyfst pbests.p (f newPositions.p, newPositions.p) newGbest:(Float & d=>Float) = minbyfst gbest (minimumbyfst newPbests) - (keyG, keyP, keyNext) = splitKey3 keyL + [keyG, keyP, keyNext] = splitKey keyL (gscore, gloc) = gbest plocs = map snd pbests gVel:(np=>d=>Float) = for p i. @@ -87,7 +86,7 @@ def optimize (keyNext,newGbest,newPbests,newPositions,newVelocities) randInit1 = \keyI1. randBounded keyI1 lb ub randInit = \keyI. for p:np. randInit1 $ ixkey keyI p - (keyPos, keyVel, keyLoop) = splitKey3 key + [keyPos, keyVel, keyLoop] = splitKey key initPositions:(np=>d=>Float) = randInit keyPos initVelocities:(np=>d=>Float) = randInit keyVel initPs:(np=>(Float & d=>Float)) = for p. (f initPositions.p, initPositions.p) @@ -103,13 +102,13 @@ Run it for more iterations and result should improve. Which it indeed does. :p optimize 50 10 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [3.7902741, 14.911411] +> [0.076986, 0.232818] :p optimize 50 20 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [1.737732, 3.1227117] +> [0.90125, 0.750447] :p optimize 50 100 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [1.0062296, 1.0128789] +> [0.999069, 0.998192] :p optimize 50 1000 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [1.0, 1.0] +> [1., 1.] diff --git a/examples/pi.dx b/examples/pi.dx index 46de500ac..ef8175b34 100644 --- a/examples/pi.dx +++ b/examples/pi.dx @@ -1,25 +1,36 @@ '# Monte Carlo estimates of pi -estimatePiArea : Key -> Float = \key. - (k1, k2) = splitKey key +'Consider the unit circle centered at the origin. + +'Consider the first quadrant: the unit circle quadrant and its $1 \times 1$ bounding unit square. + +'$$\text{Area of unit circle quadrant: } \\\\ A_{quadrant} = \frac{\pi r^2}{4} = \frac{\pi}{4}$$ + +'$$\text{Area of unit square: } \\\\ A_{square} = 1$$ + +'$$\text{Compute } \pi \text{ via ratios: } \\\\ \frac{A_{quadrant}}{A_{square}} = \frac{\pi}{4}, \\; \pi = 4 \thinspace \frac{A_{quadrant}}{A_{square}} $$ + +'To compute $\pi$, randomly sample points in the first quadrant unit square to estimate the $\frac{A_{quadrant}}{A_{square}}$ ratio. Then, multiply by $4$. + +def estimatePiArea (key:Key) : Float = + [k1, k2] = splitKey key x = rand k1 y = rand k2 inBounds = (sq x + sq y) < 1.0 4.0 * BToF inBounds -estimatePiAvgVal : Key -> Float = \key. +def estimatePiAvgVal (key:Key) : Float = x = rand key 4.0 * sqrt (1.0 - sq x) -meanAndStdDev : Int -> (Key -> Float) -> Key -> (Float & Float) = - \n f key. - samps = for i:(Fin n). many f key i - (mean samps, std samps) +def meanAndStdDev (n:Int) (f: Key -> Float) (key:Key) : (Float & Float) = + samps = for i:(Fin n). many f key i + (mean samps, std samps) numSamps = 1000000 :p meanAndStdDev numSamps estimatePiArea (newKey 0) -> (3.143452, 1.6408892) +> (3.143452, 1.640889) :p meanAndStdDev numSamps estimatePiAvgVal (newKey 0) -> (3.1437902, 0.88649935) +> (3.14379, 0.886499) diff --git a/examples/raytrace.dx b/examples/raytrace.dx index 9b51c3d4e..051722f56 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -4,9 +4,8 @@ [JAX implementation](https://github.com/ericjang/pt-jax/blob/master/jaxpt_vmap.ipynb), described [here](https://blog.evjang.com/2019/11/jaxpt.html). -Specifically, it's based on his unrolled ```lax.scan``` version. -include "examples/plot.dx" +include "plot.dx" ' ### Generic Helper Functions Some of these should probably go in prelude. @@ -25,52 +24,23 @@ def directionAndLength (x: d=>Float) : (d=>Float & Float) = def randuniform (lower:Float) (upper:Float) (k:Key) : Float = lower + (rand k) * (upper - lower) -def reverse (x:n=>a) : n=>a = - s = size n - for i. x.((s - 1 - ordinal i)@_) - -def sampleAveraged (_:VSpace a) ?=> (sample:Key -> a) (n:Int) (k:Key) : a = - snd $ withState zero \total. +def sampleAveraged [VSpace a] (sample:Key -> a) (n:Int) (k:Key) : a = + yieldState zero \total. for i:(Fin n). total := get total + sample (ixkey k i) / IToF n def positiveProjection (x:n=>Float) (y:n=>Float) : Bool = dot x y > 0.0 -data IterResult a:Type b:Type = - Continue a - Done b - --- A little iteration combinator --- TODO: allow effects (currently there's some type inference bug preventing it) -def iter (init:a) (body: Int -> a -> IterResult a b) : b = - result = snd $ withState Nothing \resultRef. - withState init \carryRef. - withState 0 \i. - while (\(). isNothing (get resultRef)) \(). - case body (get i) (get carryRef) of - Continue carry -> - i := get i + 1 - carryRef := carry - Done result -> - resultRef := Just result - case result of - Just ans -> ans - Nothing -> todo -- should be unreachable - ' ### 3D Helper Functions --- TODO: implement table unpacking -def unpackvec3 (p:Vec 3) : (Float & Float & Float) = - (p.(0@(Fin 3)), p.(1@(Fin 3)), p.(2@(Fin 3))) - def cross (a:Vec 3) (b:Vec 3) : Vec 3 = - (a1, a2, a3) = unpackvec3 a - (b1, b2, b3) = unpackvec3 b + [a1, a2, a3] = a + [b1, b2, b3] = b [a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1] -- TODO: Use `data Color = Red | Green | Blue` and ADTs for index sets -def ColorImage (height:Int) (width:Int) : Type = Fin height => Fin width => Color -def GrayScaleImage (height:Int) (width:Int) : Type = Fin height => Fin width => Float +data Image = + MkImage height:Int width:Int (Fin height => Fin width => Color) xHat : Vec 3 = [1., 0., 0.] yHat : Vec 3 = [0., 1., 0.] @@ -81,23 +51,23 @@ Angle = Float -- angle in radians def rotateX (p:Vec 3) (angle:Angle) : Vec 3 = c = cos angle s = sin angle - (px, py, pz) = unpackvec3(p) + [px, py, pz] = p [px, c*py - s*pz, s*py + c*pz] def rotateY (p:Vec 3) (angle:Angle) : Vec 3 = c = cos angle s = sin angle - (px, py, pz) = unpackvec3(p) + [px, py, pz] = p [c*px + s*pz, py, - s*px+ c*pz] def rotateZ (p:Vec 3) (angle:Angle) : Vec 3 = c = cos angle s = sin angle - (px, py, pz) = unpackvec3(p) + [px, py, pz] = p [c*px - s*py, s*px+c*py, pz] def sampleCosineWeightedHemisphere (normal: Vec 3) (k:Key) : Vec 3 = - (k1, k2) = splitKey k + [k1, k2] = splitKey k u1 = rand k1 u2 = rand k2 uu = normalize $ cross normal [0.0, 1.1, 1.1] @@ -112,7 +82,6 @@ def sampleCosineWeightedHemisphere (normal: Vec 3) (k:Key) : Vec 3 = ' ### Raytracer Distance = Float -def Image (n:Int) :Type = Fin n => Fin n => Color -- TODO: hide the size Position = Vec 3 Direction = Vec 3 -- Should be normalized. TODO: use a newtype wrapper @@ -142,7 +111,9 @@ Filter = Color -- TODO: use a record -- num samples, num bounces, share seed? -Params = (Int & Int & Bool) +Params = { numSamples : Int + & maxBounces : Int + & shareSeed : Bool } -- TODO: use a list instead, once they work data Scene n:Type = MkScene (n=>Object) @@ -199,32 +170,30 @@ data RayMarchResult = HitNothing def raymarch (scene:Scene n) (ray:Ray) : RayMarchResult = - max_iters = 100 + maxIters = 100 tol = 0.01 startLength = 10.0 * tol -- trying to escape the current surface (rayOrigin, rayDir) = ray - iter (10.0 * tol) \i rayLength. - case i >= max_iters of - True -> Done HitNothing - False -> - rayPos = rayOrigin + rayLength .* rayDir - (obj, d) = sdScene scene $ rayPos - -- 0.9 ensures we come close to the surface but don't touch it - dNew = rayLength + 0.9 * d - case d < tol of - False -> Continue $ dNew - True -> - surfNorm = calcNormal obj rayPos - case positiveProjection rayDir surfNorm of - True -> - -- Oops, we didn't escape the surface we're leaving.. - -- (Is there a more standard way to do this?) - Continue dNew - False -> - -- We made it! - Done $ case obj of - PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf) - Light _ _ radiance -> HitLight radiance + withState (10.0 * tol) \rayLength. + boundedIter maxIters HitNothing \_. + rayPos = rayOrigin + get rayLength .* rayDir + (obj, d) = sdScene scene $ rayPos + -- 0.9 ensures we come close to the surface but don't touch it + rayLength := get rayLength + 0.9 * d + case d < tol of + False -> Continue + True -> + surfNorm = calcNormal obj rayPos + case positiveProjection rayDir surfNorm of + True -> + -- Oops, we didn't escape the surface we're leaving.. + -- (Is there a more standard way to do this?) + Continue + False -> + -- We made it! + Done $ case obj of + PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf) + Light _ _ radiance -> HitLight radiance def rayDirectRadiance (scene:Scene n) (ray:Ray) : Radiance = case raymarch scene ray of @@ -233,7 +202,7 @@ def rayDirectRadiance (scene:Scene n) (ray:Ray) : Radiance = HitObj _ _ -> zero def sampleSquare (hw:Float) (k:Key) : Position = - (kx, kz) = splitKey k + [kx, kz] = splitKey k x = randuniform (- hw) hw kx z = randuniform (- hw) hw kz [x, 0.0, z] @@ -243,73 +212,71 @@ def sampleLightRadiance (surfNor, surf) = osurf (rayPos, _) = inRay (MkScene objs) = scene - snd $ withAccum \radiance. + yieldAccum \radiance. for i. case objs.i of PassiveObject _ _ -> () Light lightPos hw _ -> (dirToLight, distToLight) = directionAndLength $ lightPos + sampleSquare hw k - rayPos - case positiveProjection dirToLight surfNor of - False -> () -- light on the far side of current surface - True -> - fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight) - outRay = (rayPos, dirToLight) - coeff = fracSolidAngle * probReflection osurf inRay outRay - radiance += coeff .* rayDirectRadiance scene outRay - -def trace (params:Params) (scene:Scene n) (init_ray:Ray) (k:Key) : Color = - (_, max_bounces, _) = params - -- TODO: we ought to be able to use an accumulator here, but there's a bug + if positiveProjection dirToLight surfNor then + -- light on this far side of current surface + fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight) + outRay = (rayPos, dirToLight) + coeff = fracSolidAngle * probReflection osurf inRay outRay + radiance += coeff .* rayDirectRadiance scene outRay + +def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = noFilter = [1.0, 1.0, 1.0] - iter (noFilter, zero, init_ray) $ - \i (filter, radiance, ray). - case i >= max_bounces of - True -> Done radiance - False -> case raymarch scene ray of - HitNothing -> Done radiance - HitLight intensity -> case i == 0 of - True -> Done intensity -- TODO: scale etc - False -> Done radiance + yieldAccum \radiance. + runState noFilter \filter. + runState initRay \ray. + boundedIter (getAt #maxBounces params) () \i. + case raymarch scene $ get ray of + HitNothing -> Done () + HitLight intensity -> + if i == 0 then radiance += intensity -- TODO: scale etc + Done () HitObj incidentRay osurf -> - (k1, k2) = splitKey $ hash k i - lightRadiance = sampleLightRadiance scene osurf incidentRay k1 - outRayHemisphere = sampleReflection osurf incidentRay k2 - newFilter = surfaceFilter filter (snd osurf) - newRadiance = radiance + applyFilter newFilter lightRadiance - Continue (newFilter, newRadiance, outRayHemisphere) - --- TODO: add number of pixels once we can hide sizes --- sensor half-width, pinhole-sensor distance, pinhole position --- (Assumes we're looking towards -z.) -Camera = (Position & Float & Float) - + [k1, k2] = splitKey $ hash k i + lightRadiance = sampleLightRadiance scene osurf incidentRay k1 + ray := sampleReflection osurf incidentRay k2 + filter := surfaceFilter (get filter) (snd osurf) + radiance += applyFilter (get filter) lightRadiance + Continue + +-- Assumes we're looking towards -z. +Camera = + { numPix : Int + & pos : Position -- pinhole position + & halfWidth : Float -- sensor half-width + & sensorDist : Float } -- pinhole-sensor distance + +-- TODO: might be better with an anonymous dependent pair for the result def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => (Key -> Ray) = -- images indexed from top-left - (pos, halfWidth, sensorDist) = camera + halfWidth = getAt #halfWidth camera pixHalfWidth = halfWidth / IToF n ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth xs = linspace (Fin n) (neg halfWidth) halfWidth for i j. \key. - (kx, ky) = splitKey key + [kx, ky] = splitKey key x = xs.j + randuniform (-pixHalfWidth) pixHalfWidth kx y = ys.i + randuniform (-pixHalfWidth) pixHalfWidth ky - (pos, normalize [x, y, neg sensorDist]) + (getAt #pos camera, normalize [x, y, neg (getAt #sensorDist camera)]) -def takePicture - (params:Params) (scene:Scene m) (n:Int) (camera:Camera) - : ColorImage n n = - (numSamples, _, shareSeed) = params +def takePicture (params:Params) (scene:Scene m) (camera:Camera) : Image = + n = getAt #numPix camera rays = cameraRays n camera rootKey = newKey 0 image = for i j. - pixKey = case shareSeed of - True -> rootKey - False -> ixkey (ixkey rootKey i) j + pixKey = if getAt #shareSeed params + then rootKey + else ixkey (ixkey rootKey i) j sampleRayColor : Key -> Color = \k. - (k1, k2) = splitKey k + [k1, k2] = splitKey k trace params scene (rays.i.j k1) k2 - sampleAveraged sampleRayColor numSamples pixKey - image / mean (for (i,j,k). image.i.j.k) + sampleAveraged sampleRayColor (getAt #numSamples params) pixKey + MkImage _ _ $ image / mean (for (i,j,k). image.i.j.k) ' ### Define the scene and render it @@ -331,25 +298,31 @@ theScene = MkScene $ , PassiveObject (Sphere [ 2.0, 2.0, -2.0] 1.5) (Mirror) ] -camera = (10.0 .* zHat, 0.3, 1.0) +defaultParams = { numSamples = 50 + , maxBounces = 10 + , shareSeed = True } --- num_pix = 250 -num_pix = 10 -num_samples = 50 -num_bounces = 10 -share_prng = True -params = (num_samples, num_bounces, share_prng) +defaultCamera = { numPix = 250 + , pos = 10.0 .* zHat + , halfWidth = 0.3 + , sensorDist = 1.0 } +-- We change to a small num pix here to reduce the compute needed for tests +params = defaultParams +camera = if dex_test_mode () + then defaultCamera |> setAt #numPix 10 + else defaultCamera -- %time -image = takePicture params theScene num_pix camera - +(MkImage _ _ image) = takePicture params theScene camera :html imshow image > 'Just for fun, here's what we get with a single sample (sharing the PRNG key among pixels) -:html imshow $ - takePicture (1, num_bounces, share_prng) theScene num_pix camera +params2 = defaultParams |> setAt #numSamples 1 +(MkImage _ _ image2) = takePicture params2 theScene camera + +:html imshow image2 > diff --git a/examples/regression.dx b/examples/regression.dx index f33a84aec..ca8ce2731 100644 --- a/examples/regression.dx +++ b/examples/regression.dx @@ -1,68 +1,68 @@ '# Basis function regression -include "examples/plot.dx" +include "plot.dx" -- Conjugate gradients solver -def solve (m:Type)?-> : m=>m=>Float -> m=>Float -> m=>Float = - \mat b. - x0 = for i:m. 0.0 - ax = mat **. x0 - r0 = b - ax - (xOut, _, _) = fold (x0, r0, r0) $ - \s:m (x, r, p). - ap = mat **. p - alpha = vdot r r / vdot p ap - x' = x + alpha .* p - r' = r - alpha .* ap - beta = vdot r' r' / (vdot r r + 0.000001) - p' = r' + beta .* p - (x', r', p') - xOut +def solve (mat:m=>m=>Float) (b:m=>Float) : m=>Float = + x0 = for i:m. 0.0 + ax = mat **. x0 + r0 = b - ax + (xOut, _, _) = fold (x0, r0, r0) $ + \s:m (x, r, p). + ap = mat **. p + alpha = vdot r r / vdot p ap + x' = x + alpha .* p + r' = r - alpha .* ap + beta = vdot r' r' / (vdot r r + 0.000001) + p' = r' + beta .* p + (x', r', p') + xOut 'Make some synthetic data Nx = Fin 100 noise = 0.1 -(k1, k2) = splitKey (newKey 0) +[k1, k2] = splitKey (newKey 0) -trueFun : Float -> Float = - \x. x + sin (5.0 * x) +def trueFun (x:Float) : Float = + x + sin (5.0 * x) xs : Nx=>Float = for i. rand (ixkey k1 i) ys : Nx=>Float = for i. trueFun xs.i + noise * randn (ixkey k2 i) --- :html showPlot $ xyPlot xs ys +:html showPlot $ xyPlot xs ys +> 'Implement basis function regression -regress : (Float -> d=>Float) -> n=>Float -> n=>Float -> d=>Float = - \featurize xRaw y. - x = map featurize xRaw - xT = transpose x - solve (xT ** x) (xT **. y) +def regress (featurize: Float -> d=>Float) (xRaw:n=>Float) (y:n=>Float) : d=>Float = + x = map featurize xRaw + xT = transpose x + solve (xT ** x) (xT **. y) 'Fit a third-order polynomial -poly : Float -> d=>Float = - \x. for i. pow x (IToF (ordinal i)) +def poly (x:Float) : d=>Float = + for i. pow x (IToF (ordinal i)) params : (Fin 4)=>Float = regress poly xs ys -predict : Float -> Float = - \x. vdot params (poly x) +def predict (x:Float) : Float = + vdot params (poly x) xsTest = linspace (Fin 200) 0.0 1.0 --- :html showPlot $ xyPlot xsTest (map predict xsTest) +:html showPlot $ xyPlot xsTest (map predict xsTest) +> 'RMS error -rmsErr : n=>Float -> n=>Float -> Float = - \truth pred. sqrt $ mean for i. sq (pred.i - truth.i) +def rmsErr (truth:n=>Float) (pred:n=>Float) : Float = + sqrt $ mean for i. sq (pred.i - truth.i) :p rmsErr ys (map predict xs) -> 0.25269496 +> 0.252695 def tabCat (xs:n=>a) (ys:m=>a) : ({left:n|right:m})=>a = @@ -73,7 +73,8 @@ def tabCat (xs:n=>a) (ys:m=>a) : ({left:n|right:m})=>a = xsPlot = tabCat xs xsTest ysPlot = tabCat ys $ map predict xsTest --- :html showPlot $ xycPlot xsPlot ysPlot $ --- for i. case i of --- {| left = _ |} -> 0.0 --- {| right = _ |} -> 1.0 +:html showPlot $ xycPlot xsPlot ysPlot $ + for i. case i of + {| left = _ |} -> 0.0 + {| right = _ |} -> 1.0 +> diff --git a/examples/rejection-sampler.dx b/examples/rejection-sampler.dx index 503dcfb69..b62ede490 100644 --- a/examples/rejection-sampler.dx +++ b/examples/rejection-sampler.dx @@ -1,34 +1,84 @@ +'# Rejection sampler of a Binomial distribution + +'We implement rejection sampling from a Binomial distribution using a uniform proposal. def rejectionSample (try: Key -> Maybe a) (k:Key) : a = - ans = fst $ withState 0 \i. - snd $ withState Nothing \sample. - while (\(). isNothing (get sample)) \(). - i := get i + 1 - sample := try $ hash k (get i) - case ans of Just sample -> sample + iter \i. case try $ hash k i of + Nothing -> Continue + Just x -> Done x Prob = Float LogProb = Float -def binomialSample (n:Int) (p:Prob) (k:Key) : Int = todo - +-- log probability density of a Binomial distribution def logBinomialProb (n:Int) (p:Prob) (counts:Int) : LogProb = pSuccess = log p * IToF counts pFailure = log1p (-p) * IToF (n - counts) normConst = (lbeta (1. + IToF counts) (1. + IToF n - IToF counts) + - log (1. + IToF n)) + log1p (IToF n)) pSuccess + pFailure - normConst -def binomialProb (n:Int) (p:Prob) (count:Int) : Prob = - exp $ logBinomialProb n p count - def trySampleBinomial (n:Int) (p:Prob) (k:Key) : Maybe Int = - (k1, k2) = splitKey k + [k1, k2] = splitKey k proposal = FToI $ floor $ rand k1 * IToF (n + 1) - acceptance = rand k2 < binomialProb n p proposal - case proposal < (n + 1) && acceptance of - True -> Just proposal - False -> Nothing + if proposal > n + then Nothing + else + acceptance = log (rand k2) < logBinomialProb n p proposal + if acceptance + then Just proposal + else Nothing + +'## Example + +'We test the implementation by sampling from a Binomial distribution with 10 trials and success probability 0.4. + +-- parameters +n = 10 +p = 0.4 +numSamples = 5000 +k0 = newKey 0 + +rejectionSamples = randVec numSamples (rejectionSample $ trySampleBinomial n p) k0 + +:p slice rejectionSamples 0 $ Fin 10 +> [4, 2, 5, 4, 6, 7, 3, 6, 4, 3] + +'The Binomial distribution has mean 4 and variance 2.4. + +def meanAndVariance (xs:n=>Float) : (Float&Float) = (mean xs, sq $ std xs) + +:p meanAndVariance $ map IToF rejectionSamples +> (3.9984, 2.361596) + +'## Alternative: Inversion sampling + +'Alternatively, we can use inversion sampling. + +def binomialSample (n:Int) (p:Prob) (k:Key) : Int = + m = n + 1 + logprobs = for i:(Fin m). logBinomialProb n p $ ordinal i + ordinal $ categorical logprobs k + +inversionSamples = randVec numSamples (binomialSample n p) k0 + +:p slice inversionSamples 0 $ Fin 10 +> [6, 7, 6, 5, 3, 2, 4, 4, 3, 4] + +:p meanAndVariance $ map IToF inversionSamples +> (3.9978, 2.409796) + +'The following variant is guaranteed to evaluate the CDF only once. + +def binomialBatch (n:Int) (p:Prob) (k:Key) : a => Int = + m = n + 1 + logprobs = for i:(Fin m). logBinomialProb n p $ ordinal i + map ordinal $ categoricalBatch logprobs k + +inversionBatchSamples = (binomialBatch n p k0) : Fin numSamples => Int + +:p slice inversionBatchSamples 0 $ Fin 10 +> [6, 7, 6, 5, 3, 2, 4, 4, 3, 4] -:p randVec 10 (rejectionSample (trySampleBinomial 10 0.5)) (newKey 0) -> [4, 2, 5, 4, 6, 7, 3, 6, 6, 3] +:p meanAndVariance $ map IToF inversionBatchSamples +> (3.9978, 2.409796) diff --git a/examples/sgd.dx b/examples/sgd.dx index 754d1f550..bc1a0cb29 100644 --- a/examples/sgd.dx +++ b/examples/sgd.dx @@ -1,16 +1,16 @@ '## Stochastic Gradient Descent with Momentum -def sgd_step (dict: VSpace a) ?=> (step_size: Float) (decay: Float) (gradfunc: a -> Int -> a) (x: a) (m: a) (iter:Int) : (a & a) = +def sgd_step [VSpace a] (step_size: Float) (decay: Float) (gradfunc: a -> Int -> a) (x: a) (m: a) (iter:Int) : (a & a) = g = gradfunc x iter new_m = decay .* m + g new_x = x - step_size .* new_m (new_x, new_m) -- In-place optimization loop. -def sgd (dict: VSpace a) ?=> (step_size:Float) (decay:Float) (num_steps:Int) (gradient: a -> Int -> a) (x0: a) : a = +def sgd [VSpace a] (step_size:Float) (decay:Float) (num_steps:Int) (gradient: a -> Int -> a) (x0: a) : a = m0 = zero - (x_final, m_final) = snd $ withState (x0, m0) \state. + (x_final, m_final) = yieldState (x0, m0) \state. for i:(Fin num_steps). (x, m) = get state state := sgd_step step_size decay gradient x m (ordinal i) @@ -31,5 +31,7 @@ stepsize = 0.01 decay = 0.9 num_iters = 1000 :p sgd stepsize decay num_iters gradfunc x_init +> [1.1, 1.1, 1.1, 1.1] :p optimum +> [1.1, 1.1, 1.1, 1.1] diff --git a/examples/sierpinski.dx b/examples/sierpinski.dx index cdfde70c1..64a8f8aea 100644 --- a/examples/sierpinski.dx +++ b/examples/sierpinski.dx @@ -1,19 +1,17 @@ '# Sierpinski triangle ("chaos game") -include "examples/plot.dx" +include "plot.dx" -update : n=>Point -> Key -> Point -> Point = - \points key (x,y). - (x', y') = points.(randIdx key) - (0.5 * (x + x'), 0.5 * (y + y')) +def update (points:n=>Point) (key:Key) ((x,y):Point) : Point = + (x', y') = points.(randIdx key) + (0.5 * (x + x'), 0.5 * (y + y')) -runChain : n:Int -> (Key -> a -> a) -> Key -> a -> (Fin n)=>a = - \n f key x0. scan' x0 (many f key) +def runChain (n:Int) (f:Key -> a -> a) (key:Key) (x0:a) : Fin n => a = + scan' x0 (many f key) trianglePoints : (Fin 3)=>Point = [(0.0, 0.0), (1.0, 0.0), (0.5, sqrt 0.75)] (xs, ys) = unzip $ runChain 3000 (update trianglePoints) (newKey 0) (0.0, 0.0) --- Disabling this for now because the plotting function is too slow to compile --- :html showPlot $ xyPlot xs ys --- > +:html showPlot $ xyPlot xs ys +> diff --git a/examples/simple-include-test.dx b/examples/simple-include-test.dx deleted file mode 100644 index b251c0378..000000000 --- a/examples/simple-include-test.dx +++ /dev/null @@ -1,7 +0,0 @@ - -include "examples/included.dx" -> 30 -> 40 - -:p x -> 10 diff --git a/examples/somedata.dxo b/examples/somedata.dxo deleted file mode 100644 index a9be9bb4e..000000000 --- a/examples/somedata.dxo +++ /dev/null @@ -1 +0,0 @@ -(1.0, 1@2, [[2.0, 1.0, 3.0], [0.0, -10.0, 20.0]], [(1, True), (2, False)]) diff --git a/examples/tiled-matmul.dx b/examples/tiled-matmul.dx index 42e3b5753..7238d671e 100644 --- a/examples/tiled-matmul.dx +++ b/examples/tiled-matmul.dx @@ -16,7 +16,7 @@ def matmul (k : Type) ?-> (n : Type) ?-> (m : Type) ?-> vectorTile = Fin VectorWidth colTile = (colVectors & vectorTile) (tile2d (\nt:(Tile n rowTile). \mt:(Tile m colTile). - ct = snd $ withAccum \acc. + ct = yieldAccum \acc. for l:k. for i:rowTile. ail = broadcastVector a.(nt +> i).l diff --git a/examples/tutorial-old.dx b/examples/tutorial-old.dx index ecd2ee80a..3822158c9 100644 --- a/examples/tutorial-old.dx +++ b/examples/tutorial-old.dx @@ -1,3 +1,7 @@ +'# THIS FILE IS STALE + +'(But we plan to update it at some point) + '# Introduction to the Dex language 'Dex is a functional, statically typed language for array processing. diff --git a/examples/tutorial.dx b/examples/tutorial.dx index 06622370d..2c5b54042 100644 --- a/examples/tutorial.dx +++ b/examples/tutorial.dx @@ -1,104 +1,100 @@ '# Dex Tutorial +' Dex is a functional, statically typed language for array processing. There are + many tools for array processing, from high-level libraries like NumPy and + MATLAB to low-level languages like CUDA. Dex gives you many of the safety and + simplicity benefits of high-level array processing languages, without + requiring that users give up low-level control. -'Dex is a functional, statically typed language for array processing. -There are many tools for array processing from high-level libraries -like NumPy / MATLAB to low-level languages like CUDA. Dex gives you -many benefit of the safety and simplicity benefits of high-level array -processing languages, without requiring that you give up low-level -control. +'## Array comprehensions -'## Array Comprehensions - - -' Before getting into the details of the language, let us begin with -the most useful component of dex, the `for` builder. The best analogy -for this construct is list comprehensions in Python. For instance, in -Python we might write a list comprehension like: +' Before getting into language details, let us begin with the most useful + component of Dex, the `for` builder. The best analogy for this construct is + list comprehensions in Python. For instance, in Python, we might write a + list comprehension like: ' `x = [[1 for j in range(10)] for i in range(5)]` -' In Dex, this construct would be written as, +' In Dex, this construct would be written as: x = for i:(Fin 5). for j:(Fin 10). 1 - -' Once we have an variable we can print it `:p` +' Once we have an variable, we can print it `:p` :p x -' More interestingly, we can also see its type with `:t`. This type -signature tells us that `x` is a two-dimensional array, with first -dimension of size 5 and the second of size 10. +' More interestingly, we can also see its type with `:t`. This type tells us + that `x` is a two-dimensional array, whose first dimension has size 5 and + second dimension has size 10. :t x -' Once we have an array we can use it in new comprehensions. For example, - if say we want to add `5` to each element of the array. In Python, - you might write this as, +' Once we have an array, we can use it in new comprehensions. For example, + let's try to add `5` to each array element. In Python, one might write this as: ' `y = [[x[i][j] for j in range(10)] for i in range(5)]` ' Dex can do something similar. The main superficial difference is the - indexing syntax which uses `.` instead of brackets. + array indexing syntax, which uses `array.i` instead of square brackets for + subscripting. -y = for i:(Fin 5). for j:(Fin 10). x.i.j + 5 +y = for i:(Fin 5). for j:(Fin 10). x.i.j + 5 :p y -' However, we can make this expression nicer. Because `x` has a known type +' However, we can make this expression nicer. Because `x` has a known array type and `i` and `j` index into that type, Dex can infer the range of the loop. - That means that we can safely remove `Fin` statements and get the same result. + That means that we can safely remove the explicit `Fin` type annotations and + get the same result. -y' = for i. for j. x.i.j + 5 +y' = for i. for j. x.i.j + 5 +' We can further reduce this array by applying array reduction functions like + `sum`: -' We can further reduce this array by applying array functions such as `sum`. +z = for i. sum x.i -z = for i. sum x.i +' This style of using `for` to construct type-inferred arrays is central to what + makes Dex powerful. Let's consider another example. This one produces a list of + length 50 in Python. -' This style of using the `for` construct to infer the loop range is - central to what makes Dex powerful. Let's consider another example. - This one produces a list of length 50 in Python. - ' `x = [1 for j in range(10) for i in range(5)]` -' The analogous array construct in Dex is written in - the following form. This produces a one dimension - array of 50 elements. +' The analogous array construct in Dex is written in the following form. It + produces a one-dimensional array of 50 elements. -x2 = for (i, j): (Fin 5 & Fin 10). 1 +x2 = for (i, j): (Fin 5 & Fin 10). 1 +' As before, we can implement "adding 5" to this array using a `for` constructor, + enumerating over each of its elements: -' As before, we can modify this array through another `for` constructor, - which enumerates over each element of `x2`. Or by applying a function. - +y2 = for i. x2.i + 5 -y2 = for i. x2.i + 5 +' And we can apply array functions to the array: :p sum x2 -' But things start to get interesting when we consider the type of this array. - Unlike the Python example that produces a list of length 50. The - Dex array maintains the index type of its construction. In particular - the type of `x2` remembers the original ranges. +' But things start to get interesting when we consider the type of the array. + Unlike the Python example, which produces a list of length 50, the Dex array + Jmaintains the index type of its construction. In particular, the type of the + array remembers the original ranges. :t x2 +'## Typed indexing -'## Typed Indexing - -' The use of typed indices lets you do some really neat things, but it - also breaks some things in counterintuitive ways. Dex use the `.` - syntax for indexing. Critically though, cannot simply index with a - raw integer. +' The use of typed indices lets you do really neat things, but it also breaks + other things in counterintuitive ways. Dex uses the `.` syntax for array + indexing. Critically though, one cannot simply index an array with an integer + literal. r = x.3 -' Instead you need to cast your integer into the index type of the current - shape. This is done with the `@` operator. (If it helps, you can think of `a.i` - as passing index `i` to array `a` the same way `f x` passes arg `x` to function - `f`.) +' Instead, it is necessary to cast the integer into the index type of the + current shape. This type annotation is done with the `@` operator. (If it + helps, you can think of array indexing as function application: `a.i` applies + array `a` with index `i` just like how `f x` applies function `f` with + argument `x`.) :t x @@ -108,63 +104,58 @@ row = x.(3 @ Fin 5) :t row.(5 @ Fin 10) -' This is bit verbose, but you rarely need to do it in practice. Most of the - time, you index with the `for` construct which is able to infer the right indices. +' This explicit annotation is a bit verbose, but it is rarely necessary in + practice. Most of the time, the `for` construct can infer index types. That's why we didn't have to use any `@` symbols when constructing `y2` above. -' Similarly you can't use indices as integers as you might be used to. You need to - cast them out explicitly. - +' Similarly, you cannot use indices as integers as you might be used to. It is + necessary to explicitly annotate index types. x4 = for i:(Fin 5). for j:(Fin 10). i + j +x4 = for i:(Fin 5). for j:(Fin 10). (ordinal i) + (ordinal j) -x4 = for i:(Fin 5). for j:(Fin 10). (ordinal i) + (ordinal j) - - -' As we have seen though, indices do not need to just be integers. We can index with - many different Dex type. For instance `x2` was indexed with a pair of integers (`&` means tuple) - so we need to build a tuple in order to index. +' As we have seen, indices are not limited to only integers. Many different Dex + types are valid index types. For example, we declared array `x2` as having a + pair of integers as its index type (`a & b` means tuple type), so indexing + into `x2` requires creating a tuple value (via `(x, y)`). :t x2 :t x2.(3@Fin 5, 5@Fin 10) -' A lot of algorithms in Dex come down to being able to pack and - unpack these indices. For example, we have seen that it is easy to - sum over one dimension of a 2D array. However, if we have a 1D - array indexed by a pair, we can easily turn it into a 2D array by - constructing it. +' Many algorithms in Dex come down to being able to pack and unpack these + indices. For example, we have seen that it is easy to sum over one dimension + of a 2D array. However, if we have a 1D array indexed by a pair, we can + easily turn it into a 2D array using two `for` constructors. x3 = for i. for j. x2.(i, j) :t x3 -' Again we rely on type inference in order to avoid explicitly giving -the ranges. +' Again, we rely on type inference in order to avoid explicitly spelling the + ranges. -' ## Functions over Arrays +' ## Defining functions over arrays -' One use case of packing and unpacking array indices is that - it allows us to change the order of the axes. This is useful for - applying functions on arrays. +' One use case of packing and unpacking array indices is that it allows us to + change the order of the axes. This is useful for applying functions on arrays. -' For instance, we saw the `sum` function above which sums over an - axes. We can apply `sum` to `x2` to produce the sum over 50 elements. +' For instance, we saw the `sum` function above which sums over the first axis + of an array. We can apply `sum` to `x2` to produce the sum over 50 elements: :t x2 :p sum x2 -' Alternatively we can apply sum over `x3` to produce the sum over rows. +' Alternatively, we can apply sum over `x3` to produce the sum over rows: :t x3 :p sum x3 -' How do we sum over the columns? In systems like NumPy you would - do this by passing an axis argument to `sum`. Dex doesn't work this - way. To sum over columns, you need to move columns to the front - of the line. Luckily, you already know how to do this. - +' How do we sum over the columns of `x3`? In systems like NumPy, you would do + this by passing an axis argument to `sum`. Dex doesn't work this way. To sum + over columns, you need to move columns to the front of the line. Luckily, we + already know how to do this: using `for` constructors! :t x3 @@ -174,12 +165,11 @@ trans = for j. for i. x3.i.j :p sum trans -' The `sum` function seems to work independently of the index type of the - array. +' The `sum` function works independently of the index type of the array. -' Let's see how we can do this with our own functions. To define a function in - Dex we use the following syntax (there are other ways to do it, but this - one looks pretty close to Python.) +' Let's see how we can define our own array functions. Defining a function in + Dex uses the following syntax. (There are other ways to do it, but this one + looks closest to Python.) def add5 (x:Int32) : Int32 = x + 5 @@ -187,28 +177,25 @@ def add5 (x:Int32) : Int32 = x + 5 :t for i. add5 x2.i - -' We can also write functions with type variables over their inputs. For instance - we if we want to be able to `Add5` to any array. This function binds the type - variable `n` to the index type of the array. - +' We can also write functions with type variables over their inputs. For + instance, we may want to be able to write a function that applies "adds 5" + to arrays with _any_ index type. This is possible by declaring an `n => Int32` + array argument type: this declares the type variable `n` as the index type of + the array argument. def arrAdd5 (x : n => Int32) : n => Int32 = for i. x.i + 5 - -:t arrAdd5 x2 +:t arrAdd5 x2 -' But the function types can help you out even more. - For instance, because index types are sized, you - can use type inference to ensure the arguments passed in - are valid. +' But function types can help you out even more. For instance, since index types + are statically known, type checking can ensure that array arguments have valid + dimensions. This is "shape safety". -' For instance, let's say we want to add two array together. +' For instance, let's write a function adding two 2D arrays with the same shape: :t x :t y - def arrAdd (x : m => n => Int32) (y : m => n => Int32) : m => n => Int32 = for i. for j. x.i.j + y.i.j @@ -216,192 +203,207 @@ def arrAdd (x : m => n => Int32) (y : m => n => Int32) : m => n => Int32 = :t arrAdd x (trans y) -' Here the system type checked for us that they are the same size. +' The type system checked for us that input arrays indeed have the same shape. +'## Writing loops -'## Writing Loops +' Dex is a functional language - but when writing mathematical algorithms, + it is often convenient to temporarily put aside immutability and write + imperative code using mutation. -' Dex is a functional language, but when writing mathematical algorithm - it is often convenient to ignore that fact and write imperative code. +' For example, let's say we want to actually implement the `sum` function + ourselves by accumulating summed values in-place. In Python, implementing this + is not directly possible solely via list comprehensions, so we would write a + loop. -' For example, lets say we now want to actually write the `sum` function - ourselves by accumulating summed values. In Python, We can't do this directly - with list comprehensions, so we would write a loop. - ' `acc = 0` ' `for i in range(10):` -' `acc = acc + x[i]` +' `acc = acc + x[i]` -' Variables are immutable in Dex, so we cannot do this directly. But we can - write very similar code using the `state` effect. Here's what it looks like - with the corresponding Python code. - +' In Dex, values are immutable, so we cannot directly perform mutation. But Dex + includes algebraic effects, which are a purely-functional way to modeling + side-effects like mutation. We can write code that looks like mutation using + the `State` effect, which provides getter and setter functionality (via `get` + and `:=` assignment). Here's what it looks like: -def arrSum (x : a => Int32) : Int32 = +def arrSum (x : n => Int32) : Int32 = -- acc = 0 - initAcc = 0 + init = 0 -- (ignore for now) - snd $ withState initAcc $ \acc. + snd $ withState init $ \acc. -- for i in range for i. -- acc = acc + x[i] acc := (get acc) + x.i - -:p arrSum x2 +:p arrSum x2 -' So even though we are functional, the loop looks quite - similar to the imperative style. However there is one - line which is quite new and a bit scary. Let us look - into that line in a bit more detail. +' So, even though Dex is a functional language, it is possible to write loops + that look similar to ones that truly perform mutation. However, there is one + line which is quite new and a bit scary. Let us look into that line in a bit + more detail. -' First `$`. This symbol is used in Dex the same way it is - used in Haskell, but if you have haven't seen it before it - is a bit strange. It basically takes the place of parens `( )` - when you are too lazy to write them. For example, these two are the same: +' First: `$`. This symbol is used in Dex just like it is used in Haskell, but + if you haven't seen it before, it seems a bit strange. `$` is the function + application operator: it basically replaces of expression-grouping parentheses + `(f x)` when it is inconvenient to write them. For example, the following two + expressions are identical: :t arrSum (x2 + x2) :t arrSum $ x2 + x2 -' Next `\`. This symbol is the lambda operator in Dex. It makes a function - that you can use right away, and behaves like `lambda` in python. - Here the function takes an argument `acc` and returns the expression below (a `for` constructor). +' Next: `\`. This symbol is the lambda sigil in Dex. It is analogous to the + `lambda` keyword in Python, and starts the definition of a function value + (i.e. closure). In `arrSum` above: the lambda takes an argument named `acc` + and returns the body, which is the expression following the `.` (a `for` + constructor in this case). -' Finally, the function `snd` is from the prelude. It returns the second of a pair, nothing fancy. +' Finally, the function `snd` is from the Dex Prelude. It simply returns the + second element of a pair - there is also `fst` for extracting the first + element. :p fst (1, 2) :p snd (1, 2) - -' That leaves `withState`. This function allows you to introduce imperative variables into the computation. - It takes a intial values `initAcc` and a function of a reference to that value `\acc.` It then returns - a pair of the result of that function and the final value. Here's a simple example +' That leaves: `withState`. This function uses the `State` effect, enabling us + to introduce imperative variables into the computation. + `withState` takes an initial value `init` and a body function taking a + "mutable value" reference (`acc` here), and returns a pair of the body + function's result and the final value. Here's a simple example: :p withState 10 $ \ state. state := 30 20 -' The first element in the pair is the function return (`20`) and the second is the final value of the variable (`30`). - -' Finally this is a good point to talk a bit about some of the other operators in Dex. - Here we see two types of equal signs `=` and `:=`. The first is the `let` operator that makes an - immutable assignment. This one is built into the language and can be used anywhere you want. +' The first element of the returned pair is the body function's result (`20`). + The second element is the final value of the variable (`30`). +' Finally: this is a good point to talk a bit about some other operators in Dex. + In the examples above, we see two types of equal sign operators: `=` and `:=`. + The first is the `let` operator that creates an immutable assignment (a + "let-binding"). This one is built into the language and can be used anywhere. q = for i:(Fin 10). - -- Bind a temp variable for some reason + -- Bind a temporary variable `temp`, as an example. temp = (ordinal i) + 10 for j:(Fin 5). temp - -' The other is `:=` which can only be used inside of a `withState` block. It assigns - a value to a mutable reference. To read that value you need to use the `get` function. - or wait until the `withState` returns. +' The other is `:=`, which is an assignment operator that can only be used + when a `State` effect is available (e.g. inside of a body function passed to + `withState`. `ref := x` assigns the value `x` to the mutable reference `ref`. + Reading the value in `ref` is possible via the `get` function. or via using + the final result returned by `withState`. -'## Type Classes +'## Typeclasses -' Our arrSum function is pretty neat. It lets us work with any type index - to compute the sum. However, it annoyingly only works for integers. +' Our `arrSum` function is pretty neat. It lets us work with arrays with any + index type and computes the sum. However, `arrSum` explicitly takes only + integer arrays (of type `n => Int32`). :t arrSum -' If we apply it to floats we get the following error. +' If we try to apply `arrSum` to float arrays, we get the following error: arrSum for i : (Fin 5). 10.0 -' We can compare the type of our sum to the built-in Dex `sum`. +' We can compare the type of our `arrSum` function to the `sum` function found + in the Dex Prelude. :t sum -' It has another type variable `v` for the output. It also has the extra annotation - `(Add v) ?=>`. This is a constraint that tells us that `v` can be any type in the - `Add` type class. +' The Prelude-defined `sum` function also has an additional argument, spelled + like: `(Add v) ?=> ...`. This is a constraint telling us that the function + expects an `Add v` typeclass instance, where `v` is any type that implements + the `Add` typeclass. -' If we wanted to, we could look in the Dex prelude to see what this looks like. But we can - probably guess what it means. `v` needs to be something where `add` works on it. - We can do that! Let's define our own type class. +' We could look in the Dex Prelude to see exactly how `sum` is defined and what + `Add v` means. But we can guess what the `Add v` constraint means: `v` needs + to be a type that works with `add`. We can do that! Let's define our own + typeclass. interface MyAdd a:Type where myAdd : a -> a -> a myZero : a -' This tells us that to be in the `MyAdd` type class, a type `a` needs to have - a function `myAdd` and `myZero`. A type can then join the class like this. - +' This declares a typeclass (i.e. interface or trait) called `MyAdd` with some + typeclass methods (interface requirements). To implement the `MyAdd` + typeclass, a type `a` needs to define functions `myAdd` and `myZero` in a + "typeclass instance", like so: instance int32MyAdd : MyAdd Int32 where myAdd = \x y. x + y myZero = 0 instance float32MyAdd : MyAdd Float32 where - myAdd = \x y. (x + y) + myAdd = \x y. (x + y) myZero = 0.0 -' Once we have these two definitions, we can revisit our sum function. Here is how we modify - the type. +' Once we have these two instance definitions (`MyAdd Int32` and + `MyAdd Float32`), we can revisit our array sum function and add a typeclass + constraint: -def arrSum2 (_:MyAdd v) ?=> (x : a => v) : v = +def arrSumGeneric (_:MyAdd v) ?=> (x : a => v) : v = snd $ withState myZero $ \acc. for i. acc := myAdd (get acc) x.i -arrSum2 for i : (Fin 5). 10 -arrSum2 for i : (Fin 5). 10.0 +arrSumGeneric for i : (Fin 5). 10 +arrSumGeneric for i : (Fin 5). 10.0 -arrSum2 $ for i : (Fin 5). - for j : (Fin 10). 10.0 +arrSumGeneric $ for i : (Fin 5). + for j : (Fin 10). 10.0 -' So it works for ints and it works for floats. But it failed when we tried to - pass in a 2D array. What went wrong? The error tells us that it can't produce - a class dictionary for `MyAdd ((Fin 10) => Float32)`. This makes sense as - we have have not written one. We need to tell the system how to add columns. - -' If we want, we can take the type checker literally and make this instance :). +' This sum function works for any type that implements `MyAdd`, like `Int32` and + `Float32`. But it failed when we tried to pass in a 2D array. What went wrong? + The error tells us that the function could not find a `MyAdd` instance for + `MyAdd ((Fin 10) => Float32)`. This makes sense because we have have not + written one. We need to tell the system "how to add array columns". +' One option is to directly satisfy the type checker and provide a specific + `MyAdd ((Fin 10) => Float32)` instance: instance specMyAdd : MyAdd ((Fin 10) => Float32) where - myAdd = \x y. for i: (Fin 10). (x.i + y.i) + myAdd = \x y. for i: (Fin 10). (x.i + y.i) myZero = for i: (Fin 10). 0.0 -arrSum2 $ for i : (Fin 5). - for j : (Fin 10). 10.0 - +arrSumGeneric $ for i : (Fin 5). + for j : (Fin 10). 10.0 -' Or we can treat it a more generally and extend to all 1D arrays. +' To be more general, we can instead define a `MyAdd` instance for all array + types. This instance requires that the array element type `v` also has an + `MyAdd` instance; this requirement is represented as a `(MyAdd v) ?=> ...` + constraint. instance arrMyAdd : (MyAdd v) ?=> MyAdd (a => v) where - myAdd = \x y. for i. (myAdd x.i y.i) + myAdd = \x y. for i. (myAdd x.i y.i) myZero = for i. myZero -arrSum2 $ for i : (Fin 5). - for j : (Fin 9). 10.0 +arrSumGeneric $ for i : (Fin 5). + for j : (Fin 9). 10.0 +' This instance not only works for 2D arrays, but also 3D and higher-dimensional + arrays: -' This now works for 3D arrays too. - -arrSum2 $ for i : (Fin 5). +arrSumGeneric $ for i : (Fin 5). for j : (Fin 9). for k : (Fin 9). 10.0 - - -' ## Prelude Practice +' ## Learn the Prelude -' There are a bunch of goodies implemented in the prelude - that are worth knowing. It's good practice just to - infer what these functions do from their type. - -' Here are a couple that come up a lot. +' The Prelude contains many handy functions. Since Dex types contain so much + information, it is possible to infer what many of these functions do just by + reading and understanding their type. +' Here are a few used, commonly-used Prelude functions. ' * `select` for filtering -:t select +:t select select True 1 2 select False 1 2 @@ -413,23 +415,21 @@ select False 1 2 myzero1 : (Fin 20 & Fin 10) => Float32 = zero myzero2 : (Fin 20) => (Fin 10) => Float32 = zero -' * `zip` for creating tables of pairs +' * `zip` for creating arrays of pairs :t zip :t zip x x :t for i. zip x.i x.i - -' * `iota` for create aranges +' * `iota` for creating "aranges" :t iota :p (iota (Fin 10)) :p for i. for j. (iota (Fin 4 & Fin 4)).(i, j) - -' * Random numbers +' * Pseudorandom number generation :t newKey :t splitKey @@ -440,51 +440,43 @@ key = newKey 0 :p randn key1 -' * `randVec` creates a random vector +' * `randVec` for creating a vector of random numbers - -randv = randVec 20 randn key2 +randv = randVec 20 randn key2 :t randv randv2 = randVec 20 randInt key3 :t randv2 +'## Worked examples: Project Euler -'## Worked Examples: Project Euler - -' To demonstrate Dex in practice, here are some example - functions solving problems on https://projecteuler.net/ +' To demonstrate Dex in practice, below are some examples of solving problems + from [Project Euler](https://projecteuler.net). - def ignore (y:a) (x : Maybe a) : a = case x of Just x -> x Nothing -> y - + instance maybeAdd : (Add v) ?=> Add (Maybe v) where add = \x y. Just $ ignore zero x + ignore zero y sub = \x y. Just $ ignore zero x - ignore zero y zero = Just zero - ' ### Problem 1: Find the sum of all the multiples of 3 or 5 below 1000. - - prob1 = for i : (Fin 1000). i' = ordinal i case ((i' `mod` 3) == 0 || (i' `mod` 5) == 0) of True -> Just i' False -> Nothing - + :p fromJust $ sum prob1 ' ### Problem 2: By considering the terms in the Fibonacci sequence whose values do not exceed four million, find the sum of the even-valued terms. - ... - -- def maybeList (x : Maybe a) : List a = -- case x of -- Just a -> AsList 1 $ for i : (Fin 1). a diff --git a/examples/typeclass-tests.dx b/examples/typeclass-tests.dx deleted file mode 100644 index 5061d6ece..000000000 --- a/examples/typeclass-tests.dx +++ /dev/null @@ -1,38 +0,0 @@ -interface InterfaceTest1 a:Type where - InterfaceTest1 : a -> Error: variable already defined: InterfaceTest1 - -interface InterfaceTest2 typeName:Type where - typeName : typeName -> typeName - -interface InterfaceTest3 _:Type where - foo : Int32 - -> Parse error:8:26: -> | -> 8 | interface InterfaceTest3 _:Type where -> | ^^^^^ -> unexpected "_:Typ" -> expecting "where" or named annoted binder -interface InterfaceTest4 where - foo : Int - -instance instanceTest4 : InterfaceTest4 where - foo = 1 - -instance instanceTest4 : InterfaceTest4 x -> InterfaceTest4 (n=>a) where - foo = 1 - -> Parse error:23:68: -> | -> 23 | instance instanceTest4 : InterfaceTest4 x -> InterfaceTest4 (n=>a) where -> | ^ -> Met invalid arrow '->' in type annotation of instance. Only class arrows and implicit arrows are allowed. -instance instanceTest5 : (..i) where - bar = bar - -> Parse error:31:32: -> | -> 31 | instance instanceTest5 : (..i) where -> | ^ -> Could not extract interface name from type annotation. diff --git a/examples/web-tests.dx b/examples/web-tests.dx deleted file mode 100644 index e50511fba..000000000 --- a/examples/web-tests.dx +++ /dev/null @@ -1,12 +0,0 @@ - -_, N = unpack range 7 - -xs = for i:N. float iota.i - -:p 1 + 1.0 - -:p 1.0 + 200.0 - -:plot for i. (xs.i, xs.i * xs.i) - -:plotmat for i:N j:N. rand $ hash (hash 0 iota.i) iota.j diff --git a/examples/diagram.dx b/lib/diagram.dx similarity index 52% rename from examples/diagram.dx rename to lib/diagram.dx index e19fd325b..bbfff2ef4 100644 --- a/examples/diagram.dx +++ b/lib/diagram.dx @@ -7,24 +7,19 @@ data Geom = Circle Float Rectangle Float Float -- width, height Line Point + Text String --- HTML color (no alpha) --- TODO: replace with `Fin 3 => Word8` when we fix #348 -HtmlColor : Type = (Word8 & Word8 & Word8) +HtmlColor : Type = Fin 3 => Word8 -def showHex (x:Int32) : String = - (n, ptr) = %ffi showHex (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) +def showHex (x:Word8) : String = unsafeIO \(). + (n, ptr) = %ffi showHex (Int32 & RawPtr) x + stringFromCharPtr n (MkPtr ptr) --- TODO: we should add overloaded string literals so we don't need this -def str (n:Int) ?-> (s:(Fin n=>Char)) : String = AsList _ s - -black : HtmlColor = (IToW8 0, IToW8 0, IToW8 0) -white : HtmlColor = (IToW8 255, IToW8 255, IToW8 255) -red : HtmlColor = (IToW8 255, IToW8 0, IToW8 0) -green : HtmlColor = (IToW8 0, IToW8 255, IToW8 0) -blue : HtmlColor = (IToW8 0, IToW8 0, IToW8 255) +black : HtmlColor = [IToW8 0, IToW8 0, IToW8 0] +white : HtmlColor = [IToW8 255, IToW8 255, IToW8 255] +red : HtmlColor = [IToW8 255, IToW8 0, IToW8 0] +green : HtmlColor = [IToW8 0, IToW8 255, IToW8 0] +blue : HtmlColor = [IToW8 0, IToW8 0, IToW8 255] GeomStyle : Type = { fillColor : Maybe HtmlColor @@ -39,7 +34,7 @@ defaultGeomStyle : GeomStyle = -- TODO: consider sharing attributes among a set of objects for efficiency data Diagram = MkDiagram (List (GeomStyle & Point & Geom)) -instance monoidDiagram : Monoid Diagram where +instance Monoid Diagram mempty = MkDiagram mempty mcombine = \(MkDiagram d1) (MkDiagram d2). MkDiagram $ d1 <> d2 @@ -56,9 +51,9 @@ def applyTransformation (transformGeom: Geom -> Geom) (d:Diagram) : Diagram = (MkDiagram (AsList _ objs)) = d - (MkDiagram $ AsList _ for i. + MkDiagram $ toList for i. (attr, p, geom) = objs.i - (attr, transformPoint p, transformGeom geom)) + (attr, transformPoint p, transformGeom geom) flipY : Diagram -> Diagram = applyTransformation (\(x,y). (x, -y)) \geom. case geom of @@ -66,6 +61,7 @@ flipY : Diagram -> Diagram = Circle r -> Circle r Rectangle w h -> Rectangle w h Line (x, y) -> Line (x, -y) + Text x -> Text x def scale (s:Float) : (Diagram -> Diagram) = applyTransformation ( \(x,y). (s * x, s * y) ) \geom. case geom of @@ -73,21 +69,23 @@ def scale (s:Float) : (Diagram -> Diagram) = Circle r -> Circle (s * r) Rectangle w h -> Rectangle (s * w) (s * h) Line (x, y) -> Line (s * x, s * y) + Text x -> Text x def moveXY ((offX, offY) : Point) : (Diagram -> Diagram) = applyTransformation (\(x,y). (x + offX, y + offY) ) id def singletonDefault (geom:Geom) : Diagram = - MkDiagram $ AsList _ [(defaultGeomStyle, (0.0, 0.0), geom)] + MkDiagram $ toList [(defaultGeomStyle, (0.0, 0.0), geom)] def pointDiagram : Diagram = singletonDefault PointGeom def circle (r:Float) : Diagram = singletonDefault $ Circle r def rect (w:Float) (h:Float) : Diagram = singletonDefault $ Rectangle w h def line (p:Point) : Diagram = singletonDefault $ Line p +def text (x:String) : Diagram = singletonDefault $ Text x def updateGeom (update: GeomStyle -> GeomStyle) (d:Diagram) : Diagram = (MkDiagram (AsList _ objs)) = d - MkDiagram $ AsList _ for i. + MkDiagram $ toList for i. (attr , geoms) = objs.i (update attr, geoms) @@ -106,67 +104,78 @@ def strCatUncurried ((xs,ys):(String & String)) : String = xs <> ys def (<.>) (xs:String) (ys:String) : String = strCatUncurried (xs, ys) -def quote (s:String) : String = str "\"" <.> s <.> str "\"" +def quote (s:String) : String = "\"" <.> s <.> "\"" @noinline def strSpaceCatUncurried ((s1,s2):(String & String)) : String = - s1 <.> str " " <.> s2 + s1 <.> " " <.> s2 -def (<+>) (s1:String) (s2:String) : String = - strSpaceCatUncurried (s1, s2) +def (<+>) [Show a, Show b] (s1:a) (s2:b) : String = + strSpaceCatUncurried ((show s1), (show s2)) -def selfClosingBrackets (s:String) : String = str "<" <.> s <.> str "/>" +def selfClosingBrackets (s:String) : String = "<" <.> s <.> "/>" def tagBrackets (tag:String) (s:String) : String = - str "<" <.> tag <.> str ">" <.> s <.> str " tag <.> str ">" + "<" <.> tag <.> ">" <.> s <.> " tag <.> ">" @noinline def tagBracketsAttrUncurried ((tag, attr, s):(String & String & String)) : String = - str "<" <.> tag <+> attr <.> str ">" <.> s <.> str " tag <.> str ">" + "<" <.> tag <+> attr <.> ">" <.> s <.> " tag <.> ">" def tagBracketsAttr (tag:String) (attr:String) (s:String) : String = tagBracketsAttrUncurried (tag, attr, s) -def makeAttr (attr:String) (val:String) : String = - attr <.> str "=" <.> quote val +def (<=>) [Show b] (attr:String) (val:b) : String = + attr <.> "=" <.> quote (show val) -def htmlColorStr (cs:HtmlColor) : String = - (r, g, b) = cs - toList "#" <> (showHex $ W8ToI r) <> (showHex $ W8ToI g) <> (showHex $ W8ToI b) +def htmlColor(cs:HtmlColor) : String = + "#" <> (concat $ for i. showHex cs.i) -def optionalHtmlColorStr (c: Maybe HtmlColor) : String = +def optionalHtmlColor(c: Maybe HtmlColor) : String = case c of - Nothing -> str "none" - Just c' -> htmlColorStr c' + Nothing -> "none" + Just c' -> htmlColor c' @noinline def attrString (attr:GeomStyle) : String = - ( -- makeAttr (str "stroke") (optionalHtmlColorStr $ getAt #strokeColor attr) - makeAttr (str "fill") (optionalHtmlColorStr $ getAt #fillColor attr) - <+> makeAttr (str "stroke-width") (show $ getAt #strokeWidth attr)) + ( ("stroke" <=> (optionalHtmlColor $ getAt #strokeColor attr)) + <+> ("fill" <=> (optionalHtmlColor $ getAt #fillColor attr)) + <+> ("stroke-width" <=> (getAt #strokeWidth attr))) def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = + -- For things that are solid. SVG says they have fill=stroke. + solidAttr = setAt #fillColor (getAt #strokeColor attr) attr + + groupEle = \attr. tagBracketsAttr "g" (attrString attr) case geom of PointGeom -> - pointAttr = setAt #fillColor (getAt #strokeColor attr) attr - tagBracketsAttr (str "g") (attrString pointAttr) $ selfClosingBrackets $ - (str "circle" <+> - str "cx=" <.> quote (show x) <.> - str "cy=" <.> quote (show y) <.> - str "r=\"1\"") + pointAttr = setAt #fillColor (getAt #strokeColor attr) attr + groupEle solidAttr $ selfClosingBrackets $ + ("circle" <+> + "cx" <=> x <.> + "cy" <=> y <.> + "r=\"1\"") Circle r -> - tagBracketsAttr (str "g") (attrString attr) $ selfClosingBrackets $ - (str "circle" <+> - str "cx=" <.> quote (show x) <.> - str "cy=" <.> quote (show y) <.> - str "r=" <.> quote (show r)) + groupEle attr $ selfClosingBrackets $ + ("circle" <+> + "cx" <=> x <.> + "cy" <=> y <.> + "r" <=> r) Rectangle w h -> - tagBracketsAttr (str "g") (attrString attr) $ selfClosingBrackets $ - (str "rect" <+> - str "width=" <.> quote (show w) <.> - str "height=" <.> quote (show h) <.> - str "x=" <.> quote (show (x - (w/2.0))) <.> - str "y=" <.> quote (show (y - (h/2.0)))) + groupEle attr $ selfClosingBrackets $ + ("rect" <+> + "width" <=> w <.> + "height" <=> h <.> + "x" <=> (x - (w/2.0)) <.> + "y" <=> (y - (h/2.0))) + Text content -> + textEle = tagBracketsAttr "text" $ + ("x" <=> x <.> + "y" <=> y <.> + "text-anchor" <=> "middle" <.> -- horizontal center + "dominant-baseline" <=> "middle" -- vertical center + ) + groupEle solidAttr $ textEle content BoundingBox : Type = (Point & Point) @@ -175,14 +184,13 @@ def renderSVG (d:Diagram) (bounds:BoundingBox) : String = imgWidth = 400.0 scaleFactor = imgWidth / (xmax - xmin) imgHeight = (ymax - ymin) * scaleFactor + imgXMin = xmin * scaleFactor + imgYMin = -ymax * scaleFactor (MkDiagram (AsList _ objs)) = d |> flipY |> scale scaleFactor - viewBoxStr = makeAttr (str "viewBox") $ - (show (xmin * scaleFactor) <+> show (-(ymax * scaleFactor)) <+> - show imgWidth <+> show imgHeight) - svgAttrStr = ( makeAttr (str "width" ) (show imgWidth) - <+> makeAttr (str "height") (show imgHeight) - <+> viewBoxStr) - tagBracketsAttr (str "svg") svgAttrStr $ + svgAttrStr = ( "width" <=> imgWidth + <+> "height" <=> imgHeight + <+> "viewBox" <=> (imgXMin <+> imgYMin <+> imgWidth <+> imgHeight)) + tagBracketsAttr "svg" svgAttrStr $ concat for i. (attr, pos, geom) = objs.i renderGeom attr pos geom @@ -192,11 +200,24 @@ def renderSVG (d:Diagram) (bounds:BoundingBox) : String = moveX : Float -> Diagram -> Diagram = \x. moveXY (x, 0.0) moveY : Float -> Diagram -> Diagram = \y. moveXY (0.0, y) --- mydiagram : Diagram = --- ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) --- <> (circle 5.0 |> moveXY (40.0, 40.0)) --- <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) --- <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) --- ) - --- :html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) +' A Demo showing all kind of features +``` +mydiagram : Diagram = + ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) + <> (circle 5.0 |> moveXY (40.0, 40.0)) + <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) + <> (text "DexLang" |> moveXY (30.0, 10.0) |> setStrokeColor green) + <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) + ) +:html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) +``` + +' Another demo that shows things are all center aligned: +``` +concentricDiagram : Diagram = ( + (rect 2.0 2.0 |> setFillColor red) + <> (circle 1.0 |> setFillColor blue) + <> (text "DexLang" |> setStrokeColor white) +) |> moveXY (5.0, 5.0) +:html renderSVG concentricDiagram ((0.0, 0.0), (10.0, 10.0)) +``` diff --git a/lib/parser.dx b/lib/parser.dx new file mode 100644 index 000000000..1190fe84e --- /dev/null +++ b/lib/parser.dx @@ -0,0 +1,123 @@ + + +'Utilities unrelated to parsing + +def fromOrdinalExc (n:Type) (i:Int) : {Except} n = + if (0 <= i) && (i < size n) + then unsafeFromOrdinal _ i + else throw () + +-- TODO: allow this to happen in-place +-- TODO: if it takes too long to make that possible, start with a bounded version +def push (ref:Ref h (List a)) (x:a) : {State h} Unit = + l = get ref + ref := l <> AsList _ [x] + +def indexList (l:List a) (i:Int) : {Except} a = + (AsList n xs) = l + xs.(fromOrdinalExc _ i) + +'The Parser type + +def ParserHandle (h:Type) : Type = (String & Ref h Int) + +data Parser a:Type = + MkParser (h:Type ?-> ParserHandle h -> {Except, State h} a) + +def parse (handle:ParserHandle h) (parser:Parser a) : {Except, State h} a = + (MkParser f) = parser + f handle + +def runParserPartial (s:String) (parser:Parser a) : Maybe a = + (MkParser f) = parser + withState 0 \pos. + catch $ do + f (s, pos) + +'Primitive combinators + +def pChar (c:Char) : Parser Unit = MkParser \(s, posRef). + i = get posRef + c' = indexList s i + assert (c == c') + posRef := i + 1 + +def pEOF : Parser Unit = MkParser \(s, posRef). + assert $ get posRef >= listLength s + +def (<|>) (p1:Parser a) (p2:Parser a) : Parser a = MkParser \h. + (s, posRef) = h + curPos = get posRef + case catch do parse h p1 of + Nothing -> + assert $ curPos == get posRef + parse h p2 + Just ans -> ans + +def return (x:a) : Parser a = MkParser \_. x + +def runParser (s:String) (parser:Parser a) : Maybe a = + runParserPartial s $ MkParser \h. + ans = parse h parser + _ = parse h pEOF + ans + +def parseAny : Parser Char = MkParser \h. + (s, posRef) = h + i = get posRef + c' = indexList s i + posRef := i + 1 + c' + +def try (parser:Parser a) : Parser a = MkParser \h. + (s, posRef) = h + savedPos = get posRef + ans = catch do parse h parser + case ans of + Nothing -> + posRef := savedPos + throw () + Just x -> x + +'Derived combinators + +def parseDigit : Parser Int = try $ MkParser \h. + c = parse h $ parseAny + i = W8ToI c - 48 + assert $ 0 <= i && i < 10 + i + +def optional (p:Parser a) : Parser (Maybe a) = + (MkParser \h. Just (parse h p)) <|> return Nothing + +def parseMany (parser:Parser a) : Parser (List a) = MkParser \h. + yieldState (AsList _ []) \results. + iter \_. + maybeVal = parse h $ optional parser + case maybeVal of + Nothing -> Done () + Just x -> + push results x + Continue + +def parseUnsignedInt : Parser Int = MkParser \h. + (AsList _ digits) = parse h $ parseMany parseDigit + yieldState 0 \ref. + for i. ref := 10 * get ref + digits.i + +def parseInt : Parser Int = MkParser \h. + negSign = parse h $ optional $ pChar '-' + x = parse h $ parseUnsignedInt + case negSign of + Nothing -> x + Just () -> (-1) * x + +def bracketed (l:Parser Unit) (r:Parser Unit) (body:Parser a) : Parser a = + MkParser \h. + _ = parse h l + ans = parse h body + _ = parse h r + ans + +def parens (parser:Parser a) : Parser a = + bracketed (pChar '(') (pChar ')') parser diff --git a/examples/plot.dx b/lib/plot.dx similarity index 93% rename from examples/plot.dx rename to lib/plot.dx index fde6d9754..56c36f647 100644 --- a/examples/plot.dx +++ b/lib/plot.dx @@ -1,7 +1,7 @@ '# Plotting library -include "examples/diagram.dx" -include "examples/png.dx" +include "diagram.dx" +include "png.dx" data CompactSet a:Type = Interval a a @@ -49,13 +49,12 @@ def getScaled (sd:ScaledData n a) (i:n) : Maybe Float = lowColor = [1.0, 0.5, 0.0] highColor = [0.0, 0.5, 1.0] -def interpolate (_:VSpace a) ?=> (low:a) (high:a) (x:Float) : a = +def interpolate [VSpace a] (low:a) (high:a) (x:Float) : a = x' = clip (0.0, 1.0) x (x' .* low) + ((1.0 - x') .* high) def makeRGBColor (c : Color) : HtmlColor = - [r, g, b] = for i. IToW8 $ FToI $ floor (255.0 * c.i) - (r, g, b) + for i. IToW8 $ FToI $ floor (255.0 * c.i) def colorScale (x:Float) : HtmlColor = makeRGBColor $ interpolate lowColor highColor x @@ -109,6 +108,10 @@ def xycPlot (xs:n=>Float) (ys:n=>Float) (cs:n=>Float) : Plot n Float Float Float setYData (autoScale ys) |> setCData (autoScale cs) +def yPlot (ys:n=>Float) : Plot n Float Float Unit = + xs = for i. IToF $ ordinal i + xyPlot xs ys + -- xs = linspace (Fin 100) 0. 1.0 -- :html showPlot $ xycPlot xs xs xs @@ -118,6 +121,6 @@ def xycPlot (xs:n=>Float) (ys:n=>Float) (cs:n=>Float) : Plot n Float Float Float def matshow (img:n=>m=>Float) : Html = low = minimum $ for (i,j). img.i.j high = maximum $ for (i,j). img.i.j - pngToHtml $ makePNG for i j. + imgToHtml $ makePNG for i j. x = floatTo8Bit $ (img.i.j - low) / (high - low) [x, x, x] diff --git a/examples/png.dx b/lib/png.dx similarity index 76% rename from examples/png.dx rename to lib/png.dx index 8fe4dbd2f..b542449d5 100644 --- a/examples/png.dx +++ b/lib/png.dx @@ -72,7 +72,7 @@ def decodeChunk (chunk : Fin 4 => Char) : Maybe (Fin 3 => Char) = Just base64s -> Just $ base64sToBytes base64s -- TODO: put this in prelude? -def replace (_:Eq a) ?=> ((old,new):(a&a)) (x:a) : a = +def replace [Eq a] ((old,new):(a&a)) (x:a) : a = case x == old of True -> new False -> x @@ -91,23 +91,42 @@ def base64Decode (s:String) : Maybe String = '## PNG FFI -Html : Type = List Char +Html : Type = String +Png : Type = String +Gif : Type = String -def makePNG (img:n=>m=>(Fin 3)=>Word8) : List Byte = +def makePNG (img:n=>m=>(Fin 3)=>Word8) : Png = unsafeIO \(). (AsList _ imgFlat) = toList for (i,j,k). img.i.j.k - (n, ptr) = (%ffi encodePNG (Int & CharPtr) (%getPtr imgFlat) - (size m) (size n)) - AsList n $ for i. %ptrLoad (%ptrOffset ptr (ordinal i)) - -def pngToHtml (png:List Byte) : List Char = - (toList " base64Encode png - <> toList "\">") - -'## API entry point + withTabPtr imgFlat \ptr. + (MkPtr rawPtr) = ptr + (n, ptr') = %ffi encodePNG (Int & RawPtr) rawPtr (size m) (size n) + toList $ tabFromPtr (Fin n) $ MkPtr ptr' + +def pngsToGif (delay:Int) (pngs:t=>Png) : Gif = unsafeIO \(). + withTempFiles \pngFiles. + for i. writeFile pngFiles.i pngs.i + withTempFile \gifFile. + shellOut $ + "convert" <> " -delay " <> show delay <> " " <> + concat (for i. "png:" <> pngFiles.i <> " ") <> + "gif:" <> gifFile + readFile gifFile + +def imgToHtml (png:String) : Html = + (" base64Encode png + <> "\">") def floatTo8Bit (x:Float) : Word8 = IToW8 $ FToI $ 255.0 * clip (0.0, 1.0) x +def imgToPng (img:n=>m=>(Fin 3)=>Float) : Png = + makePNG for i j k. floatTo8Bit img.i.j.k + +'## API entry point + def imshow (img:n=>m=>(Fin 3)=>Float) : Html = - pngToHtml $ makePNG for i j k. floatTo8Bit img.i.j.k + imgToHtml $ imgToPng img + +def imseqshow (imgs:t=>n=>m=>(Fin 3)=>Float) : Html = + imgToHtml $ pngsToGif 50 $ map imgToPng imgs diff --git a/prelude.dx b/lib/prelude.dx similarity index 51% rename from prelude.dx rename to lib/prelude.dx index 1c11c57fc..596f9ad3f 100644 --- a/prelude.dx +++ b/lib/prelude.dx @@ -1,9 +1,8 @@ - -'## Dex prelude +'# Dex prelude 'Runs before every Dex program unless an alternative is provided with `--prelude`. -'Wrappers around primitives +'## Wrappers around primitives Unit = %UnitType Type = %TyKind @@ -19,6 +18,8 @@ Word8 = %Word8 Byte = Word8 Char = Byte +RawPtr : Type = %Word8Ptr + Int = Int32 Float = Float32 @@ -36,108 +37,110 @@ def IToI32 (x : Int) : Int32 = internalCast _ x def IToW8 (x : Int) : Word8 = internalCast _ x def IToF (x:Int) : Float = internalCast _ x def FToI (x:Float) : Int = internalCast _ x +def I64ToRawPtr (x:Int64 ) : RawPtr = internalCast _ x +def RawPtrToI64 (x:RawPtr) : Int64 = internalCast _ x -interface Add a:Type where +interface Add a add : a -> a -> a sub : a -> a -> a zero : a -def (+) (d:Add a) ?=> : a -> a -> a = add -def (-) (d:Add a) ?=> : a -> a -> a = sub +def (+) [Add a] : a -> a -> a = add +def (-) [Add a] : a -> a -> a = sub -instance float64Add : Add Float64 where - add = \x:Float64 y:Float64. %fadd x y - sub = \x:Float64 y:Float64. %fsub x y +instance Add Float64 + add = \x y. %fadd x y + sub = \x y. %fsub x y zero = FToF64 0.0 -instance float32Add : Add Float32 where - add = \x:Float32 y:Float32. %fadd x y - sub = \x:Float32 y:Float32. %fsub x y +instance Add Float32 + add = \x y. %fadd x y + sub = \x y. %fsub x y zero = FToF32 0.0 -instance int64Add : Add Int64 where - add = \x:Int64 y:Int64. %iadd x y - sub = \x:Int64 y:Int64. %isub x y +instance Add Int64 + add = \x y. %iadd x y + sub = \x y. %isub x y zero = IToI64 0 -instance int32Add : Add Int32 where - add = \x:Int32 y:Int32. %iadd x y - sub = \x:Int32 y:Int32. %isub x y +instance Add Int32 + add = \x y. %iadd x y + sub = \x y. %isub x y zero = IToI32 0 -instance word8Add : Add Word8 where - add = \x:Word8 y:Word8. %iadd x y - sub = \x:Word8 y:Word8. %isub x y +instance Add Word8 + add = \x y. %iadd x y + sub = \x y. %isub x y zero = IToW8 0 -instance unitAdd : Add Unit where +instance Add Unit add = \x y. () sub = \x y. () zero = () -instance tabAdd : Add a ?=> Add (n=>a) where +instance [Add a] Add (n=>a) add = \xs ys. for i. xs.i + ys.i sub = \xs ys. for i. xs.i - ys.i zero = for _. zero -interface Mul a:Type where +interface Mul a mul : a -> a -> a one : a -def (*) (d:Mul a) ?=> : a -> a -> a = mul +def (*) [Mul a] : a -> a -> a = mul -instance float64Mul : Mul Float64 where - mul = \x:Float64 y:Float64. %fmul x y +instance Mul Float64 + mul = \x y. %fmul x y one = FToF64 1.0 -instance float32Mul : Mul Float32 where - mul = \x:Float32 y:Float32. %fmul x y +instance Mul Float32 + mul = \x y. %fmul x y one = FToF32 1.0 -instance int64Mul : Mul Int64 where - mul = \x:Int64 y:Int64. %imul x y +instance Mul Int64 + mul = \x y. %imul x y one = IToI64 1 -instance int32Mul : Mul Int32 where - mul = \x:Int32 y:Int32. %imul x y +instance Mul Int32 + mul = \x y. %imul x y one = IToI32 1 -instance word8Mul : Mul Word8 where - mul = \x:Word8 y:Word8. %imul x y +instance Mul Word8 + mul = \x y. %imul x y one = IToW8 1 -instance unitMul : Mul Unit where +instance Mul Unit mul = \x y. () one = () -interface Integral a:Type where - idiv: a->a->a - rem: a->a->a +interface Integral a + idiv : a->a->a + rem : a->a->a -instance int64Integral : Integral Int64 where - idiv = \x:Int64 y:Int64. %idiv x y - rem = \x:Int64 y:Int64. %irem x y +instance Integral Int64 + idiv = \x y. %idiv x y + rem = \x y. %irem x y -instance int32Integral : Integral Int32 where - idiv = \x:Int32 y:Int32. %idiv x y - rem = \x:Int32 y:Int32. %irem x y +instance Integral Int32 + idiv = \x y. %idiv x y + rem = \x y. %irem x y -instance word8Integral : Integral Word8 where - idiv = \x:Word8 y:Word8. %idiv x y - rem = \x:Word8 y:Word8. %irem x y +instance Integral Word8 + idiv = \x y. %idiv x y + rem = \x y. %irem x y -interface Fractional a:Type where +interface Fractional a divide : a -> a -> a -instance float64Fractional : Fractional Float64 where - divide = \x:Float64 y:Float64. %fdiv x y +instance Fractional Float64 + divide = \x y. %fdiv x y -instance float32Fractional : Fractional Float32 where - divide = \x:Float32 y:Float32. %fdiv x y +instance Fractional Float32 + divide = \x y. %fdiv x y -'Basic polymorphic functions and types +'## Basic polymorphic functions and types def (&) (a:Type) (b:Type) : Type = %PairType a b def (,) (x:a) (y:b) : (a & b) = %pair x y @@ -152,23 +155,26 @@ flip : (a -> b -> c) -> (b -> a -> c) = \f x y. f y x uncurry : (a -> b -> c) -> (a & b) -> c = \f (x,y). f x y const : a -> b -> a = \x _. x -'Vector spaces +'## Vector spaces -data VSpace a:Type = MkVSpace (Add a) (Float -> a -> a) +interface [Add a] VSpace a + scaleVec : Float -> a -> a -@superclass -def addFromVSpace (d:VSpace a) : Add a = case d of MkVSpace addDict _ -> addDict +def (.*) [VSpace a] : Float -> a -> a = scaleVec +def (*.) [VSpace a] : a -> Float -> a = flip scaleVec +def (/) [VSpace a] (v:a) (s:Float) : a = divide 1.0 s .* v +def neg [VSpace a] (v:a) : a = (-1.0) .* v -def (.*) (d:VSpace a) ?=> : Float -> a -> a = case d of MkVSpace _ scale -> scale -(*.) : VSpace a ?=> a -> Float -> a = flip (.*) -def (/) (_:VSpace a) ?=> (v:a) (s:Float) : a = (divide 1.0 s) .* v -def neg (_:VSpace a) ?=> (v:a) : a = (-1.0) .* v +instance VSpace Float + scaleVec = \x y. x * y -@instance floatVS : VSpace Float = MkVSpace float32Add (*) -@instance tabVS : VSpace a ?=> VSpace (n=>a) = MkVSpace tabAdd \s xs. for i. s .* xs.i -@instance unitVS : VSpace Unit = MkVSpace unitAdd \s u. () +instance [VSpace a] VSpace (n=>a) + scaleVec = \s xs. for i. s .* xs.i -'Bool type +instance VSpace Unit + scaleVec = \_ _. () + +'## Boolean type data Bool = False @@ -192,9 +198,9 @@ def not (x:Bool) : Bool = x' = BToW8 x W8ToB $ %not x' -'Sum types +'## Sum types -data Maybe a:Type = +data Maybe a = Nothing Just a @@ -202,7 +208,9 @@ def isNothing (x:Maybe a) : Bool = case x of Nothing -> True Just _ -> False -data (|) a:Type b:Type = +def isJust (x:Maybe a) : Bool = not $ isNothing x + +data (|) a b = Left a Right b @@ -211,12 +219,9 @@ def select (p:Bool) (x:a) (y:a) : a = case p of False -> y def BToI (x:Bool) : Int = W8ToI $ BToW8 x - def BToF (x:Bool) : Float = IToF (BToI x) -def todo (a:Type) ?-> : a = %throwError a -def throw (a:Type) ?-> : a = %throwError a -'Effects +'## Effects def Ref (r:Type) (a:Type) : Type = %Ref r a def get (ref:Ref h s) : {State h} s = %get ref @@ -227,87 +232,133 @@ def (!) (ref:Ref h (n=>a)) (i:n) : Ref h a = %indexRef ref i def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref -def withReader - (eff:Effects) ?-> (a:Type) ?-> (r:Type) ?-> - (init:r) (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) +def runReader + (init:r) + (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = def explicitAction (h':Type) (ref:Ref h' r) : {Read h'|eff} a = action ref %runReader init explicitAction -def withAccum - (eff:Effects) ?-> (a:Type) ?-> (w:Type) ?-> +def withReader + (init:r) + (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) + : {|eff} a = + runReader init action + +def runAccum (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) : {|eff} (a & w) = def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = action ref %runWriter explicitAction -def withState - (eff:Effects) ?-> (a:Type) ?-> (s:Type) ?-> +def yieldAccum + (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) + : {|eff} w = + snd $ runAccum action + +def runState (init:s) - (action: (h:Type ?-> Ref h s -> {State h |eff} a)) + (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} (a & s) = - def explicitAction (h':Type) (ref:Ref h' s) : {State h'|eff} a = action ref - %runState init explicitAction - -'Type classes - -data Eq a:Type = MkEq (a -> a -> Bool) -data Ord a:Type = MkOrd (Eq a) (a -> a -> Bool) (a -> a -> Bool) -- eq, gt, lt - -@superclass -def eqFromOrd (d:Ord a) : Eq a = case d of MkOrd eq _ _ -> eq - -def (==) (d:Eq a) ?=> (x:a) (y:a) : Bool = case d of MkEq eq -> eq x y -def (/=) (d:Eq a) ?=> (x:a) (y:a) : Bool = not $ x == y - -def (>) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ gt _ -> gt x y -def (<) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ _ lt -> lt x y -def (<=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x>y || x==y - -@instance float64Eq : Eq Float64 = MkEq \x:Float64 y:Float64. W8ToB $ %feq x y -@instance float32Eq : Eq Float32 = MkEq \x:Float32 y:Float32. W8ToB $ %feq x y -@instance int64Eq : Eq Int64 = MkEq \x:Int64 y:Int64. W8ToB $ %ieq x y -@instance int32Eq : Eq Int32 = MkEq \x:Int32 y:Int32. W8ToB $ %ieq x y -@instance word8Eq : Eq Word8 = MkEq \x:Word8 y:Word8. W8ToB $ %ieq x y -@instance boolEq : Eq Bool = MkEq \x y. BToW8 x == BToW8 y -@instance unitEq : Eq Unit = MkEq \x y. True - -@instance float64Ord : Ord Float64 = (MkOrd float64Eq (\x y. W8ToB $ %fgt x y) - (\x y. W8ToB $ %flt x y)) -@instance float32Ord : Ord Float32 = (MkOrd float32Eq (\x y. W8ToB $ %fgt x y) - (\x y. W8ToB $ %flt x y)) -@instance int64Ord : Ord Int64 = (MkOrd int64Eq (\x y. W8ToB $ %igt x y) - (\x y. W8ToB $ %ilt x y)) -@instance int32Ord : Ord Int32 = (MkOrd int32Eq (\x y. W8ToB $ %igt x y) - (\x y. W8ToB $ %ilt x y)) -@instance word8Ord : Ord Word8 = (MkOrd word8Eq (\x y. W8ToB $ %igt x y) - (\x y. W8ToB $ %ilt x y)) -@instance unitOrd : Ord Unit = (MkOrd unitEq (\x y. False) (\x y. False)) - -@instance -def pairEq (eqA: Eq a)?=> (eqB: Eq b)?=> : Eq (a & b) = MkEq $ - \(x1,x2) (y1,y2). x1 == y1 && x2 == y2 - -@instance -def pairOrd (ordA: Ord a)?=> (ordB: Ord b)?=> : Ord (a & b) = - pairGt = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) - pairLt = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) - MkOrd pairEq pairGt pairLt + def explicitAction (h':Type) (ref:Ref h' s) : {State h'|eff} a = action ref + %runState init explicitAction + +def withState + (init:s) + (action: h:Type ?-> Ref h s -> {State h |eff} a) + : {|eff} a = fst $ runState init action + +def yieldState + (init:s) + (action: h:Type ?-> Ref h s -> {State h |eff} a) + : {|eff} s = snd $ runState init action + +def unsafeIO (f: Unit -> {IO|eff} a) : {|eff} a = + %runIO f + +def unreachable (():Unit) : a = unsafeIO do + %throwError a + +'## Type classes + +interface Eq a + (==) : a -> a -> Bool + +def (/=) [Eq a] (x:a) (y:a) : Bool = not $ x == y + +interface [Eq a] Ord a + (>) : a -> a -> Bool + (<) : a -> a -> Bool + +def (<=) [Ord a] (x:a) (y:a) : Bool = x=) [Ord a] (x:a) (y:a) : Bool = x>y || x==y + +instance Eq Float64 + (==) = \x y. W8ToB $ %feq x y + +instance Eq Float32 + (==) = \x y. W8ToB $ %feq x y +instance Eq Int64 + (==) = \x y. W8ToB $ %ieq x y + +instance Eq Int32 + (==) = \x y. W8ToB $ %ieq x y + +instance Eq Word8 + (==) = \x y. W8ToB $ %ieq x y + +instance Eq Bool + (==) = \x y. BToW8 x == BToW8 y + +instance Eq Unit + (==) = \x y. True + +instance Eq RawPtr + (==) = \x y. RawPtrToI64 x == RawPtrToI64 y + +instance Ord Float64 + (>) = \x y. W8ToB $ %fgt x y + (<) = \x y. W8ToB $ %flt x y + +instance Ord Float32 + (>) = \x y. W8ToB $ %fgt x y + (<) = \x y. W8ToB $ %flt x y + +instance Ord Int64 + (>) = \x y. W8ToB $ %igt x y + (<) = \x y. W8ToB $ %ilt x y + +instance Ord Int32 + (>) = \x y. W8ToB $ %igt x y + (<) = \x y. W8ToB $ %ilt x y + +instance Ord Word8 + (>) = \x y. W8ToB $ %igt x y + (<) = \x y. W8ToB $ %ilt x y + +instance Ord Unit + (>) = \x y. False + (<) = \x y. False + +instance [Eq a, Eq b] Eq (a & b) + (==) = \(x1,x2) (y1,y2). x1 == y1 && x2 == y2 + +instance [Ord a, Ord b] Ord (a & b) + (>) = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) + (<) = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) -- TODO: accumulate using the True/&& monoid -@instance -def tabEq (n:Type) ?-> (eqA: Eq a) ?=> : Eq (n=>a) = MkEq $ - \xs ys. +instance [Eq a] Eq (n=>a) + (==) = \xs ys. numDifferent : Float = - snd $ withAccum \ref. for i. + yieldAccum \ref. for i. ref += (IToF (BToI (xs.i /= ys.i))) numDifferent == 0.0 -'Transcencendental functions +'## Transcencendental functions -interface Floating a:Type where +interface Floating a exp : a -> a exp2 : a -> a log : a -> a @@ -327,7 +378,7 @@ interface Floating a:Type where pow : a -> a -> a lgamma : a -> a -def lbeta (_ : Add a) ?=> (_ : Floating a) ?=> : a -> a -> a = \x y. lgamma x + lgamma y - lgamma (x + y) +def lbeta [Add a, Floating a] : a -> a -> a = \x y. lgamma x + lgamma y - lgamma (x + y) -- Todo: better numerics for very large and small values. -- Using %exp here to avoid circular definition problems. @@ -341,92 +392,155 @@ def float64_cosh (x:Float64) : Float64 = %fdiv ((%exp x) + (%exp (%fsub (FToF64 def float64_tanh (x:Float64) : Float64 = %fdiv (%fsub (%exp x) (%exp (%fsub (FToF64 0.0) x))) ((%exp x) + (%exp (%fsub (FToF64 0.0) x))) -instance float64Floating : Floating Float64 where - exp = \x:Float64. %exp x - exp2 = \x:Float64. %exp2 x - log = \x:Float64. %log x - log2 = \x:Float64. %log2 x - log10 = \x:Float64. %log10 x - log1p = \x:Float64. %log1p x - sin = \x:Float64. %sin x - cos = \x:Float64. %cos x - tan = \x:Float64. %tan x +instance Floating Float64 + exp = \x. %exp x + exp2 = \x. %exp2 x + log = \x. %log x + log2 = \x. %log2 x + log10 = \x. %log10 x + log1p = \x. %log1p x + sin = \x. %sin x + cos = \x. %cos x + tan = \x. %tan x sinh = float64_sinh cosh = float64_cosh tanh = float64_tanh - floor = \x:Float64. %floor x - ceil = \x:Float64. %ceil x - round = \x:Float64. %round x - sqrt = \x:Float64. %sqrt x - pow = \x:Float64 y:Float64. %fpow x y - lgamma = \x:Float64. %lgamma x - -instance float32Floating : Floating Float32 where - exp = \x:Float32. %exp x - exp2 = \x:Float32. %exp2 x - log = \x:Float32. %log x - log2 = \x:Float32. %log2 x - log10 = \x:Float32. %log10 x - log1p = \x:Float32. %log1p x - sin = \x:Float32. %sin x - cos = \x:Float32. %cos x - tan = \x:Float32. %tan x + floor = \x. %floor x + ceil = \x. %ceil x + round = \x. %round x + sqrt = \x. %sqrt x + pow = \x y. %fpow x y + lgamma = \x. %lgamma x + +instance Floating Float32 + exp = \x. %exp x + exp2 = \x. %exp2 x + log = \x. %log x + log2 = \x. %log2 x + log10 = \x. %log10 x + log1p = \x. %log1p x + sin = \x. %sin x + cos = \x. %cos x + tan = \x. %tan x sinh = float32_sinh cosh = float32_cosh tanh = float32_tanh - floor = \x:Float32. %floor x - ceil = \x:Float32. %ceil x - round = \x:Float32. %round x - sqrt = \x:Float32. %sqrt x - pow = \x:Float32 y:Float32. %fpow x y - lgamma = \x:Float32. %lgamma x + floor = \x. %floor x + ceil = \x. %ceil x + round = \x. %round x + sqrt = \x. %sqrt x + pow = \x y. %fpow x y + lgamma = \x. %lgamma x -'Working with index sets +'## Index set utilities def Range (low:Int) (high:Int) : Type = %IntRange low high def Fin (n:Int) : Type = Range 0 n def ordinal (i:a) : Int = %toOrdinal i def size (n:Type) : Int = %idxSetSize n def unsafeFromOrdinal (n : Type) (i : Int) : n = %unsafeFromOrdinal n i +def iota (n:Type) : n=>Int = view i. ordinal i -def fromOrdinal (n:Type) (i:Int) : n = - case (0 <= i) && (i < size n) of - True -> unsafeFromOrdinal _ i - False -> throw +-- TODO: we want Eq and Ord for all index sets, not just `Fin n` +instance Eq (Fin n) + (==) = \x y. ordinal x == ordinal y -def asidx (n:Type) ?-> (i:Int) : n = fromOrdinal n i -def (@) (i:Int) (n:Type) : n = fromOrdinal n i -def iota (n:Type) : n=>Int = for i. ordinal i +instance Ord (Fin n) + (>) = \x y. ordinal x > ordinal y + (<) = \x y. ordinal x < ordinal y --- TODO: we want Eq and Ord for all index sets, not just `Fin n` -@instance -def finEq (n:Int) ?-> : Eq (Fin n) = MkEq \x y. ordinal x == ordinal y +'## Raw pointer operations + +data Ptr a = MkPtr RawPtr + +-- Is there a better way to select the right instance for `storageSize`?? +data TypeVehicle a = MkTypeVehicle +def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle + +interface Storable a + store : Ptr a -> a -> {IO} Unit + load : Ptr a -> {IO} a + storageSize_ : TypeVehicle a -> Int + +def storageSize (a:Type) -> (d:Storable a) ?=> : Int = + tv : TypeVehicle a = MkTypeVehicle + storageSize_ tv + +instance Storable Word8 + store = \(MkPtr ptr) x. %ptrStore ptr x + load = \(MkPtr ptr) . %ptrLoad ptr + storageSize_ = const 1 + +instance Storable Int32 + store = \(MkPtr ptr) x. %ptrStore (internalCast %Int32Ptr ptr) x + load = \(MkPtr ptr) . %ptrLoad (internalCast %Int32Ptr ptr) + storageSize_ = const 4 -@instance -def finOrd (n:Int) ?-> : Ord (Fin n) = - MkOrd finEq (\x y. ordinal x > ordinal y) (\x y. ordinal x < ordinal y) +instance Storable (Ptr a) + store = \(MkPtr ptr) (MkPtr x). %ptrStore (internalCast %PtrPtr ptr) x + load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) + storageSize_ = const 8 -- TODO: something more portable? -'Misc +-- TODO: Storable instances for other types + +def malloc [Storable a] (n:Int) : {IO} (Ptr a) = + numBytes = storageSize a * n + MkPtr $ %alloc numBytes + +def free (ptr:Ptr a) : {IO} Unit = + (MkPtr ptr') = ptr + %free ptr' + +def (+>>) [Storable a] (ptr:Ptr a) (i:Int) : Ptr a = + (MkPtr ptr') = ptr + i' = i * storageSize a + MkPtr $ %ptrOffset ptr' i' + +-- TODO: consider making a Storable instance for tables instead +def storeTab [Storable a] (ptr: Ptr a) (tab:n=>a) : {IO} Unit = + for_ i. store (ptr +>> ordinal i) tab.i + +def memcpy [Storable a] (dest:Ptr a) (src:Ptr a) (n:Int) : {IO} Unit = + for_ i:(Fin n). + i' = ordinal i + store (dest +>> i') (load $ src +>> i') + +-- TODO: generalize these brackets to allow other effects +-- TODO: make sure that freeing happens even if there are run-time errors +def withAlloc [Storable a] + (n:Int) (action: Ptr a -> {IO} b) : {IO} b = + ptr = malloc n + result = action ptr + free ptr + result + +def withTabPtr [Storable a] + (xs:n=>a) (action : Ptr a -> {IO} b) : {IO} b = + withAlloc (size n) \ptr. + for i. store (ptr +>> ordinal i) xs.i + action ptr + +def tabFromPtr [Storable a] (n:Type) -> (ptr:Ptr a) : {IO} n=>a = + for i. load $ ptr +>> ordinal i + +'## Miscellaneous common utilities pi : Float = 3.141592653589793 def id (x:a) : a = x def dup (x:a) : (a & a) = (x, x) def map (f:a->{|eff} b) (xs: n=>a) : {|eff} (n=>b) = for i. f xs.i -def zip (xs:n=>a) (ys:n=>b) : (n=>(a&b)) = for i. (xs.i, ys.i) +def zip (xs:n=>a) (ys:n=>b) : (n=>(a&b)) = view i. (xs.i, ys.i) def unzip (xys:n=>(a&b)) : (n=>a & n=>b) = (map fst xys, map snd xys) -def fanout (n:Type) (x:a) : n=>a = for i. x -def sq (d:Mul a) ?=> (x:a) : a = x * x -def abs (_:Add a) ?=> (_:Ord a) ?=> (x:a) : a = select (x > zero) x (zero - x) +def fanout (n:Type) (x:a) : n=>a = view i. x +def sq [Mul a] (x:a) : a = x * x +def abs [Add a, Ord a] (x:a) : a = select (x > zero) x (zero - x) def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y -def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = - for i. xs.(fromOrdinal _ (ordinal i + start)) - def reindex (ixr: b -> a) (tab: a=>v) : b=>v = for i. tab.(ixr i) def scan (init:a) (body:n->a->(a&b)) : (a & n=>b) = - swap $ withState init \s. for i. + swap $ runState init \s. for i. c = get s (c', y) = body i c s := c' @@ -442,57 +556,53 @@ def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = -- TODO: call this `scan` and call the current `scan` something else def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x) -- TODO: allow tables-via-lambda and get rid of this -def fsum (xs:n->Float) : Float = snd $ withAccum \ref. for i. ref += xs i -def sum (_: Add v) ?=> (xs:n=>v) : v = reduce zero (+) xs -def prod (_: Mul v) ?=> (xs:n=>v) : v = reduce one (*) xs -def mean (n:Type) ?-> (xs:n=>Float) : Float = sum xs / IToF (size n) +def fsum (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs i +def sum [Add v] (xs:n=>v) : v = reduce zero (+) xs +def prod [Mul v] (xs:n=>v) : v = reduce one (*) xs +def mean (xs:n=>Float) : Float = sum xs / IToF (size n) def std (xs:n=>Float) : Float = sqrt $ mean (map sq xs) - sq (mean xs) def any (xs:n=>Bool) : Bool = reduce False (||) xs def all (xs:n=>Bool) : Bool = reduce True (&&) xs - def applyN (n:Int) (x:a) (f:a -> a) : a = - snd $ withState x \ref. for _:(Fin n). + yieldState x \ref. for _:(Fin n). ref := f (get ref) def linspace (n:Type) (low:Float) (high:Float) : n=>Float = dx = (high - low) / IToF (size n) for i:n. low + IToF (ordinal i) * dx -def transpose (x:n=>m=>a) : m=>n=>a = for i j. x.j.i -def vdot (x:n=>Float) (y:n=>Float) : Float = fsum \i. x.i * y.i -def dot (_:VSpace v) ?=> (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j +def transpose (x:n=>m=>a) : m=>n=>a = view i j. x.j.i +def vdot (x:n=>Float) (y:n=>Float) : Float = fsum view i. x.i * y.i +def dot [VSpace v] (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j -- matmul. Better symbol to use? `@`? -(**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. - y' = transpose y - for i k. fsum \j. x.i.j * y'.k.j +(**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. + for i k. fsum view j. x.i.j * y.j.k (**.) : (n=>m=>Float) -> (m=>Float) -> (n=>Float) = \mat v. for i. vdot mat.i v (.**) : (m=>Float) -> (n=>m=>Float) -> (n=>Float) = flip (**.) def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = - fsum \(i,j). x.i * mat.i.j * y.j + fsum view (i,j). x.i * mat.i.j * y.j + +def eye [Eq n] : n=>n=>Float = + for i j. select (i == j) 1.0 0.0 -'Functions for working with the pseudorandom number generator +'## Pseudorandom number generator utilities -- TODO: newtype Key = Int64 -def hash (x:Key) (y:Int32) : Key = +def hash (x:Key) (y:Int32) : Key = unsafeIO do y64 = IToI64 y %ffi threefry2x32 Int64 x y64 def newKey (x:Int) : Key = hash (IToI64 0) x -def splitKey (k:Key) : (Key & Key) = (hash k 0, hash k 1) -def splitKey3 (k:Key) : (Key & Key & Key) = - (k1, k') = splitKey k - (k2, k3) = splitKey k' - (k1, k2, k3) - def many (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i)) def ixkey (k:Key) (i:n) : Key = hash k (ordinal i) def ixkey2 (k:Key) (i:n) (j:m) : Key = hash (hash k (ordinal i)) (ordinal j) -def rand (k:Key) : Float = F64ToF $ %ffi randunif Float64 k +def splitKey (k:Key) : Fin n => Key = for i. ixkey k i +def rand (k:Key) : Float = unsafeIO do F64ToF $ %ffi randunif Float64 k def randVec (n:Int) (f: Key -> a) (k: Key) : Fin n => a = for i:(Fin n). f (ixkey k i) @@ -500,50 +610,27 @@ def randMat (n:Int) (m:Int) (f: Key -> a) (k: Key) : Fin n => Fin m => a = for i j. f (ixkey2 k i j) def randn (k:Key) : Float = - (k1, k2) = splitKey k + [k1, k2] = splitKey k u1 = rand k1 u2 = rand k2 sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2) -def randIdx (n:Type) ?-> (k:Key) : n = - unif = rand k - fromOrdinal n $ FToI $ floor $ unif * IToF (size n) - -- TODO: Make this better... def randInt (k:Key) : Int = (I64ToI k) `mod` 2147483647 def bern (p:Float) (k:Key) : Bool = rand k < p -def randnVec (n:Type) ?-> (k:Key) : n=>Float = +def randnVec (k:Key) : n=>Float = for i. randn (ixkey k i) -'min / max etc - -def minBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y -def maxBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y - -def min (_:Ord o) ?=> (x1: o) -> (x2: o) : o = minBy id x1 x2 -def max (_:Ord o) ?=> (x1: o) -> (x2: o) : o = maxBy id x1 x2 - -def minimumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = - reduce xs.(0@_) (minBy f) xs -def maximumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = - reduce xs.(0@_) (maxBy f) xs - -def minimum (_:Ord o) ?=> (xs:n=>o) : o = minimumBy id xs -def maximum (_:Ord o) ?=> (xs:n=>o) : o = maximumBy id xs +def cumSum (xs: n=>Float) : n=>Float = + withState 0.0 \total. + for i. + newTotal = get total + xs.i + total := newTotal + newTotal -def argmin (_:Ord o) ?=> (xs:n=>o) : n = - zeroth = (0@_, xs.(0@_)) - compare = \(idx1, x1) (idx2, x2). - select (x1 < x2) (idx1, x1) (idx2, x2) - zipped = for i. (i, xs.i) - fst $ reduce zeroth compare zipped - -def clip (_:Ord a) ?=> ((low,high):(a&a)) (x:a) : a = - min high $ max low x - -'Automatic differentiation +'## Automatic differentiation -- TODO: add vector space constraints def linearize (f:a->b) (x:a) : (b & a --o b) = %linearize f x @@ -560,33 +647,33 @@ def deriv (f:Float->Float) (x:Float) : Float = jvp f x 1.0 def derivRev (f:Float->Float) (x:Float) : Float = snd (vjp f x) 1.0 -interface HasAllClose a:Type where +interface HasAllClose a allclose : a -> a -> a -> a -> Bool -interface HasDefaultTolerance a:Type where +interface HasDefaultTolerance a atol : a rtol : a -def (~~) (_:HasAllClose a) ?=> (d:HasDefaultTolerance a) ?=> : a -> a -> Bool = allclose atol rtol +def (~~) [HasAllClose a, HasDefaultTolerance a] : a -> a -> Bool = allclose atol rtol -instance allCloseF32 : HasAllClose Float32 where +instance HasAllClose Float32 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y) -instance allCloseF64 : HasAllClose Float64 where +instance HasAllClose Float64 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y) -instance defaultToleranceF32 : HasDefaultTolerance Float32 where +instance HasDefaultTolerance Float32 atol = FToF32 0.00001 rtol = FToF32 0.0001 -instance defaultToleranceF64 : HasDefaultTolerance Float64 where +instance HasDefaultTolerance Float64 atol = FToF64 0.00000001 rtol = FToF64 0.00001 -instance allCloseTable : HasAllClose t ?=> HasDefaultTolerance t ?=> HasAllClose (n=>t) where +instance [HasAllClose t, HasDefaultTolerance t] HasAllClose (n=>t) allclose = \atol rtol a b. all for i:n. (a.i ~~ b.i) -instance defaultToleranceTable : (HasDefaultTolerance t) ?=> HasDefaultTolerance (n=>t) where +instance [HasDefaultTolerance t] HasDefaultTolerance (n=>t) atol = for i. atol rtol = for i. rtol @@ -601,17 +688,7 @@ def checkDerivBase (f:Float->Float) (x:Float) : Bool = def checkDeriv (f:Float->Float) (x:Float) : Bool = checkDerivBase f x && checkDerivBase (deriv f) x -'Control flow - -def while - (eff:Effects) ?-> - (cond: Unit -> {|eff} Bool) - (body: Unit -> {|eff} Unit) - : {|eff} Unit = - cond' : Unit -> {|eff} Word8 = \_. BToW8 $ cond () - %while cond' body - -'Vector support +'## Vector support -- TODO: Reenable vector suport once fixed-width types are supported. -- def UNSAFEFromOrdinal (n : Type) (i : Int) : n = %unsafeAsIndex n i @@ -647,7 +724,7 @@ def while -- @instance vectorFloatVSpace : VSpace VectorFloat = -- MkVSpace vectorFloatAdd \x v. broadcastVector x * v -'Tiling +'## Tiling functions def Tile (n : Type) (m : Type) : Type = %IndexSlice n m @@ -656,14 +733,12 @@ def Tile (n : Type) (m : Type) : Type = %IndexSlice n m -- elements of n. In this view (+>) is just function application, while ++> -- is currying followed by function application. We cannot represent currying -- in isolation, because `Tile n (Tile u v)` does not make sense, unlike `Tile n (u & v)`. -def (+>) (l : Type) ?-> (t:Tile n l) (i : l) : n = %sliceOffset t i +def (+>) (t:Tile n l) (i : l) : n = %sliceOffset t i def (++>) (t : Tile n (u & v)) (i : u) : Tile n v = %sliceCurry t i -def tile (l : Type) ?-> - (fTile : (t:(Tile n l) -> {|eff} l=>a)) +def tile (fTile : (t:(Tile n l) -> {|eff} l=>a)) (fScalar : n -> {|eff} a) : {|eff} n=>a = %tiled fTile fScalar -def tile1 (n : Type) ?-> (l : Type) ?-> (m : Type) ?-> - (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) +def tile1 (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) (fScalar : n -> {|eff} m=>a) : {|eff} m=>n=>a = %tiledd fTile fScalar -- TODO: This should become just `loadVector $ for i. arr.(t +> i)` @@ -675,41 +750,27 @@ def tile1 (n : Type) ?-> (l : Type) ?-> (m : Type) ?-> -- arr.(t +> UNSAFEFromOrdinal idx 2) -- arr.(t +> UNSAFEFromOrdinal idx 3)) -'Numerical utilities - -def logsumexp (x: n=>Float) : Float = - m = maximum x - m + (log $ sum for i. exp (x.i - m)) - -def logsoftmax (x: n=>Float) : n=>Float = - lse = logsumexp x - for i. x.i - lse - -def softmax (x: n=>Float) : n=>Float = - m = maximum x - e = for i. exp (x.i - m) - s = sum e - for i. e.i / s - -def evalpoly (_:VSpace v) ?=> (coefficients:n=>v) (x:Float) : v = - -- Evaluate a polynomial at x. Same as Numpy's polyval. - fold zero \i c. coefficients.i + x .* c - +'## Monoid typeclass -'Monoid - -interface Monoid a:Type where +interface Monoid a mempty : a mcombine : a -> a -> a -- can't use `<>` just for parser reasons? -(<>) : Monoid a ?=> a -> a -> a = mcombine +def (<>) [Monoid a] : a -> a -> a = mcombine -'Length-erased lists +'## Length-erased lists -data List a:Type = +data List a = AsList n:Int foo:(Fin n => a) -instance monoidList : Monoid (List a) where +def unsafeCastTable (m:Type) (xs:n=>a) : m=>a = + for i. xs.(unsafeFromOrdinal _ (ordinal i)) + +def toList (xs:n=>a) : List a = + n' = size n + AsList _ $ unsafeCastTable (Fin n') xs + +instance Monoid (List a) mempty = AsList _ [] mcombine = \x y. (AsList nx xs) = x @@ -721,9 +782,9 @@ instance monoidList : Monoid (List a) where True -> xs.(unsafeFromOrdinal _ i') False -> ys.(unsafeFromOrdinal _ (i' - nx)) -'Isomorphisms +'## Isomorphisms -data Iso a:Type b:Type = MkIso { fwd: a -> b & bwd: b -> a } +data Iso a b = MkIso { fwd: a -> b & bwd: b -> a } def appIso (iso: Iso a b) (x:a) : b = (MkIso {fwd, bwd}) = iso @@ -801,46 +862,107 @@ splitV : Iso a ({|} | a) = def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = reindex (buildWith $ splitV &>> iso) tab +'Dynamic buffer + +-- TODO: would be nice to be able to use records here +data DynBuffer a = + MkDynBuffer { size : Ptr Int + & maxSize : Ptr Int + & buffer : Ptr (Ptr a) } + +def withDynamicBuffer [Storable a] + (action: DynBuffer a -> {IO} b) : {IO} b = + initMaxSize = 256 + withAlloc 1 \sizePtr. withAlloc 1 \maxSizePtr. withAlloc 1 \bufferPtr. + store sizePtr 0 + store maxSizePtr initMaxSize + store bufferPtr $ malloc initMaxSize + result = action $ MkDynBuffer { size = sizePtr + , maxSize = maxSizePtr + , buffer = bufferPtr } + + free $ load bufferPtr + result + +def maybeIncreaseBufferSize [Storable a] + ((MkDynBuffer db): DynBuffer a) (sizeDelta:Int) : {IO} Unit = + size = load $ getAt #size db + maxSize = load $ getAt #maxSize db + bufPtr = load $ getAt #buffer db + newSize = sizeDelta + size + if newSize > maxSize then + -- TODO: maybe this should use integer arithmetic? + newMaxSize = FToI $ pow 2.0 (ceil $ log2 $ IToF newSize) + newBufPtr = malloc newMaxSize + memcpy newBufPtr bufPtr size + free bufPtr + store (getAt #maxSize db) newMaxSize + store (getAt #buffer db) newBufPtr + +def addAtIntPtr (ptr: Ptr Int) (n:Int) : {IO} Unit = + store ptr (load ptr + n) + +def extendDynBuffer [Storable a] + (buf: DynBuffer a) (new:List a) : {IO} Unit = + (AsList n xs) = new + maybeIncreaseBufferSize buf n + (MkDynBuffer db) = buf + bufPtr = load $ getAt #buffer db + size = load $ getAt #size db + storeTab (bufPtr +>> size) xs + addAtIntPtr (getAt #size db) n + +def loadDynBuffer [Storable a] + (buf: DynBuffer a) : {IO} (List a) = + (MkDynBuffer db) = buf + bufPtr = load $ getAt #buffer db + size = load $ getAt #size db + AsList size $ tabFromPtr _ bufPtr + +def pushDynBuffer [Storable a] + (buf: DynBuffer a) (x:a) : {IO} Unit = + extendDynBuffer buf $ AsList _ [x] + '## Strings and Characters String : Type = List Char -CharPtr : Type = %CharPtr +def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {IO} String = + AsList n $ tabFromPtr _ ptr -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint (c:Char) : Int = W8ToI c -interface Show a:Type where +interface Show a show : a -> String -instance showInt32 : Show Int32 where - show = \x: Int32. - (n, ptr) = %ffi showInt32 (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) - -instance showInt64 : Show Int64 where - show = \x: Int64. - (n, ptr) = %ffi showInt64 (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) - -instance showFloat32 : Show Float32 where - show = \x: Float32. - (n, ptr) = %ffi showFloat32 (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) - -instance showFloat64 : Show Float64 where - show = \x: Float64. - (n, ptr) = %ffi showFloat64 (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) +instance Show String + show = id + +instance Show Int32 + show = \x. unsafeIO do + (n, ptr) = %ffi showInt32 (Int32 & RawPtr) x + stringFromCharPtr n $ MkPtr ptr + +instance Show Int64 + show = \x. unsafeIO do + (n, ptr) = %ffi showInt64 (Int32 & RawPtr) x + stringFromCharPtr n $ MkPtr ptr + +instance Show Float32 + show = \x. unsafeIO do + (n, ptr) = %ffi showFloat32 (Int32 & RawPtr) x + stringFromCharPtr n $ MkPtr ptr + +instance Show Float64 + show = \x. unsafeIO do + (n, ptr) = %ffi showFloat64 (Int32 & RawPtr) x + stringFromCharPtr n $ MkPtr ptr -- pipe-like reverse function application def (|>) (x:a) (f: a -> b) : b = f x -'## Floating point helper functions +'## Floating-point helper functions def sign (x:Float) : Float = case x > 0.0 of @@ -867,8 +989,287 @@ def isnan (x:Float) : Bool = not (x >= x && x <= x) -- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered. def either_is_nan (x:Float) (y:Float) : Bool = (isnan x) || (isnan y) +'File system operations + +FilePath : Type = String +data CString = MkCString RawPtr + +def nullRawPtr : RawPtr = I64ToRawPtr $ IToI64 0 + +def fromNullableRawPtr (ptr:RawPtr) : Maybe (Ptr a) = + if ptr == nullRawPtr + then Nothing + else Just $ MkPtr ptr + +def cStringPtr (s:CString) : Maybe (Ptr Char) = + (MkCString ptr) = s + fromNullableRawPtr ptr + +data StreamMode = + ReadMode + WriteMode + +data Stream mode:StreamMode = MkStream RawPtr + +-- TODO: check the string contains no nulls +def withCString (s:String) (action: CString -> {IO} a) : {IO} a = + (AsList n s') = s <> "\NUL" + withTabPtr s' \(MkPtr ptr). action $ MkCString ptr + +def fopen (path:String) (mode:StreamMode) : {IO} (Stream mode) = + modeStr = case mode of + ReadMode -> "r" + WriteMode -> "w" + withCString path \(MkCString pathPtr). + withCString modeStr \(MkCString modePtr). + MkStream $ %ffi fopen RawPtr pathPtr modePtr + +def fclose (stream:Stream mode) : {IO} Unit = + (MkStream stream') = stream + %ffi fclose Int64 stream' + () + +def fwrite (stream:Stream WriteMode) (s:String) : {IO} Unit = + (MkStream stream') = stream + (AsList n s') = s + withTabPtr s' \(MkPtr ptr). + %ffi fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' + %ffi fflush Int64 stream' + () + +def while (body: Unit -> {|eff} Bool) : {|eff} Unit = + body' : Unit -> {|eff} Word8 = \_. BToW8 $ body () + %while body' + +data IterResult a = + Continue + Done a + +-- TODO: can we improve effect inference so we don't need this? +def liftState (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|eff} b = + f x + +-- A little iteration combinator +def iter (body: Int -> {|eff} IterResult a) : {|eff} a = + result = yieldState Nothing \resultRef. withState 0 \i. + while do + continue = isNothing $ get resultRef + if continue then + case liftState resultRef (liftState i body) (get i) of + Continue -> i := get i + 1 + Done result -> resultRef := Just result + continue + + case result of + Just ans -> ans + Nothing -> unreachable () + +-- XXX: used internally by compiler for exceptional while +def whileMaybe (eff:Effects) -> (body: Unit -> {|eff} (Maybe Word8)) : {|eff} Maybe Unit = + hadError = yieldState False \ref. + while do + ans = liftState ref body () + case ans of + Nothing -> + ref := True + False + Just cond -> W8ToB cond + if hadError + then Nothing + else Just () + +def boundedIter (maxIters:Int) (fallback:a) + (body: Int -> {|eff} IterResult a) : {|eff} a = + iter \i. + if i >= maxIters + then Done fallback + else body i + +def fromCString (s:CString) : {IO} (Maybe String) = + case cStringPtr s of + Nothing -> Nothing + Just ptr -> + Just $ withDynamicBuffer \buf. iter \i. + c = load $ ptr +>> i + if c == '\NUL' + then Done $ loadDynBuffer buf + else + pushDynBuffer buf c + Continue + +def getEnv (name:String) : {IO} Maybe String = + withCString name \(MkCString ptr). + fromCString $ MkCString $ %ffi getenv RawPtr ptr + +def checkEnv (name:String) : {IO} Bool = + isJust $ getEnv name + +def fread (stream:Stream ReadMode) : {IO} String = + (MkStream stream') = stream + -- TODO: allow reading longer files! + n = 4096 + withAlloc n \ptr:(Ptr Char). + withDynamicBuffer \buf. + iter \_. + (MkPtr rawPtr) = ptr + numRead = I64ToI $ %ffi fread Int64 rawPtr (IToI64 1) (IToI64 n) stream' + extendDynBuffer buf $ stringFromCharPtr numRead ptr + if numRead == n + then Continue + else Done () + loadDynBuffer buf + +def deleteFile (f:FilePath) : {IO} Unit = + withCString f \(MkCString ptr). + %ffi remove Int64 ptr + () + +def withFile (f:FilePath) (mode:StreamMode) + (action: Stream mode -> {IO} a) + : {IO} a = + stream = fopen f mode + result = action stream + fclose stream + result + +def writeFile (f:FilePath) (s:String) : {IO} Unit = + withFile f WriteMode \stream. fwrite stream s + +def readFile (f:FilePath) : {IO} String = + withFile f ReadMode \stream. fread stream + +def newTempFile (_:Unit) : {IO} FilePath = + withCString "/tmp/dex-XXXXXX" \(MkCString ptr). + fd = %ffi mkstemp Int32 ptr + %ffi close Int32 fd + stringFromCharPtr 15 (MkPtr ptr) + +def withTempFile (action: FilePath -> {IO} a) : {IO} a = + tmpFile = newTempFile () + result = action tmpFile + deleteFile tmpFile + result + +def withTempFiles (action: (n=>FilePath) -> {IO} a) : {IO} a = + tmpFiles = for i. newTempFile () + result = action tmpFiles + for i. deleteFile tmpFiles.i + result + +def getOutputStream (_:Unit) : {IO} Stream WriteMode = + MkStream $ %ptrLoad OUT_STREAM_PTR + +def print (s:String) : {IO} Unit = + fwrite (getOutputStream ()) (s <> "\n") + +def shellOut (command:String) : {IO} String = + modeStr = "r" + withCString command \(MkCString commandPtr). + withCString modeStr \(MkCString modePtr). + pipe = MkStream %ffi popen RawPtr commandPtr modePtr + fread pipe + +'Partial functions + +def error (s:String) : a = unsafeIO do + print s + %throwError a + +def todo : a = error "TODO: implement it!" + +def fromOrdinal (n:Type) (i:Int) : n = + case (0 <= i) && (i < size n) of + True -> unsafeFromOrdinal _ i + False -> error $ + "Ordinal index out of range:" <> show i <> " >= " <> show (size n) + +-- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy +-- TODO: safe (runtime-checked) and unsafe versions +def castTable (m:Type) (xs:n=>a) : m=>a = + case size m == size n of + True -> unsafeCastTable _ xs + False -> error $ + "Table size mismatch in cast: " <> show (size m) <> " vs " <> show (size n) + +def asidx (i:Int) : n = fromOrdinal n i +def (@) (i:Int) (n:Type) : n = fromOrdinal n i + +def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = + for i. xs.(fromOrdinal _ (ordinal i + start)) + +def head (xs:n=>a) : a = xs.(0@_) + +def tail (xs:n=>a) (start:Int) : List a = + numElts = size n - start + toList $ slice xs start (Fin numElts) + +def randIdx (k:Key) : n = + unif = rand k + fromOrdinal n $ FToI $ floor $ unif * IToF (size n) + +'Type class for generating example values + +interface Arbitrary a + arb : Key -> a + +instance Arbitrary Float32 + arb = randn + +instance Arbitrary Int32 + arb = \key. FToI $ randn key * 5.0 + +instance [Arbitrary a] Arbitrary (n=>a) + arb = \key. for i. arb $ ixkey key i + +instance Arbitrary (Fin n) + arb = randIdx + +'Control flow + +-- returns the highest index `i` such that `xs.i <= x` +def searchSorted [Ord a] (xs:n=>a) (x:a) : Maybe n = + if size n == 0 + then Nothing + else if x < xs.(fromOrdinal _ 0) + then Nothing + else withState 0 \low. withState (size n) \high. iter \_. + numLeft = get high - get low + if numLeft == 1 + then Done $ Just $ fromOrdinal _ $ get low + else + centerIx = get low + idiv numLeft 2 + if x < xs.(fromOrdinal _ centerIx) + then high := centerIx + else low := centerIx + Continue + +'min / max etc + +def minBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y +def maxBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y + +def min [Ord o] (x1: o) -> (x2: o) : o = minBy id x1 x2 +def max [Ord o] (x1: o) -> (x2: o) : o = maxBy id x1 x2 + +def minimumBy [Ord o] (f:a->o) (xs:n=>a) : a = + reduce xs.(0@_) (minBy f) xs +def maximumBy [Ord o] (f:a->o) (xs:n=>a) : a = + reduce xs.(0@_) (maxBy f) xs + +def minimum [Ord o] (xs:n=>o) : o = minimumBy id xs +def maximum [Ord o] (xs:n=>o) : o = maximumBy id xs + +def argmin [Ord o] (xs:n=>o) : n = + zeroth = (0@_, xs.(0@_)) + compare = \(idx1, x1) (idx2, x2). + select (x1 < x2) (idx1, x1) (idx2, x2) + zipped = for i. (i, xs.i) + fst $ reduce zeroth compare zipped + +def clip [Ord a] ((low,high):(a&a)) (x:a) : a = + min high $ max low x -'## Trigonometric functions. +'## Trigonometric functions def atan_inner (x:Float) : Float = -- From "Computing accurate Horner form approximations to @@ -887,7 +1288,7 @@ def atan_inner (x:Float) : Float = r = r * s r * x + x -def min_and_max (_: Ord a) ?=> (x:a) (y:a) : (a & a) = +def min_and_max [Ord a] (x:a) (y:a) : (a & a) = select (x < y) (x, y) (y, x) -- get both with one comparison. def atan2 (y:Float) (x:Float) : Float = @@ -900,7 +1301,7 @@ def atan2 (y:Float) (x:Float) : Float = (min_abs_x_y, max_abs_x_y) = min_and_max abs_x abs_y a = atan_inner (min_abs_x_y / max_abs_x_y) a = select (abs_x <= abs_y) ((pi / 2.0) -a) a - a = select (x < 0.0) pi a + a = select (x < 0.0) (pi - a) a t = select (x < 0.0) pi 0.0 a = select (y == 0.0) t a t = select (x < 0.0) (3.0 * pi / 4.0) (pi / 4.0) @@ -910,33 +1311,32 @@ def atan2 (y:Float) (x:Float) : Float = def atan (x:Float) : Float = atan2 x 1.0 - '## Complex numbers data Complex = MkComplex Float Float -- real, imaginary -instance allCloseComplex : HasAllClose Complex where +instance HasAllClose Complex allclose = \atol rtol (MkComplex a b) (MkComplex c d). (a ~~ c) && (b ~~ d) -instance defaultToleranceComplex : HasDefaultTolerance Complex where +instance HasDefaultTolerance Complex atol = MkComplex atol atol rtol = MkComplex rtol rtol -@instance ComplexEq : Eq Complex = - MkEq \(MkComplex a b) (MkComplex c d). (a == c) && (b == d) +instance Eq Complex + (==) = \(MkComplex a b) (MkComplex c d). (a == c) && (b == d) -instance ComplexAdd : Add Complex where +instance Add Complex add = \(MkComplex a b) (MkComplex c d). MkComplex (a + c) (b + d) sub = \(MkComplex a b) (MkComplex c d). MkComplex (a - c) (b - d) zero = MkComplex 0.0 0.0 -instance ComplexMul : Mul Complex where +instance Mul Complex mul = \(MkComplex a b) (MkComplex c d). MkComplex (a * c - b * d) (a * d + b * c) one = MkComplex 1.0 0.0 -@instance complexVS : VSpace Complex = - MkVSpace ComplexAdd \a:Float (MkComplex c d):Complex. MkComplex (a * c) (a * d) +instance VSpace Complex + scaleVec = \a:Float (MkComplex c d):Complex. MkComplex (a * c) (a * d) -- Todo: Hook up to (/) operator. Might require two-parameter VSpace. def complex_division (MkComplex a b:Complex) (MkComplex c d:Complex): Complex = @@ -975,7 +1375,7 @@ def complex_tanh (MkComplex a b:Complex) : Complex = den = MkComplex (cosh a * cos b) (sinh a * sin b) complex_division num den -instance ComplexFractional : Fractional Complex where +instance Fractional Complex divide = complex_division def complex_floor (MkComplex re im:Complex) : Complex = @@ -1008,7 +1408,7 @@ def complex_log1p (x:Complex) : Complex = True -> complex_log u False -> divide ((complex_log u) * x) x -instance complexFloating : Floating Complex where +instance Floating Complex exp = complex_exp exp2 = complex_exp2 log = complex_log @@ -1036,34 +1436,13 @@ def (>>) (x:Byte) (y:Int) : Byte = %shr x (IToW8 y) def (.|.) (x:Byte) (y:Byte) : Byte = %or x y def (.&.) (x:Byte) (y:Byte) : Byte = %and x y -'## Raw pointer operations - -def Ptr (ty:Type) : Type = %makePtrType ty - -def tabToPtr (n:Int) ?-> (xs:(Fin n)=>Float) : Ptr Float = - %getPtr xs - -def ptrToTab (n:Int) (ptr:Ptr Float) : Fin n => Float = - for i:(Fin n). %ptrLoad (%ptrOffset ptr (ordinal i)) - -'## Misc +'## Miscellaneous utilities --- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy --- TODO: safe (runtime-checked) and unsafe versions -def castTable (m:Type) (xs:n=>a) : m=>a = - case size m == size n of - True -> for i. xs.(unsafeFromOrdinal _ (ordinal i)) - False -> throw +def reverse (x:n=>a) : n=>a = + s = size n + for i. x.((s - 1 - ordinal i)@_) -def toList (n:Type) ?-> (xs:n=>a) : List a = - n' = size n - AsList _ $ castTable (Fin n') xs - -def tail (n:Type) ?-> (xs:n=>a) (start:Int) : List a = - numElts = size n - start - toList $ slice xs start (Fin numElts) - -def padTo (n:Type) ?-> (m:Type) (x:a) (xs:n=>a) : (m=>a) = +def padTo (m:Type) (x:a) (xs:n=>a) : (m=>a) = n' = size n for i. i' = ordinal i @@ -1077,17 +1456,16 @@ def fromJust (x:Maybe a) : a = case x of Just x' -> x' def anySat (f:a -> Bool) (xs:n=>a) : Bool = any (map f xs) --- In Haskell this would just be `mapM`. The equivalent for us would be having --- an exception effect. -def seqMaybes (xs : n=>Maybe a) : Maybe (n => a) = +-- XXX: we use this internally so it's important to make the type args explicit +def seqMaybes (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) = -- is it possible to implement this safely? (i.e. without using partial -- functions) case anySat isNothing xs of True -> Nothing False -> Just $ map fromJust xs -def linearSearch (_:Eq a) ?=> (xs:n=>a) (query:a) : Maybe n = - snd $ withState Nothing \ref. for i. +def linearSearch [Eq a] (xs:n=>a) (query:a) : Maybe n = + yieldState Nothing \ref. for i. case xs.i == query of True -> ref := Just i False -> () @@ -1098,13 +1476,79 @@ def listLength ((AsList n _):List a) : Int = n -- TODO: we want this for any monoid but this implementation won't work. def concat (lists:n=>(List a)) : List a = totalSize = sum for i. listLength lists.i - AsList _ $ fst $ withState 0 \listIdx. - fst $ withState 0 \eltIdx. + AsList _ $ withState 0 \listIdx. + withState 0 \eltIdx. for i:(Fin totalSize). - while (\(). get eltIdx >= listLength (lists.((get listIdx)@_))) \(). - eltIdx := 0 - listIdx := get listIdx + 1 + while do + continue = get eltIdx >= listLength (lists.((get listIdx)@_)) + if continue + then + eltIdx := 0 + listIdx := get listIdx + 1 + else () + continue (AsList _ xs) = lists.((get listIdx)@_) eltIdxVal = get eltIdx eltIdx := eltIdxVal + 1 xs.(eltIdxVal@_) + +def cumSumLow (xs: n=>Float) : n=>Float = + withState 0.0 \total. + for i. + oldTotal = get total + total := oldTotal + xs.i + oldTotal + +-- cdf should include 0.0 but not 1.0 +def categoricalFromCDF (cdf: n=>Float) (key: Key) : n = + r = rand key + case searchSorted cdf r of + Just i -> i + +def normalizePdf (xs: d=>Float) : d=>Float = xs / sum xs + +def cdfForCategorical (logprobs: n=>Float) : n=>Float = + maxLogProb = maximum logprobs + cumSumLow $ normalizePdf $ map exp $ for i. logprobs.i - maxLogProb + +def categorical (logprobs: n=>Float) (key: Key) : n = + categoricalFromCDF (cdfForCategorical logprobs) key + +-- batch variant to share the work of forming the cumsum +-- (alternatively we could rely on hoisting of loop constants) +def categoricalBatch (logprobs: n=>Float) (key: Key) : m=>n = + cdf = cdfForCategorical logprobs + for i. categoricalFromCDF cdf $ ixkey key i + +'Numerical utilities + +def logsumexp (x: n=>Float) : Float = + m = maximum x + m + (log $ sum for i. exp (x.i - m)) + +def logsoftmax (x: n=>Float) : n=>Float = + lse = logsumexp x + for i. x.i - lse + +def softmax (x: n=>Float) : n=>Float = + m = maximum x + e = for i. exp (x.i - m) + s = sum e + for i. e.i / s + +def evalpoly [VSpace v] (coefficients:n=>v) (x:Float) : v = + -- Evaluate a polynomial at x. Same as Numpy's polyval. + fold zero \i c. coefficients.i + x .* c + +def dex_test_mode (():Unit) : Bool = unsafeIO do checkEnv "DEX_TEST_MODE" + +'## Exception effect + +def catch (f:Unit -> {Except|eff} a) : {|eff} Maybe a = + %catchException f + +def throw (_:Unit) : {Except} a = + %throwException a + +def assert (b:Bool) : {Except} Unit = + if not b then throw () diff --git a/makefile b/makefile index 3fb4fb955..d1b74519d 100644 --- a/makefile +++ b/makefile @@ -60,88 +60,100 @@ tc: dexrt-llvm build: dexrt-llvm $(STACK) build $(STACK_FLAGS) +watch: dexrt-llvm + $(STACK) build $(STACK_FLAGS) --file-watch + install: dexrt-llvm $(STACK) install $(STACK_BIN_PATH) --flag dex:optimized $(STACK_FLAGS) build-prof: dexrt-llvm $(STACK) build $(PROF) -dexrt-llvm: src/lib/dexrt.bc - # For some reason stack fails to detect modifications to foreign library files -build-python: build +build-python: dexrt-llvm $(STACK) build $(STACK_FLAGS) --force-dirty $(eval STACK_INSTALL_DIR=$(shell stack path --local-install-root)) cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/ +build-ci: dexrt-llvm + $(STACK) build $(STACK_FLAGS) --force-dirty --ghc-options "-Werror -fforce-recomp" + +dexrt-llvm: src/lib/dexrt.bc + %.bc: %.cpp clang++ $(CXXFLAGS) -c -emit-llvm $^ -o $@ # --- running tests --- -# TODO: re-enable linear-tests ad-tests include-test chol -example-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ - shadow-tests monad-tests \ - ad-tests mandelbrot pi sierpinski \ +example-names = mandelbrot pi sierpinski rejection-sampler \ regression brownian_motion particle-swarm-optimizer \ - ode-integrator parser-tests serialize-tests \ - mcmc record-variant-tests simple-include-test ctc raytrace \ - isomorphisms typeclass-tests complex-tests trig-tests \ - ode-integrator linear_algebra fluidsim + ode-integrator mcmc ctc raytrace particle-filter \ + isomorphisms ode-integrator linear_algebra fluidsim \ + sgd chol -quine-test-targets = $(example-names:%=run-%) +test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ + shadow-tests monad-tests io-tests exception-tests \ + ad-tests parser-tests serialize-tests parser-combinator-tests \ + record-variant-tests typeclass-tests complex-tests trig-tests -update-targets = $(example-names:%=update-%) +lib-names = diagram plot png -doc-names = $(example-names:%=doc/%.html) +all-names = $(test-names:%=tests/%) $(example-names:%=examples/%) -tests: quine-tests repl-test export-tests +quine-test-targets = $(all-names:%=run-%) -quine-tests: $(quine-test-targets) +update-test-targets = $(test-names:%=update-tests-%) +update-example-targets = $(example-names:%=update-examples-%) + +doc-example-names = $(example-names:%=doc/examples/%.html) + +doc-lib-names = $(lib-names:%=doc/lib/%.html) -quine-tests-interp: runinterp-eval-tests runinterp-ad-tests-interp runinterp-interp-tests +tests: quine-tests repl-test + +quine-tests: $(quine-test-targets) run-%: export DEX_ALLOW_CONTRACTIONS=0 -run-%: examples/%.dx build +run-%: export DEX_TEST_MODE=t + +run-tests/%: tests/%.dx build + misc/check-quine $< $(dex) script --allow-errors +run-examples/%: examples/%.dx build misc/check-quine $< $(dex) script --allow-errors # Run these with profiling on while they're catching lots of crashes prop-tests: cbits/libdex.so $(STACK) test $(PROF) -update-all: $(update-targets) - update-%: export DEX_ALLOW_CONTRACTIONS=0 -update-%: examples/%.dx build +update-%: export DEX_TEST_MODE=t + +update-all: $(update-test-targets) $(update-example-targets) + +update-tests-%: tests/%.dx build + $(dex) script --allow-errors $< > $<.tmp + mv $<.tmp $< + +update-examples-%: examples/%.dx build $(dex) script --allow-errors $< > $<.tmp mv $<.tmp $< run-gpu-tests: export DEX_ALLOC_CONTRACTIONS=0 -run-gpu-tests: examples/gpu-tests.dx build - misc/check-quine $< $(dex) --backend LLVM-CUDA script --allow-errors +run-gpu-tests: tests/gpu-tests.dx build + misc/check-quine $< $(dex) --backend llvm-cuda script --allow-errors update-gpu-tests: export DEX_ALLOW_CONTRACTIONS=0 -update-gpu-tests: examples/gpu-tests.dx build - $(dex) --backend LLVM-CUDA script --allow-errors $< > $<.tmp +update-gpu-tests: tests/gpu-tests.dx build + $(dex) --backend llvm-cuda script --allow-errors $< > $<.tmp mv $<.tmp $< -export-tests: export-test-scalar export-test-array - -export-test-%: build - $(dex) export examples/export/$*.dx examples/export/$*.o - $(CXX) -std=c++11 examples/export/$*.o examples/export/$*.cpp -o examples/export/$* - examples/export/$* - -jax-tests: build - misc/check-quine examples/jax-tests.dx $(dex) --backend JAX script - uexpr-tests: misc/check-quine examples/uexpr-tests.dx $(dex) script repl-test: misc/check-no-diff \ - examples/repl-multiline-test-expected-output \ - <($(dex) repl < examples/repl-multiline-test.dx) + tests/repl-multiline-test-expected-output \ + <($(dex) repl < tests/repl-multiline-test.dx) # --- running and querying benchmarks --- @@ -160,16 +172,21 @@ bench-summary: # --- building docs --- -slow-docs = doc/mnist-nearest-neighbors.html +slow-docs = doc/examples/mnist-nearest-neighbors.html + +docs: doc-prelude $(doc-example-names) $(doc-lib-names) $(slow-docs) -docs: doc/style.css $(doc-names) $(slow-docs) - $(dex) --prelude /dev/null script prelude.dx --html > doc/prelude.html +doc-prelude: lib/prelude.dx + mkdir -p doc + $(dex) --prelude /dev/null script lib/prelude.dx --outfmt html > doc/prelude.html -doc/%.html: examples/%.dx - $(dex) script $^ --outfmt HTML > $@ +doc/examples/%.html: examples/%.dx + mkdir -p doc/examples + $(dex) script $^ --outfmt html > $@ -doc/%.css: static/%.css - cp $^ $@ +doc/lib/%.html: lib/%.dx + mkdir -p doc/lib + $(dex) script $^ --outfmt html > $@ clean: $(STACK) clean diff --git a/misc/check-quine b/misc/check-quine index c592d32fd..0bde916df 100755 --- a/misc/check-quine +++ b/misc/check-quine @@ -26,6 +26,7 @@ if ${@:2} $1 > $tmpout 2> $errout ; then misc/check-no-diff $1 $tmpout status=$? else + status=$? cat $tmpout fi diff --git a/misc/dex.el b/misc/dex.el index 85f0acad4..c7bf86dad 100644 --- a/misc/dex.el +++ b/misc/dex.el @@ -10,7 +10,7 @@ ("^'\\(.\\|\n.\\)*\n\n" . font-lock-comment-face) ("\\w+:" . font-lock-comment-face) ("^:\\w*" . font-lock-preprocessor-face) - ("\\bdef\\b\\|\\bfor\\b\\|\\brof\\b\\|\\bcase\\b\\|\\bdata\\b\\|\\bwhere\\b\\|\\bof\\b\\|\\bif\\b\\|\\bthen\\b\\|\\belse\\b\\|\\binterface\\b\\|\\binstance\\b" . + ("\\bdef\\b\\|\\bfor\\b\\|\\brof\\b\\|\\bcase\\b\\|\\bdata\\b\\|\\bwhere\\b\\|\\bof\\b\\|\\bif\\b\\|\\bthen\\b\\|\\belse\\b\\|\\binterface\\b\\|\\binstance\\b\\|\\bdo\\b\\|\\bview\\b" . font-lock-keyword-face) ("--o" . font-lock-variable-name-face) ("[-.,!$^&*:~+/=<>|?\\\\]" . font-lock-variable-name-face) diff --git a/misc/dex.tex b/misc/dex.tex deleted file mode 100644 index 839328fbf..000000000 --- a/misc/dex.tex +++ /dev/null @@ -1,56 +0,0 @@ -\documentclass[12pt]{article} -\usepackage{amsmath} -\usepackage{geometry} -\geometry{legalpaper, landscape, margin=0in} - -\newcommand{\annot}[1]{\texttt{::} #1} -\newcommand{\ttt}[1]{~\texttt{#1}~} - -\begin{document} - -\vspace{-0.5cm} - -\begin{huge} -\begin{align*} -\text{Terms } \quad t ::&=~ -l \mid x - && \text{Literal / variable} \\ -&\mid \ttt{let} x \annot{\tau} = t \ttt{in} t - && \text{Let expression} \\ -&\mid \ttt{lam} x \annot{\tau} \ttt {.} t - \mid t ~ t - && \text{Lambda abstraction / application} \\ -&\mid \ttt{tlam} a \, . ~ t - \mid t ~ \tau - && \text{Type-lambda abstraction / application} \\ -&\mid \ttt{for} i \annot{\iota} \ttt {.} t - \mid t.t - && \text{Index comprehension / indexing} \\ -&\mid \ttt{pack} t, \iota \ttt{::} \exists n. ~ \tau - && \text{Existential packing} \\ -&\mid \ttt{let} x, n = \ttt{unpack} t \ttt{in} t - && \text{Existential unpacking} -\\ \\ -\text{Types} \quad \tau, \iota ::&= - \ttt{Int} | \ttt{ Real} \mid \ttt{Bool} \mid a - && \text{Base types and type variable} \\ -&\mid \tau \ttt{->} \tau && \text{Arrow type} \\ -&\mid \forall a. ~ \tau - && \text{Universal quantification} \\ -&\mid \iota \ttt{=>} \tau && \text{Table type} \\ -&\mid \ttt{\{<}l\ttt{\}} && \text{Index set literal} \\ -&\mid \exists n. ~ \tau - && \text{Existential quantification} -\end{align*} -% -\begin{align*} -\text{Term variables} \quad x, i \qquad -\text{Type variables} \quad a, n \qquad -\text{Literals} \quad l -\end{align*} -\end{huge} - -\end{document} - -%% to convert to slides-friendly png: -%% convert -alpha remove -density 300 -quality 85 dex.pdf -transparent white dex.png diff --git a/misc/py/dex.py b/misc/py/dex.py deleted file mode 100644 index c90cb21f9..000000000 --- a/misc/py/dex.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import print_function -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import ctypes -import json - -libname = "./dex2jax.so" - -lib = ctypes.cdll.LoadLibrary(libname) -lib.hs_init(0, 0) # TODO should call lib.hs_exit() when done - -def setup_f(fname): - f = getattr(lib, fname) - f.argtypes = [ctypes.c_char_p] - f.restype = ctypes.c_char_p - return lambda x: json.loads(f(json.dumps(x))) - -loadSource, = map(setup_f, ["loadSource"]) - -class DexModule(object): - def __init__(self, functions): - for fname, definition in functions: - self.__dict__[fname] = definition - -def load(fname): - with open(fname) as f: - s = f.read() - top_level_functions = loadSource(s) - print(top_level_functions) - return DexModule(top_level_functions) diff --git a/misc/py/dex_binary_object.py b/misc/py/dex_binary_object.py deleted file mode 100644 index 9fee60af8..000000000 --- a/misc/py/dex_binary_object.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import itertools as it -from collections import namedtuple -import numpy as np - -TabType = namedtuple('TabType', ['index_set', 'element_type']) - -preheader_length = 81 -preheader_start = "-- dex-object-file-v0.0.1 num-header-bytes " - -def dump(obj, f): - ty = get_dex_ty(obj) - buffers = flatten_to_buffers(obj) - ty_str = "type: {}\n".format(pprint_ty(ty)) - sizes_str = "bufferSizes: [{}]\n".format(", ".join([str(get_buffer_size(x)) - for x in buffers])) - header_size = preheader_length + len(ty_str) + len(sizes_str) - pre_header_str = make_preheader(header_size) - header = pre_header_str + ty_str + sizes_str - assert header_size == len(header) - f.write(header) - f.flush() - for b in buffers: - buf_bytes = b.tobytes() - assert len(buf_bytes) == get_buffer_size(b), \ - "{} {} != {}".format(b, len(buf_bytes), get_buffer_size(b)) - f.buffer.write(buf_bytes) - f.flush() - -def get_dex_ty(obj): - if isinstance(obj, tuple): - return tuple(get_dex_ty(x) for x in obj) - elif isinstance(obj, np.ndarray): - base_ty = dtype_to_dex_ty(obj.dtype) - return make_tab_type(base_ty, obj.shape) - elif isinstance(obj, float): - return float - elif isinstance(obj, bool): - return bool - elif isinstance(obj, int): - return int - else: - raise Exception("No corresponding Dex type for {}".format(type(obj))) - -def flatten_to_buffers(obj): - if isinstance(obj, tuple): - return tuple(it.chain(*(flatten_to_buffers(x) for x in obj))) - elif isinstance(obj, np.ndarray): - flat_array = obj.ravel() - if obj.dtype == np.bool: - return [np.asarray(flat_array, dtype=np.int64)] - else: - return [flat_array] - elif isinstance(obj, float): - return [np.array(obj, dtype=np.float64)] - elif isinstance(obj, bool): - return [np.array(obj, dtype=np.int64)] - elif isinstance(obj, int): - return [np.array(obj, dtype=np.int64)] - else: - raise Exception("No corresponding Dex type for {}".format(type(obj))) - -def dtype_to_dex_ty(dtype): - if dtype == np.float64: - return float - elif dtype == np.int64: - return int - elif dtype == np.bool: - return bool - else: - raise Exception("Unrecognized dtype: " + str(dtype)) - -def make_tab_type(base_ty, shape): - shape = tuple(shape) - if shape == (): - return base_ty - else: - (n, *rest) = shape - return TabType(n, make_tab_type(base_ty, rest)) - -def get_buffer_size(array): - return array.size * 8 - -def pprint_ty(ty): - if isinstance(ty, TabType): - return "{}=>{}".format(str(ty.index_set), pprint_ty(ty.element_type)) - elif isinstance(ty, tuple): - return "({})".format(", ".join(map(pprint_ty, ty))) - if ty is int: - return "Int" - elif ty is float: - return "Real" - elif ty is bool: - return "Bool" - else: - raise Exception("Can't print type: {}".format(ty)) - -def make_preheader(n): - preheader_prefix = preheader_start + str(n) + " " - padding = '-' * (preheader_length - len(preheader_prefix) - 1) + "\n" - return preheader_prefix + padding diff --git a/misc/py/foo.dx b/misc/py/foo.dx deleted file mode 100644 index 3dc99813e..000000000 --- a/misc/py/foo.dx +++ /dev/null @@ -1,6 +0,0 @@ - - -addFloats :: Float -> Float -> Float -addFloats x y = x + y - - diff --git a/misc/py/generate-dex-data.py b/misc/py/generate-dex-data.py deleted file mode 100644 index 7981aefcd..000000000 --- a/misc/py/generate-dex-data.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -from collections import namedtuple -import numpy as np -import dex_binary_object as dbo - -data = (1.2, - 12, - (), - True, - False, - (-2, np.array([1.0, 2.0, 3.0])), - np.array([[10.0, 20.0, 30.0], [0.1, 0.2, 0.3]]) , - np.array([[10.0, 20.0, 30.0], [0.1, 0.2, 0.3]]).T, - 1.3, - np.array(0.123), - np.array([[[1]]]), - np.array([6,5,4,3]), - np.array([True, False, True])) - -with open("test-scratch/pydata.dxbo", "w") as f: - dbo.dump(data, f) diff --git a/misc/py/jax_call.py b/misc/py/jax_call.py deleted file mode 100644 index d92690076..000000000 --- a/misc/py/jax_call.py +++ /dev/null @@ -1,490 +0,0 @@ -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import collections -import json -import pprint -import sys -import pprint as pp -import traceback -import numpy as np -import jax.numpy as jnp -from jax import jit, make_jaxpr, xla_computation -from jax import random -from jax import lax - -scary_map = map - -def map(f, *args): - return list(scary_map(f, *args)) - -class JaxFunction(object): - def __init__(self, binders, decls, results): - for b in binders: assert isinstance(b, Var) - for b, op in decls: - assert isinstance(b, Var) - assert isinstance(op, Operation) - for r in results: assert isinstance(r, Atom) - self.binders = binders - self.decls = decls - self.results = results - - def ser(self): - assert False - - @staticmethod - def des(obj): - binders_ser, (decls_ser, results_ser) = obj - binders = map(Var.des, binders_ser) - results = map(Atom.des, results_ser) - decls = [(Var.des(b), Operation.des(op)) for (b, op) in decls_ser] - return JaxFunction(binders, decls, results) - -class Name(object): - def __init__(self, namespace, root, i): - assert isinstance(i, int) - assert isinstance(namespace, str) - assert isinstance(root, str) - self._name = (namespace, root, i) - - @staticmethod - def des(obj): - namespace, root, i = obj - return Name(namespace, root, i) - - def ser(self): - return {"tag":"Name", "contents": list(self._name)} - - def __repr__(self): return str(self) - def __str__(self): - (_, root, i) = self._name - if i == 0: - return root - else: - return root + str(i) - - def __eq__(self, other): - assert isinstance(other, Name) - return self._name == other._name - - def __hash__(self): - return hash(self._name) - -class IdxVar(object): - def __init__(self, name, size): - assert isinstance(name, Name) - assert isinstance(size, int) - self.name = name - self.size = size - - def __repr__(self): return str(self) - def __str__(self): - return str(self.name) + ":" + str(self.size) - - def __eq__(self, other): - assert isinstance(other, IdxVar) - return self.name == other.name - - def __hash__(self): - return hash(self.name) - - @staticmethod - def des(obj): - name, idxSize = obj - assert name["tag"] == "Name" - return IdxVar(Name.des(name["contents"]), idxSize) - -class Var(object): - def __init__(self, name, ty): - assert isinstance(ty, Ty) - assert isinstance(name, Name) - self.name = name - self.ty = ty - - def __repr__(self): return str(self) - def __str__(self): - return str(self.name) + ":" + str(self.ty) - - def __eq__(self, other): - assert isinstance(other, Var) - return self.name == other.name - - def __hash__(self): - return hash(self.name) - - def ser(self): - return [self.name.ser(), self.ty.ser()] - - @staticmethod - def des(obj): - name, (shape, basetype) = obj - assert name["tag"] == "Name" - return Var(Name.des(name["contents"]), Ty(shape, basetype)) - -class Atom(object): - def __init__(self, case, data): - self.case = case - if case == "Var": - assert isinstance(data, Var) - self.var = data - elif case == "Lit": - assert isinstance(data, arrayish_types), type(data) - self.val = data - else: - assert False - - def __repr__(self): return str(self) - def __str__(self): - if self.case == "Var": - return str(self.var) - elif self.case == "Lit": - return str(self.val) - else: - assert False - - @property - def ty(self): - if self.case == "Var": - return self.var.ty - elif self.case == "Lit": - x = self.val - return array_ty(x) - else: - assert False - - @staticmethod - def des(obj): - if obj["tag"] == "JVar": - val = obj["contents"] - return Atom("Var", Var.des(val)) - elif obj["tag"] == "JLit": - shape, vec = obj["contents"] - val = np.array(vec["contents"], dtype=vec_dtype(vec)).reshape(shape) - return Atom("Lit", val) - -class IndexedAtom(object): - def __init__(self, atom, idxs): - assert isinstance(atom, Atom) - for i in idxs: assert isinstance(i, IdxVar) - self.atom = atom - self.idxs = idxs - - @property - def ty(self): - atom_ty = self.atom.ty - return Ty(atom_ty.shape[:len(self.idxs)], atom_ty.basetype) - - @staticmethod - def des(obj): - atom, idxs = obj - return IndexedAtom(Atom.des(atom), map(IdxVar.des, idxs)) - - def __repr__(self): return str(self) - def __str__(self): - return str(self.atom) + "".join("." + str(i) for i in self.idxs) - -class Ty(object): - def __init__(self, shape, basetype): - for n in shape: assert isinstance(n, int) - assert basetype in ["IntType", "BoolType", "RealType"] - self.basetype = basetype - self.shape = tuple(shape) - - def ser(self): - return [self.shape, self.basetype] - - def __eq__(self, other): - assert isinstance(other, Ty) - return self.basetype == other.basetype and self.shape == other.shape - - @staticmethod - def des(obj): - assert False - - def __repr__(self): return str(self) - def __str__(self): - return self.basetype + str(self.shape) - -MapIdx = "MapIdx" -SumIdx = "SumIdx" -class Operation(object): - def __init__(self, binders, op_name, size_args, args): - for (i, flavor) in binders: - assert isinstance(i, IdxVar) - assert flavor in (MapIdx, SumIdx) - - assert isinstance(op_name, str) - for size in size_args: assert isinstance(size, int) - for arg in args: assert isinstance(arg, IndexedAtom) - self.binders = binders - self.op_name = op_name - self.size_args = size_args - self.args = args - - @property - def all_idxs(self): - return [i for i, _ in self.binders] - - def ser(self): - assert False - - @staticmethod - def des(obj): - binders_ser, op_and_args_ser = obj - binders = [(IdxVar.des(i), fl) for i, fl in binders_ser] - op_name, size_args, args = des_op_and_args(op_and_args_ser) - return Operation(binders, op_name, size_args, args) - - def __repr__(self): return str(self) - def __str__(self): - return "for {} . {} {}".format( - self.binders, self.op_name, tuple(self.args)) - -def array_ty(x): - return Ty(x.shape, dtype_basetype(x.dtype)) - -def ser_array(arr): - assert isinstance(arr, arrayish_types) - return ser_flat_vec(arr.ravel()) - -def ser_flat_vec(vec): - if vec.dtype in [np.int32, np.int64]: - return {"tag":"IntVec", "contents": map(int, vec)} - if vec.dtype in [np.float32, np.float64]: - return {"tag":"DoubleVec", "contents": map(float, vec)} - else: - assert False - -def des_op_and_args(obj): - tag = obj["tag"] - if tag == "JScalarBinOp": - binop_name, x_ser, y_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - y = IndexedAtom.des(y_ser) - return binop_name["tag"], [], [x, y] - if tag == "JScalarUnOp": - unop_name, x_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - return unop_name, [], [x] - elif tag == "JIota": - size = obj["contents"] - assert isinstance(size, int) - return "Iota", [size], [] - elif tag == "JId": - x_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - return "Id", [], [x] - elif tag == "JGet": - x_ser, y_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - y = IndexedAtom.des(y_ser) - return "Get", [], [x, y] - elif tag == "JThreeFry2x32": - x_ser, y_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - y = IndexedAtom.des(y_ser) - return "ThreeFry2x32", [], [x, y] - else: - raise Exception("Not implemented: " + str(tag)) - -global_env = {} - -def eval_op(op): - if op.op_name in ("FMul", "IMul"): - ans = eval_einsum(op) - return Atom("Lit", ans) - else: - broadcast_ans = eval_for(op) - sum_axes = tuple(i for (i, (_, fl)) in enumerate(op.binders) if fl == SumIdx) - if sum_axes == (): - return Atom("Lit", broadcast_ans) - else: - summed_ans = np.sum(broadcast_ans, axis=sum_axes) - return Atom("Lit", summed_ans) - -def eval_einsum(op): - assert op.op_name in ("FMul", "IMul") - x, y = op.args - x_axes = [str(i.name) for i in x.idxs] - y_axes = [str(i.name) for i in y.idxs] - out_axes = [str(i.name) for i, f in op.binders if f != SumIdx] - return jnp.einsum(x.atom.val, x_axes, y.atom.val, y_axes, out_axes) - -def eval_for(op): - if op.op_name in ("IAdd", "IMul", "FAdd", "FMul", "FDiv"): - x, y = op.args - x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val) - y_bc = broadcast_dims(op.all_idxs, y.idxs, y.atom.val) - if op.op_name in ("IAdd", "FAdd"): - return jnp.add(x_bc, y_bc) - elif op.op_name in ("IMul", "FMul"): - return jnp.multiply(x_bc, y_bc) - if op.op_name in ("FDiv",): - return jnp.divide(x_bc, y_bc) - else: - raise Exception("Not implemented: " + str(op.op_name)) - elif op.op_name == "Iota": - n, = op.size_args - val = jnp.arange(n) - val_bc = broadcast_dims(op.all_idxs, [], val) - return val_bc - elif op.op_name == "Id": - x, = op.args - x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val) - return x_bc - elif op.op_name == "Get": - x, idx = op.args - out_shape = [i.size for i in op.all_idxs] - x_idxs_used = get_stack_idxs_used(op.all_idxs, x.idxs) - leading_idx_arrays = [] - for i, idx_used in enumerate(x_idxs_used): - if idx_used: - leading_idx_arrays.append(nth_iota(out_shape, i)) - else: - pass - payload_idx_array = broadcast_dims(op.all_idxs, idx.idxs, idx.atom.val) - out = x.atom.val[tuple(leading_idx_arrays) + (payload_idx_array,)] - return out - elif op.op_name == "IntToReal": - x, = op.args - real_val = jnp.array(x.atom.val, dtype="float32") - x_bc = broadcast_dims(op.all_idxs, x.idxs, real_val) - return x_bc - elif op.op_name in ("FNeg", "INeg"): - x, = op.args - x_bc = broadcast_dims(op.all_idxs, x.idxs, jnp.negative(x.atom.val)) - return x_bc - elif op.op_name == "ThreeFry2x32": - convert_64_to_32s = lambda x: np.array([x]).view(np.uint32) - convert_32s_to_64 = lambda x: np.int64(np.array(x).view(np.int64).item()) - x, y = op.args - key, count = convert_64_to_32s(x.atom.val), convert_64_to_32s(y.atom.val) - result = convert_32s_to_64(random.threefry_2x32(key, count)) - x_bc = broadcast_dims(op.all_idxs, x.idxs, result) - return x_bc - else: - raise Exception("Unrecognized op: {}".format(op.op_name)) - -def broadcast_dims(for_idxs, idxs, x): - shape = [i.size for i in for_idxs] - idxs_used = get_stack_idxs_used(for_idxs, idxs) - bcast_dims = [i for i, b in enumerate(idxs_used) if b] - return lax.broadcast_in_dim(x, shape, bcast_dims) - -def broadcast_with(x, final_shape, idxs_used): - rem_shape = list(x.shape[sum(idxs_used):]) - reshape_shape = [size if use else 1 for (size, use) in zip(final_shape, idxs_used)] - x_singletons = jnp.reshape(x, reshape_shape + rem_shape) - return jnp.broadcast_to(x_singletons, final_shape + rem_shape) - -def nth_iota(shape, i): - size = shape[i] - iota = jnp.arange(size) - idxs_used = [Discard for _ in shape] - idxs_used[i] = Use - return broadcast_with(iota, shape, idxs_used) - -Use = True -Discard = False -def get_stack_idxs_used(for_idxs, idxs): - stack_vars = [] - cur_idxs = list(idxs) - for i in for_idxs: - if cur_idxs and i == cur_idxs[0]: - stack_vars.append(Use) - cur_idxs = cur_idxs[1:] - else: - stack_vars.append(Discard) - return stack_vars - -arrayish_types = (jnp.ndarray, np.ndarray, np.int64, np.float64, np.float32) - -def subst_op(env, op): - args = [IndexedAtom(subst_atom(env, x.atom), x.idxs) for x in op.args] - return Operation(op.binders, op.op_name, op.size_args, args) - -def subst_atom(env, x): - assert isinstance(x, Atom) - if x.case == "Var": - return env[x.var] - elif x.case == "Lit": - return x - else: - assert False - -def dtype_basetype(x): - if x in [np.int32, np.int64]: - return "IntType" - elif x in [np.float32, np.float64]: - return "RealType" - else: - assert False, x - -def vec_dtype(vec): - if vec["tag"] == "IntVec": - return np.int64 - elif vec["tag"] == "DoubleVec": - return np.float64 - else: - assert False - -def atom_as_var(x): - assert isinstance(x, Atom) - i = len(global_env) - name = Name("ArrayName", "arr", i) - v = Var(name, x.ty) - assert v not in global_env - global_env[v] = x - return v - -def eval_function_application(top_arg): - def run(): - f = JaxFunction.des(top_arg[0]) - args = [Atom("Var", Var.des(x)) for x in top_arg[1]] - env = global_env.copy() - args_subst = [subst_atom(env, arg) for arg in args] - for v, arg in zip(f.binders, args_subst): - env[v] = arg - for (v, op) in f.decls: - ans = eval_op(subst_op(env, op)) - if not (v.ty == ans.ty): - print(op) - raise Exception("Unexpected type. Expected {}, got {}".format(v.ty, ans.ty)) - env[v] = ans - return [subst_atom(env, r).val for r in f.results] - outs = run() - irdump = str(make_jaxpr(run)()) - return [atom_as_var(Atom("Lit", out)).ser() for out in outs], irdump - -def check_type(ty, val): - assert isinstance(ty, Ty) - -def retrieve_arrays(arrs): - vs = map(Var.des, arrs) - return [ser_array(global_env[v].val) for v in vs] - -def just_print_it(obj): - print(obj) - return () - -def run_server(functions): - readChan, writeChan = sys.argv[1:] - with open(writeChan, "w") as w: - for line in open(readChan): - (f_idx, arg) = json.loads(line) - try: - f = functions[f_idx] - ans = {"Right" : f(arg)} - except Exception as e: - traceback.print_exc() - ans = {"Left": traceback.format_exc()} - w.write(json.dumps(ans) + "\n") - w.flush() - -if __name__ == "__main__": - run_server([eval_function_application, - retrieve_arrays, - just_print_it]) diff --git a/misc/py/mnist_to_dxbo.py b/misc/py/mnist_to_dxbo.py deleted file mode 100644 index cf3b3e339..000000000 --- a/misc/py/mnist_to_dxbo.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import sys -import numpy as np -import dex_binary_object as dbo -sys.path.append("../jax") -from examples import datasets - -def oneHotToInt(xs): - xsInt = np.sum(xs * np.arange(10)[None,:], axis=1).astype(np.int64) - print(xsInt.shape) - assert np.max(xsInt) == 9 - return xsInt - -data = tuple(x.astype(np.float64) for x in datasets.mnist()) -train_images, train_labels, test_images, test_labels = data - -train_images_unflat = train_images.reshape((60000, 28, 28)) -test_images_unflat = test_images.reshape( (10000, 28, 28)) - -train_labels_int = oneHotToInt(train_labels) -test_labels_int = oneHotToInt(test_labels) - -data_out = (train_images_unflat, train_labels_int, - test_images_unflat, test_labels_int) - -with open("scratch/mnist.dxbo", "w") as f: - dbo.dump(data_out, f) diff --git a/python/dex/__init__.py b/python/dex/__init__.py index 6850c436d..46ba9e088 100644 --- a/python/dex/__init__.py +++ b/python/dex/__init__.py @@ -5,124 +5,43 @@ # https://developers.google.com/open-source/licenses/bsd import itertools as it +import sys import ctypes -import pathlib -import atexit -from enum import Enum -from typing import List - -__all__ = ['execute'] - -here = pathlib.Path(__file__).parent.absolute() - -lib = ctypes.cdll.LoadLibrary(here / 'libDex.so') - -def tagged_union(name: str, members: List[type]): - named_members = [(f"t{i}", member) for i, member in enumerate(members)] - payload = type(name + "Payload", (ctypes.Union,), {"_fields_": named_members}) - union = type(name, (ctypes.Structure,), { - "_fields_": [("tag", ctypes.c_uint64), ("payload", payload)], - "value": property(lambda self: getattr(self.payload, f"t{self.tag}")), - }) - return union - -CLit = tagged_union("Lit", [ctypes.c_int64, ctypes.c_int32, ctypes.c_int8, ctypes.c_double, ctypes.c_float]) -class CRectArray(ctypes.Structure): - _fields_ = [("data", ctypes.c_void_p), - ("shape_ptr", ctypes.POINTER(ctypes.c_int64)), - ("strides_ptr", ctypes.POINTER(ctypes.c_int64))] -CAtom = tagged_union("CAtom", [CLit, CRectArray]) -assert ctypes.sizeof(CAtom) == 4 * 8 - -class HsAtom(ctypes.Structure): pass -class HsContext(ctypes.Structure): pass - -_init = lib.dexInit -_init.restype = None -_init.argtypes = [] - -_fini = lib.dexFini -_fini.restype = None -_fini.argtypes = [] - -_create_context = lib.dexCreateContext -_create_context.restype = ctypes.POINTER(HsContext) -_create_context.argtypes = [] - -_destroy_context = lib.dexDestroyContext -_destroy_context.restype = None -_destroy_context.argtypes = [ctypes.POINTER(HsContext)] - -_print = lib.dexPrint -_print.restype = ctypes.c_char_p -_print.argtypes = [ctypes.POINTER(HsAtom)] - -_insert = lib.dexInsert -_insert.restype = ctypes.POINTER(HsContext) -_insert.argtypes = [ctypes.POINTER(HsContext), ctypes.c_char_p, ctypes.POINTER(HsAtom)] - -_eval = lib.dexEval -_eval.restype = ctypes.POINTER(HsContext) -_eval.argtypes = [ctypes.POINTER(HsContext), ctypes.c_char_p] - -_evalExpr = lib.dexEvalExpr -_evalExpr.restype = ctypes.POINTER(HsAtom) -_evalExpr.argtypes = [ctypes.POINTER(HsContext), ctypes.c_char_p] - -_lookup = lib.dexLookup -_lookup.restype = ctypes.POINTER(HsAtom) -_lookup.argtypes = [ctypes.POINTER(HsContext), ctypes.c_char_p] - -_toCAtom = lib.dexToCAtom -_toCAtom.restype = ctypes.c_int -_toCAtom.argtypes = [ctypes.POINTER(HsAtom), ctypes.POINTER(CAtom)] - -_getError = lib.dexGetError -_getError.restype = ctypes.c_char_p -_getError.argtypes = [] - -_init() -_nofree = False -@atexit.register -def _teardown(): - global _nofree - _fini() - _nofree = True # Don't destruct any Haskell objects after the RTS has been shutdown - - -def _as_cstr(x: str): - return ctypes.c_char_p(x.encode('ascii')) - -def _from_cstr(cx): - return cx.value.decode('ascii') +from typing import Any, List, Union +from . import api +from .native_function import NativeFunction +__all__ = [ + 'Module', + 'eval', +] class Module: __slots__ = ('_as_parameter_',) def __init__(self, source): - self._as_parameter_ = _eval(prelude, _as_cstr(source)) + self._as_parameter_ = api.eval(prelude, api.as_cstr(source)) if not self._as_parameter_: - raise RuntimeError(_from_cstr(_getError())) + api.raise_from_dex() def __del__(self): - if _nofree: + if api.nofree: return - _destroy_context(self) + api.destroyContext(self) def __getattr__(self, name): - result = _lookup(self, _as_cstr(name)) + result = api.lookup(self, api.as_cstr(name)) if not result: - raise RuntimeError(_from_cstr(_getError())) + api.raise_from_dex() return Atom(result, self) class Prelude(Module): __slots__ = () def __init__(self): - self._as_parameter_ = _create_context() + self._as_parameter_ = api.createContext() if not self._as_parameter_: - raise RuntimeError("Failed to initialize prelude!") + api.raise_from_dex() prelude = Prelude() @@ -130,14 +49,14 @@ def __init__(self): def eval(expr: str, module=prelude, _env=None): if _env is None: _env = module - result = _evalExpr(_env, _as_cstr(expr)) + result = api.evalExpr(_env, api.as_cstr(expr)) if not result: - raise RuntimeError(_from_cstr(_getError())) + api.raise_from_dex() return Atom(result, module) class Atom: - __slots__ = ('_as_parameter_', 'module') + __slots__ = ('__weakref__', '_as_parameter_', 'module') def __init__(self, ptr, module): self._as_parameter_ = ptr @@ -148,7 +67,8 @@ def __del__(self): pass def __repr__(self): - return _print(self).decode('ascii') + # TODO: Free! + return api.from_cstr(api.print(self)) def __int__(self): return int(self._as_scalar()) @@ -157,12 +77,12 @@ def __float__(self): return float(self._as_scalar()) def _as_scalar(self): - result = CAtom() - success = _toCAtom(self, ctypes.pointer(result)) + result = api.CAtom() + success = api.toCAtom(self, ctypes.pointer(result)) if not success: - raise RuntimeError(_from_cstr(_getError())) + api.raise_from_dex() value = result.value - if not isinstance(value, CLit): + if not isinstance(value, api.CLit): raise TypeError("Atom is not a scalar value") return value.value @@ -173,6 +93,12 @@ def __call__(self, *args): # NB: Atoms can contain arbitrary references if atom.module is not prelude and atom.module is not self.module: raise RuntimeError("Mixing atoms coming from different Dex modules is not supported yet!") - old_env, env = env, _insert(env, _as_cstr(f"python_arg{i}"), atom) - _destroy_context(old_env) + old_env, env = env, api.insert(env, api.as_cstr(f"python_arg{i}"), atom) + api.destroyContext(old_env) return eval(" ".join(f"python_arg{i}" for i in range(len(args) + 1)), module=self.module, _env=env) + + def compile(self): + func_ptr = api.compile(api.jit, self.module, self) + if not func_ptr: + api.raise_from_dex() + return NativeFunction(api.jit, func_ptr) diff --git a/python/dex/api.py b/python/dex/api.py new file mode 100644 index 000000000..fcd881697 --- /dev/null +++ b/python/dex/api.py @@ -0,0 +1,98 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import ctypes +import pathlib +import atexit +from typing import List + +here = pathlib.Path(__file__).parent.absolute() + +lib = ctypes.cdll.LoadLibrary(here / 'libDex.so') + +def tagged_union(name: str, members: List[type]): + named_members = [(f"t{i}", member) for i, member in enumerate(members)] + payload = type(name + "Payload", (ctypes.Union,), {"_fields_": named_members}) + union = type(name, (ctypes.Structure,), { + "_fields_": [("tag", ctypes.c_uint64), ("payload", payload)], + "value": property(lambda self: getattr(self.payload, f"t{self.tag}")), + }) + return union + +CLit = tagged_union("Lit", [ctypes.c_int64, ctypes.c_int32, ctypes.c_int8, ctypes.c_double, ctypes.c_float]) +class CRectArray(ctypes.Structure): + _fields_ = [("data", ctypes.c_void_p), + ("shape_ptr", ctypes.POINTER(ctypes.c_int64)), + ("strides_ptr", ctypes.POINTER(ctypes.c_int64))] +CAtom = tagged_union("CAtom", [CLit, CRectArray]) +assert ctypes.sizeof(CAtom) == 4 * 8 + +class HsAtom(ctypes.Structure): pass +class HsContext(ctypes.Structure): pass +class HsJIT(ctypes.Structure): pass +class NativeFunctionObj(ctypes.Structure): pass +class NativeFunctionSignature(ctypes.Structure): + _fields_ = [("arg", ctypes.c_char_p), + ("res", ctypes.c_char_p), + ("ccall", ctypes.c_char_p)] + + +HsAtomPtr = ctypes.POINTER(HsAtom) +HsContextPtr = ctypes.POINTER(HsContext) +HsJITPtr = ctypes.POINTER(HsJIT) +CAtomPtr = ctypes.POINTER(CAtom) +NativeFunctionSignaturePtr = ctypes.POINTER(NativeFunctionSignature) +NativeFunction = ctypes.POINTER(NativeFunctionObj) + +def dex_func(name, *signature): + argtypes, restype = signature[:-1], signature[-1] + f = getattr(lib, name) + f.restype = restype + f.argtypes = argtypes + return f + +init = dex_func('dexInit', None) +fini = dex_func('dexFini', None) +getError = dex_func('dexGetError', ctypes.c_char_p) + +createContext = dex_func('dexCreateContext', HsContextPtr) +destroyContext = dex_func('dexDestroyContext', HsContextPtr, None) + +eval = dex_func('dexEval', HsContextPtr, ctypes.c_char_p, HsContextPtr) +insert = dex_func('dexInsert', HsContextPtr, ctypes.c_char_p, HsAtomPtr, HsContextPtr) +evalExpr = dex_func('dexEvalExpr', HsContextPtr, ctypes.c_char_p, HsAtomPtr) +lookup = dex_func('dexLookup', HsContextPtr, ctypes.c_char_p, HsAtomPtr) + +print = dex_func('dexPrint', HsAtomPtr, ctypes.c_char_p) +toCAtom = dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int) + +createJIT = dex_func('dexCreateJIT', HsJITPtr) +destroyJIT = dex_func('dexDestroyJIT', HsJITPtr, None) +compile = dex_func('dexCompile', HsJITPtr, HsContextPtr, HsAtomPtr, NativeFunction) +unload = dex_func('dexUnload', HsJITPtr, NativeFunction, None) + +getFunctionSignature = dex_func('dexGetFunctionSignature', HsJITPtr, NativeFunction, NativeFunctionSignaturePtr) +freeFunctionSignature = dex_func('dexFreeFunctionSignature', NativeFunctionSignaturePtr) + +init() +jit = createJIT() +nofree = False +@atexit.register +def _teardown(): + global nofree + destroyJIT(jit) + fini() + nofree = True # Don't destruct any Haskell objects after the RTS has been shutdown + + +def as_cstr(x: str): + return ctypes.c_char_p(x.encode('ascii')) + +def from_cstr(cx): + return cx.decode('ascii') + +def raise_from_dex(): + raise RuntimeError(from_cstr(getError())) diff --git a/misc/py/test-dex2jax.py b/python/dex/interop/__init__.py similarity index 52% rename from misc/py/test-dex2jax.py rename to python/dex/interop/__init__.py index 553642e93..6b607710e 100644 --- a/misc/py/test-dex2jax.py +++ b/python/dex/interop/__init__.py @@ -1,14 +1,5 @@ -from __future__ import print_function -# Copyright 2019 Google LLC +# Copyright 2020 Google LLC # # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd - -import dex - -foo = dex.load("foo.dx") - -# print foo.addReals(1.0, 2.0) - -print(foo.f) diff --git a/python/dex/interop/jax.py b/python/dex/interop/jax.py new file mode 100644 index 000000000..df40bb50e --- /dev/null +++ b/python/dex/interop/jax.py @@ -0,0 +1,146 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from weakref import WeakKeyDictionary +from functools import partial +from itertools import count +import ctypes +import numpy as np + +import jax +from jax.lib import xla_client as xc +from jax.interpreters import xla + +from .. import Atom +from ..native_function import IdxRepTy, ScalarType, RectContArrayType + +def primitive(f): + if not isinstance(f, Atom): + raise TypeError("DexPrimitive expects a function atom as an argument") + return partial(dex_call_p.bind, func_atom=f) + +compiler_cache = WeakKeyDictionary() +def get_compiled(func_atom): + compiled = compiler_cache.get(func_atom, None) + if compiled is None: + compiled = compiler_cache[func_atom] = func_atom.compile() + return compiled + + +dex_call_p = jax.core.Primitive('dex_call') + +@dex_call_p.def_impl +def dex_call_impl(*args, func_atom): + return get_compiled(func_atom)(*args) + +# === abstract evaluation / shape inference === + +def dex_call_abstract_eval_with_shape(*args, func_atom): + # TODO: Make it possible to get the signature without compiling the function + native_func = get_compiled(func_atom) + arg_sig = native_func.explicit_argument_signature + res_sig = native_func.result_signature + if len(args) != len(arg_sig): + raise RuntimeError(f"Dex function expects {len(arg_sig)} arguments, but was given {len(args)}") + if not all(isinstance(arg, jax.core.ShapedArray) for arg in args): + raise RuntimeError("Cannot perform evaluation of Dex functions without known shapes") + # Check arguments and infer shape parameters + shape_vars = {} + for i, (arg, b) in enumerate(zip(args, arg_sig)): + expected_dtype = np.dtype(b.type.ctype) + if arg.dtype != expected_dtype: + raise RuntimeError(f"dtype mismatch in arg {i}: expected {expected_dtype}, got {arg.dtype}") + if isinstance(b.type, ScalarType): + expected_shape = () + elif isinstance(b.type, RectContArrayType): + expected_shape = b.type.shape + else: + raise AssertionError("Unhandled case!") + if len(arg.shape) != len(expected_shape): + raise RuntimeError(f"rank mismatch in arg {i}: expected {len(expected_shape)}, got {len(arg.shape)}") + inferred_shape = tuple( + size if isinstance(size, int) else shape_vars.setdefault(size, real_size) + for size, real_size in zip(expected_shape, arg.shape)) + if arg.shape != inferred_shape: + raise RuntimeError(f"shape mismatch in arg {i}: expected {inferred_shape}, got {arg.shape}") + # Infer result types + result_avals = [] + for b in res_sig: + dtype = np.dtype(b.type.ctype) + if isinstance(b.type, ScalarType): + shape = () + elif isinstance(b.type, RectContArrayType): + shape = tuple(shape_vars.get(size, size) for size in b.type.shape) + result_avals.append(jax.core.ShapedArray(shape, dtype)) + assert len(result_avals) == 1 # TODO: Make dex_call a multiple_results primitive + return result_avals[0], shape_vars + +@dex_call_p.def_abstract_eval +def dex_call_abstract_eval(*args, **kwargs): + return dex_call_abstract_eval_with_shape(*args, **kwargs)[0] + +# === xla translation === + +PyCapsule_Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object) +PyCapsule_New = ctypes.pythonapi.PyCapsule_New +PyCapsule_New.restype = ctypes.py_object +PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, PyCapsule_Destructor) + +def make_custom_call_target(func_ptr): + return PyCapsule_New(func_ptr, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0)) + +# TODO: Better lifetime management. func_atoms will be quite often created on the fly +# at trace time when different transforms are applied, and I'm pretty sure that +# the XLA executables outlive jaxprs formed by tracing. +custom_call_id = count() +custom_call_cache = {} +def dex_call_cpu_translation(b, *args, func_atom): + xla_shapes = list(map(b.get_shape, args)) + result_aval, shape_vars = dex_call_abstract_eval_with_shape( + *(jax.core.ShapedArray(xshape.dimensions(), xshape.numpy_dtype()) + for xshape in xla_shapes), + func_atom=func_atom) + result_xshape = xc.Shape.array_shape(result_aval.dtype, result_aval.shape) + + custom_call = custom_call_cache.get(func_atom, None) + native = get_compiled(func_atom) + if custom_call is None: + assert len(args) == len(native.explicit_argument_signature) + assert 1 == len(native.result_signature) + custom_call_ctype = ctypes.CFUNCTYPE(None, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_void_p * len(args))) + @custom_call_ctype + def trampoline(result_ptr, arg_ptr_array): + name_to_cval = {name: IdxRepTy(value) for name, value in shape_vars.items()} + for binder, ptr in zip(native.explicit_argument_signature, arg_ptr_array.contents): + if isinstance(binder.type, ScalarType): + cval = ctypes.cast(ptr, ctypes.POINTER(binder.type.arg_ctype)).contents + elif isinstance(binder.type, RectContArrayType): + cval = ctypes.cast(ptr, binder.type.arg_ctype) + else: + raise AssertionError("Unexpected binder type") + name_to_cval[binder.name] = cval + result_binder = native.result_signature[0] + name_to_cval[result_binder.name] = ctypes.cast(result_ptr, result_binder.type.ref_ctype) + native.callable(*(name_to_cval[name] for name in native.ccall_signature)) + + trampoline_addr = ctypes.c_void_p.from_param(trampoline) + custom_call_name = f"dex_custom_call{next(custom_call_id)}".encode('ascii') + xc.register_custom_call_target(custom_call_name, + make_custom_call_target(trampoline_addr)) + custom_call_cache[func_atom] = (custom_call_name, trampoline) + # TODO: Unregister custom calls at some point? + else: + custom_call_name, *_ = custom_call + return xc.ops.CustomCall(b, custom_call_name, operands=args, shape=result_xshape) + +jax.interpreters.xla.backend_specific_translations['cpu'][dex_call_p] = dex_call_cpu_translation + +# TODO +# jax.interpreters.batching.primitive_batchers[self.primitive] = ... +# jax.interpreters.ad.primitive_jvps[self.primitive] = ... +# jax.interpreters.ad.primitive_transposes[self.primitive] = ... diff --git a/python/dex/native_function.py b/python/dex/native_function.py new file mode 100644 index 000000000..6008e34c4 --- /dev/null +++ b/python/dex/native_function.py @@ -0,0 +1,221 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import sys +import ctypes +import string +import numpy as np +from typing import Any, List, Union, Callable, Dict +from dataclasses import dataclass +from . import api + +ScalarCType = Union[ + ctypes.c_int64, ctypes.c_int32, + ctypes.c_uint8, + ctypes.c_double, ctypes.c_float +] +IdxRepTy = ctypes.c_int32 + +@dataclass(frozen=True) +class ScalarType: + ctype: Any + from_ctype: Callable + + @property + def arg_ctype(self): return self.ctype + + @property + def ref_ctype(self): return ctypes.POINTER(self.ctype) + + def to_ctype(self, value, name_cvalue): + return self.ctype(value) + + def create(self, name_cvalue): + instance = self.ctype() + return ctypes.pointer(instance), lambda: self.from_ctype(instance) + + +@dataclass(frozen=True) +class RectContArrayType: + ctype: ScalarType + shape: List[Union[str, int]] + + @property + def arg_ctype(self): + return ctypes.POINTER(self.ctype) + + @property + def ref_ctype(self): + return ctypes.POINTER(self.ctype) + + def unsafe_array_ptr(self, array): + ptr, _ = array.__array_interface__['data'] + return ctypes.cast(ctypes.c_void_p(ptr), ctypes.POINTER(self.ctype)) + + def to_ctype(self, array, name_cvalue): + if not isinstance(array, np.ndarray): + array = np.asarray(array) + if array.ndim != len(self.shape): + raise ValueError(f"Expected a {len(self.shape)}D array, got {array.ndim}D") + expected_dtype = np.dtype(self.ctype) + if array.dtype != expected_dtype: + raise ValueError(f"Expected a {expected_dtype} array, got {array.dtype}") + expected_shape = tuple( + size if isinstance(size, int) else name_cvalue.setdefault(size, IdxRepTy(real_size)).value + for size, real_size in zip(self.shape, array.shape)) + if expected_shape != array.shape: + raise ValueError(f"Shape mismatch: expected {expected_shape}, but got {array.shape}") + if not array.flags['C_CONTIGUOUS']: + raise ValueError("Only contiguous arrays supported as arguments at the moment") + return self.unsafe_array_ptr(array) + + def create(self, name_cvalue): + shape = [size if isinstance(size, int) else name_cvalue[size].value + for size in self.shape] + result = np.empty(shape, dtype=self.ctype) + return self.unsafe_array_ptr(result), lambda: result + +NativeType = Union[ScalarType, RectContArrayType] + + +@dataclass(frozen=True) +class Binder: + name: str + type: NativeType + implicit: bool + + +class NativeFunction: + def __init__(self, jit, ptr): + self._as_parameter_ = ptr + self._jit = jit + sig_ptr = api.getFunctionSignature(jit, ptr) + if not sig_ptr: + raise RuntimeError("Failed to retrieve the function signature") + try: + signature = sig_ptr.contents + self.argument_signature = _SignatureParser(signature.arg).parse() + self.explicit_argument_signature = [arg for arg in self.argument_signature if not arg.implicit] + self.result_signature = _SignatureParser(signature.res).parse() + self.ccall_signature = [sys.intern(arg.decode('ascii')) for arg in signature.ccall.split(b',')] + finally: + api.freeFunctionSignature(sig_ptr) + + func_type = ctypes.CFUNCTYPE( + ctypes.c_int64, + *(arg.type.arg_ctype for arg in self.argument_signature), + *(res.type.ref_ctype for res in self.result_signature)) + self.callable = func_type(ctypes.cast(ptr, ctypes.c_void_p).value) + + def __del__(self): + if api.nofree: return + if hasattr(self, '_as_parameter_'): + api.unload(self._jit, self) + + def __call__(self, *args): + name_to_cval = {} + result_thunks = [] + assert len(self.explicit_argument_signature) == len(args) + for arg, binder in zip(args, self.explicit_argument_signature): + name_to_cval[binder.name] = binder.type.to_ctype(arg, name_to_cval) + for binder in self.result_signature: + value, result_thunk = binder.type.create(name_to_cval) + name_to_cval[binder.name] = value + result_thunks.append(result_thunk) + self.callable(*(name_to_cval[name] for name in self.ccall_signature)) + results = tuple(thunk() for thunk in result_thunks) + if len(results) == 1: + return results[0] + else: + return results + + +class _SignatureParser: + __slots__ = ('text', 'offset') + + def __init__(self, text): + self.text = text + + def consume(self, char: str): + assert self.text[self.offset] == ord(char) + self.offset += 1 + + def maybe_consume(self, char: str) -> bool: + if self.offset < len(self.text) and self.text[self.offset] == ord(char): + self.offset += 1 + return True + return False + + digit_codes = set(string.digits.encode('ascii')) + name_codes = set(string.ascii_letters.encode('ascii')) | digit_codes + + def parse_name(self) -> str: + end = self.offset + name_codes = self.name_codes + text = self.text + while text[end] in name_codes: + end += 1 + result = sys.intern(self.text[self.offset:end].decode('ascii')) + self.offset = end + return result + + scalar_types: Dict[bytes, ScalarType] = { + b'i64': ScalarType(ctypes.c_int64, np.int64), + b'i32': ScalarType(ctypes.c_int32, np.int32), + b'u8': ScalarType(ctypes.c_uint8, np.uint8), + b'f64': ScalarType(ctypes.c_double, np.float64), + b'f32': ScalarType(ctypes.c_float, np.float32), + } + + def parse_type(self) -> NativeType: + for name, scalar_type in self.scalar_types.items(): + if self.text.startswith(name, self.offset): + break + else: + raise RuntimeError(f"Invalid type specification: {sig[self.offset:self.offset+3].decode('ascii')}") + self.offset += len(name) + if self.maybe_consume('['): + if self.maybe_consume('?'): + raise RuntimeError("Only rectangular array types supported") + shape = [] + while True: + shape.append(self.parse_dim()) + if self.maybe_consume(']'): + break + else: + self.consume(',') + return RectContArrayType(scalar_type.ctype, shape) + else: + return scalar_type + + def parse_dim(self): + if self.text[self.offset] in self.digit_codes: + return self.parse_number() + else: + return self.parse_name() + + def parse_number(self) -> int: + end = self.offset + while self.text[end] in self.digit_codes: + end += 1 + result = int(self.text[self.offset:end].decode('ascii')) + self.offset = end + return result + + def parse(self): + self.offset = 0 + binders = [] + while True: + implicit = self.maybe_consume('?') + name = self.parse_name() + self.consume(':') + ty = self.parse_type() + binders.append(Binder(name, ty, implicit)) + if self.offset == len(self.text): + break + else: + self.consume(',') + return binders diff --git a/python/tests/api_test.py b/python/tests/api_test.py index 6282b4875..5b5606967 100644 --- a/python/tests/api_test.py +++ b/python/tests/api_test.py @@ -5,6 +5,8 @@ # https://developers.google.com/open-source/licenses/bsd import unittest +import ctypes +import numpy as np from textwrap import dedent # TODO: Write a proper setup.py instead of using this hack... diff --git a/python/tests/jax_test.py b/python/tests/jax_test.py new file mode 100644 index 000000000..02a4fbff8 --- /dev/null +++ b/python/tests/jax_test.py @@ -0,0 +1,53 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest +import ctypes +import numpy as np +from textwrap import dedent + +# TODO: Write a proper setup.py instead of using this hack... +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent)) + +import jax +import jax.numpy as jnp + +import dex +from dex.interop.jax import primitive + +def test_impl_scalar(): + add_two = primitive(dex.eval(r'\x:Float. x + 2.0')) + x = jnp.zeros((), dtype=np.float32) + np.testing.assert_allclose(add_two(x), x + 2.0) + +def test_impl_array(): + add_two = primitive(dex.eval(r'\x:((Fin 10)=>Float). for i. x.i + 2.0')) + x = jnp.arange((10,), dtype=np.float32) + np.testing.assert_allclose(add_two(x), x + 2.0) + +def test_abstract_eval_simple(): + add_two = primitive(dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) + x = jax.ShapeDtypeStruct((10,), np.float32) + output_shape = jax.eval_shape(add_two, x) + assert output_shape.shape == (10,) + assert output_shape.dtype == np.int32 + +def test_jit_scalar(): + add_two = primitive(dex.eval(r'\x:Float. x + 2.0')) + x = jnp.zeros((), dtype=np.float32) + np.testing.assert_allclose(jax.jit(add_two)(x), 2.0) + +def test_jit_array(): + add_two = primitive(dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) + x = jnp.zeros((10,), dtype=np.float32) + np.testing.assert_allclose(jax.jit(add_two)(x), (x + 2.0).astype(np.int32)) + +def test_jit_scale(): + scale = primitive(dex.eval(r'\x:((Fin 10)=>Float) y:Float. for i. x.i * y')) + x = jnp.arange((10,), dtype=np.float32) + np.testing.assert_allclose(scale(x, 5.0), x * 5.0) diff --git a/python/tests/jit_test.py b/python/tests/jit_test.py new file mode 100644 index 000000000..e299b845b --- /dev/null +++ b/python/tests/jit_test.py @@ -0,0 +1,78 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest +import ctypes +import numpy as np +import itertools as it +from textwrap import dedent + +# TODO: Write a proper setup.py instead of using this hack... +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent)) + +import dex + +example_floats = list(map(np.float32, (-1.0, -0.5, 0.0, 0.5, 1.0))) +example_ints = [-10, -5, 0, 5, 10] + +def check_atom(dex_atom, reference, args_iter): + compiled = dex_atom.compile() + ran_any_iter = False + for args in args_iter: + ran_any_iter = True + np.testing.assert_allclose(compiled(*args), reference(*args), + rtol=1e-4, atol=1e-6) + assert ran_any_iter, "Empty argument iterator!" + +def expr_test(dex_source, reference, args_iter): + def test(): + return check_atom(dex.eval(dex_source), reference, args_iter) + return test + +test_sigmoid = expr_test(r"\x:Float. 1.0 / (1.0 + exp(-x))", + lambda x: np.float32(1.0) / (np.float32(1.0) + np.exp(-x)), + ((x,) for x in example_floats)) + +test_multi_arg = expr_test(r"\x:Float y:Float. atan2 x y", + np.arctan2, + ((x + 0.01, y) for x, y in it.product(example_floats, repeat=2) + if (x, y) != (0.0, 0.0))) + +test_int_arg = expr_test(r"\x:Int64 y:Int. I64ToI x + y", + lambda x, y: x + y, + it.product(example_ints, example_ints)) + +test_array_scalar = expr_test(r"\x:((Fin 10)=>Float). sum x", + np.sum, + [(np.arange(10, dtype=np.float32),)]) + +test_scalar_array = expr_test(r"\x:Int. for i:(Fin 10). x + ordinal i", + lambda x: x + np.arange(10, dtype=np.int32), + [(i,) for i in range(5)]) + +test_array_array = expr_test(r"\x:((Fin 10)=>Float). for i. exp x.i", + np.exp, + [(np.arange(10, dtype=np.float32),)]) + +def test_polymorphic_array_1d(): + m = dex.Module(dedent(""" + def addTwo (n: Int) ?-> (x: (Fin n)=>Float) : (Fin n)=>Float = for i. x.i + 2.0 + """)) + check_atom(m.addTwo, lambda x: x + 2, + [(np.arange(l, dtype=np.float32),) for l in (2, 5, 10)]) + +def test_polymorphic_array_2d(): + m = dex.Module(dedent(""" + def myTranspose (n: Int) ?-> (m: Int) ?-> + (x : (Fin n)=>(Fin m)=>Float) : (Fin m)=>(Fin n)=>Float = + for i j. x.j.i + """)) + check_atom(m.myTranspose, lambda x: x.T, + [(np.arange(a*b, dtype=np.float32).reshape((a, b)),) + for a, b in it.product((2, 5, 10), repeat=2)]) + diff --git a/shell.nix b/shell.nix new file mode 100644 index 000000000..93215131e --- /dev/null +++ b/shell.nix @@ -0,0 +1,15 @@ +{ nixpkgs ? import {} }: +with nixpkgs; +stdenv.mkDerivation { + name = "dex"; + buildInputs = [ + cabal-install + haskell.compiler.ghc884 + llvm_9 + clang_9 + pkg-config + libpng + git + cacert + ]; +} diff --git a/src/Dex/Foreign/API.hs b/src/Dex/Foreign/API.hs new file mode 100644 index 000000000..f6c8349c7 --- /dev/null +++ b/src/Dex/Foreign/API.hs @@ -0,0 +1,43 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module Dex.Foreign.API where + +import Foreign.Ptr +import Foreign.C + +import Syntax + +import Dex.Foreign.Context +import Dex.Foreign.Serialize +import Dex.Foreign.JIT + +-- Public API (commented out exports are defined in rts.c) + +-- Initialization and basic runtime +-- foreign export ccall "dexInit" _ :: IO () +-- foreign export ccall "dexFini" _ :: IO () +-- foreign export ccall "dexGetError" _ :: CString + +-- Context +foreign export ccall "dexCreateContext" dexCreateContext :: IO (Ptr Context) +foreign export ccall "dexDestroyContext" dexDestroyContext :: Ptr Context -> IO () +foreign export ccall "dexInsert" dexInsert :: Ptr Context -> CString -> Ptr Atom -> IO (Ptr Context) +foreign export ccall "dexEval" dexEval :: Ptr Context -> CString -> IO (Ptr Context) +foreign export ccall "dexEvalExpr" dexEvalExpr :: Ptr Context -> CString -> IO (Ptr Atom) +foreign export ccall "dexLookup" dexLookup :: Ptr Context -> CString -> IO (Ptr Atom) + +-- Serialization +foreign export ccall "dexPrint" dexPrint :: Ptr Atom -> IO CString +foreign export ccall "dexToCAtom" dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt + +-- JIT +foreign export ccall "dexCreateJIT" dexCreateJIT :: IO (Ptr JIT) +foreign export ccall "dexDestroyJIT" dexDestroyJIT :: Ptr JIT -> IO () +foreign export ccall "dexCompile" dexCompile :: Ptr JIT -> Ptr Context -> Ptr Atom -> IO (Ptr NativeFunction) +foreign export ccall "dexUnload" dexUnload :: Ptr JIT -> Ptr NativeFunction -> IO () +foreign export ccall "dexGetFunctionSignature" dexGetFunctionSignature :: Ptr JIT -> Ptr NativeFunction -> IO (Ptr ExportedSignature) +foreign export ccall "dexFreeFunctionSignature" dexFreeFunctionSignature :: Ptr ExportedSignature -> IO () diff --git a/src/foreign/API.hs b/src/Dex/Foreign/Context.hs similarity index 50% rename from src/foreign/API.hs rename to src/Dex/Foreign/Context.hs index 23e1da09b..7a0e3cbb1 100644 --- a/src/foreign/API.hs +++ b/src/Dex/Foreign/Context.hs @@ -4,19 +4,21 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module API where +module Dex.Foreign.Context ( + Context (..), + setError, + dexCreateContext, dexDestroyContext, + dexInsert, dexLookup, + dexEval, dexEvalExpr, + ) where import Control.Monad.State.Strict import Foreign.Ptr import Foreign.StablePtr -import Foreign.Storable -import Foreign.Marshal.Alloc import Foreign.C.String -import Foreign.C.Types import Data.String -import Data.Word import Data.Int import Data.Functor import Data.Foldable @@ -26,22 +28,10 @@ import Syntax hiding (sizeOf) import Type import TopLevel import Parser (parseExpr, exprAsModule) -import Serialize (pprintVal) import Env hiding (Tag) import PPrint --- Public API (commented out exports are defined in rts.c) --- foreign export ccall "dexInit" _ :: IO () --- foreign export ccall "dexFini" _ :: IO () --- foreign export ccall "dexGetError" _ :: CString -foreign export ccall "dexCreateContext" dexCreateContext :: IO (Ptr Context) -foreign export ccall "dexDestroyContext" dexDestroyContext :: Ptr Context -> IO () -foreign export ccall "dexPrint" dexPrint :: Ptr Atom -> IO CString -foreign export ccall "dexInsert" dexInsert :: Ptr Context -> CString -> Ptr Atom -> IO (Ptr Context) -foreign export ccall "dexEval" dexEval :: Ptr Context -> CString -> IO (Ptr Context) -foreign export ccall "dexEvalExpr" dexEvalExpr :: Ptr Context -> CString -> IO (Ptr Atom) -foreign export ccall "dexLookup" dexLookup :: Ptr Context -> CString -> IO (Ptr Atom) -foreign export ccall "dexToCAtom" dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt +import Dex.Foreign.Util data Context = Context EvalConfig TopEnv @@ -56,10 +46,10 @@ dexCreateContext = do maybePreludeEnv <- evalPrelude evalConfig preludeSource case maybePreludeEnv of Right preludeEnv -> toStablePtr $ Context evalConfig preludeEnv - Left _ -> setError "Failed to initialize standard library" $> nullPtr + Left err -> nullPtr <$ setError ("Failed to initialize standard library: " ++ pprint err) where evalPrelude :: EvalConfig -> String -> IO (Either Err TopEnv) - evalPrelude opts contents = flip evalStateT mempty $ do + evalPrelude opts contents = flip evalStateT initTopEnv $ do results <- fmap snd <$> evalSource opts contents env <- get return $ env `unlessError` results @@ -72,9 +62,6 @@ dexCreateContext = do dexDestroyContext :: Ptr Context -> IO () dexDestroyContext = freeStablePtr . castPtrToStablePtr . castPtr -dexPrint :: Ptr Atom -> IO CString -dexPrint atomPtr = newCString =<< pprintVal =<< fromStablePtr atomPtr - dexEval :: Ptr Context -> CString -> IO (Ptr Context) dexEval ctxPtr sourcePtr = do Context evalConfig env <- fromStablePtr ctxPtr @@ -96,11 +83,11 @@ dexInsert ctxPtr namePtr atomPtr = do dexEvalExpr :: Ptr Context -> CString -> IO (Ptr Atom) dexEvalExpr ctxPtr sourcePtr = do Context evalConfig env <- fromStablePtr ctxPtr - maybeExpr <- parseExpr <$> peekCString sourcePtr - case maybeExpr of + source <- peekCString sourcePtr + case parseExpr source of Right expr -> do let (v, m) = exprAsModule expr - let block = SourceBlock 0 0 LogNothing "" (RunModule m) Nothing + let block = SourceBlock 0 0 LogNothing source (RunModule m) Nothing (resultEnv, Result [] maybeErr) <- evalSourceBlock evalConfig env block case maybeErr of Right () -> do @@ -118,61 +105,3 @@ dexLookup ctxPtr namePtr = do Just _ -> setError "Looking up an expression" $> nullPtr Nothing -> setError "Unbound name" $> nullPtr -dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt -dexToCAtom atomPtr resultPtr = do - atom <- fromStablePtr atomPtr - case atom of - Con con -> case con of - Lit (VecLit _) -> notSerializable - Lit l -> poke resultPtr (CLit l) $> 1 - _ -> notSerializable - _ -> notSerializable - where - notSerializable = setError "Unserializable atom" $> 0 - -dexFreeCAtom :: Ptr CAtom -> IO () -dexFreeCAtom = free - -data CAtom = CLit LitVal | CRectArray (Ptr ()) [Int] [Int] - -instance Storable CAtom where - sizeOf _ = tag + val + val + val - where tag = 8; val = 8 - alignment _ = 8 - peek addr = do - tag <- val @Word64 0 - case tag of - 0 -> do - litTag <- val @Word64 1 - CLit <$> case litTag of - 0 -> Int64Lit <$> val 2 - 1 -> Int32Lit <$> val 2 - 2 -> Word8Lit <$> val 2 - 3 -> Float64Lit <$> val 2 - 4 -> Float32Lit <$> val 2 - _ -> error "Invalid tag" - _ -> error "Invalid tag" - where - val :: forall a. Storable a => Int -> IO a - val i = peekByteOff (castPtr addr) (i * 8) - poke addr catom = case catom of - CLit lit -> do - val @Word64 0 0 - case lit of - Int64Lit v -> val @Word64 1 0 >> val 2 v - Int32Lit v -> val @Word64 1 1 >> val 2 v - Word8Lit v -> val @Word64 1 2 >> val 2 v - Float64Lit v -> val @Word64 1 3 >> val 2 v - Float32Lit v -> val @Word64 1 4 >> val 2 v - VecLit _ -> error "Unsupported" - PtrLit _ _ -> error "Unsupported" - CRectArray _ _ _ -> error "Unsupported" - where - val :: forall a. Storable a => Int -> a -> IO () - val i v = pokeByteOff (castPtr addr) (i * 8) v - -fromStablePtr :: Ptr a -> IO a -fromStablePtr = deRefStablePtr . castPtrToStablePtr . castPtr - -toStablePtr :: a -> IO (Ptr a) -toStablePtr x = castPtr . castStablePtrToPtr <$> newStablePtr x diff --git a/src/Dex/Foreign/JIT.hs b/src/Dex/Foreign/JIT.hs new file mode 100644 index 000000000..d40a4b4a0 --- /dev/null +++ b/src/Dex/Foreign/JIT.hs @@ -0,0 +1,115 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE RecordWildCards #-} +{-# OPTIONS_GHC -Wno-orphans #-} + +module Dex.Foreign.JIT ( + JIT, NativeFunction, ExportedSignature, + dexCreateJIT, dexDestroyJIT, + dexGetFunctionSignature, dexFreeFunctionSignature, + dexCompile, dexUnload + ) where + +import Control.Monad.State.Strict + +import Foreign.Ptr +import Foreign.C.String +import Foreign.Storable +import Foreign.Marshal.Alloc + +import Data.IORef +import Data.Functor +import qualified Data.Map.Strict as M + +import LLVM.Target (TargetMachine) +import qualified LLVM.Relocation as R +import qualified LLVM.CodeModel as CM +import qualified LLVM.CodeGenOpt as CGO +import qualified LLVM.JIT +import qualified LLVM.Shims + +import Logging +import LLVMExec +import JIT +import Syntax hiding (sizeOf) +import Export + +import Dex.Foreign.Util +import Dex.Foreign.Context + +data NativeFunction = + NativeFunction { nativeModule :: LLVM.JIT.NativeModule + , nativeSignature :: ExportedSignature } +type NativeFunctionAddr = Ptr NativeFunction + +data JIT = ForeignJIT { jit :: LLVM.JIT.JIT + , jitTargetMachine :: TargetMachine + , addrTableRef :: IORef (M.Map NativeFunctionAddr NativeFunction) + } + +instance Storable ExportedSignature where + sizeOf _ = 3 * sizeOf (undefined :: Ptr ()) + alignment _ = alignment (undefined :: Ptr ()) + peek _ = error "peek not implemented for ExportedSignature" + poke addr sig = do + let strAddr = castPtr @ExportedSignature @CString addr + let (arg, res, ccall) = exportedSignatureDesc sig + pokeElemOff strAddr 0 =<< newCString arg + pokeElemOff strAddr 1 =<< newCString res + pokeElemOff strAddr 2 =<< newCString ccall + +dexCreateJIT :: IO (Ptr JIT) +dexCreateJIT = do + jitTargetMachine <- LLVM.Shims.newHostTargetMachine R.PIC CM.Large CGO.Aggressive + jit <- LLVM.JIT.createJIT jitTargetMachine + addrTableRef <- newIORef mempty + toStablePtr ForeignJIT{..} + +dexDestroyJIT :: Ptr JIT -> IO () +dexDestroyJIT jitPtr = do + ForeignJIT{..} <- fromStablePtr jitPtr + addrTable <- readIORef addrTableRef + forM_ (M.toList addrTable) $ \(_, m) -> LLVM.JIT.unloadNativeModule $ nativeModule m + LLVM.JIT.destroyJIT jit + LLVM.Shims.disposeTargetMachine jitTargetMachine + +dexCompile :: Ptr JIT -> Ptr Context -> Ptr Atom -> IO NativeFunctionAddr +dexCompile jitPtr ctxPtr funcAtomPtr = do + ForeignJIT{..} <- fromStablePtr jitPtr + Context _ env <- fromStablePtr ctxPtr + funcAtom <- fromStablePtr funcAtomPtr + let (impMod, nativeSignature) = prepareFunctionForExport env "userFunc" funcAtom + nativeModule <- execLogger Nothing $ \logger -> do + llvmAST <- impToLLVM logger impMod + LLVM.JIT.compileModule jit llvmAST + (standardCompilationPipeline logger ["userFunc"] jitTargetMachine) + funcPtr <- castFunPtrToPtr <$> LLVM.JIT.getFunctionPtr nativeModule "userFunc" + modifyIORef addrTableRef $ M.insert funcPtr NativeFunction{..} + return $ funcPtr + +dexGetFunctionSignature :: Ptr JIT -> NativeFunctionAddr -> IO (Ptr ExportedSignature) +dexGetFunctionSignature jitPtr funcPtr = do + ForeignJIT{..} <- fromStablePtr jitPtr + addrTable <- readIORef addrTableRef + case M.lookup funcPtr addrTable of + Nothing -> setError "Invalid function address" $> nullPtr + Just NativeFunction{..} -> putOnHeap nativeSignature + +dexFreeFunctionSignature :: Ptr ExportedSignature -> IO () +dexFreeFunctionSignature sigPtr = do + let strPtr = castPtr @ExportedSignature @CString sigPtr + free =<< peekElemOff strPtr 0 + free =<< peekElemOff strPtr 1 + free =<< peekElemOff strPtr 2 + free sigPtr + +dexUnload :: Ptr JIT -> NativeFunctionAddr -> IO () +dexUnload jitPtr funcPtr = do + ForeignJIT{..} <- fromStablePtr jitPtr + addrTable <- readIORef addrTableRef + LLVM.JIT.unloadNativeModule $ nativeModule $ addrTable M.! funcPtr + modifyIORef addrTableRef $ M.delete funcPtr diff --git a/src/Dex/Foreign/Serialize.hs b/src/Dex/Foreign/Serialize.hs new file mode 100644 index 000000000..8d882ee49 --- /dev/null +++ b/src/Dex/Foreign/Serialize.hs @@ -0,0 +1,77 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module Dex.Foreign.Serialize ( + CAtom, + dexPrint, dexToCAtom + ) where + +import Data.Word +import Data.Functor + +import Foreign.C +import Foreign.Ptr +import Foreign.Storable + +import Syntax +import Serialize (pprintVal) + +import Dex.Foreign.Context (setError) +import Dex.Foreign.Util + +-- TODO: Free! +dexPrint :: Ptr Atom -> IO CString +dexPrint atomPtr = newCString =<< pprintVal =<< fromStablePtr atomPtr + +data CAtom = CLit LitVal | CRectArray (Ptr ()) [Int] [Int] + +instance Storable CAtom where + sizeOf _ = tag + val + val + val + where tag = 8; val = 8 + alignment _ = 8 + peek addr = do + tag <- val @Word64 0 + case tag of + 0 -> do + litTag <- val @Word64 1 + CLit <$> case litTag of + 0 -> Int64Lit <$> val 2 + 1 -> Int32Lit <$> val 2 + 2 -> Word8Lit <$> val 2 + 3 -> Float64Lit <$> val 2 + 4 -> Float32Lit <$> val 2 + _ -> error "Invalid tag" + _ -> error "Invalid tag" + where + val :: forall a. Storable a => Int -> IO a + val i = peekByteOff (castPtr addr) (i * 8) + poke addr catom = case catom of + CLit lit -> do + val @Word64 0 0 + case lit of + Int64Lit v -> val @Word64 1 0 >> val 2 v + Int32Lit v -> val @Word64 1 1 >> val 2 v + Word8Lit v -> val @Word64 1 2 >> val 2 v + Float64Lit v -> val @Word64 1 3 >> val 2 v + Float32Lit v -> val @Word64 1 4 >> val 2 v + VecLit _ -> error "Unsupported" + PtrLit _ _ -> error "Unsupported" + CRectArray _ _ _ -> error "Unsupported" + where + val :: forall a. Storable a => Int -> a -> IO () + val i v = pokeByteOff (castPtr addr) (i * 8) v + +dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt +dexToCAtom atomPtr resultPtr = do + atom <- fromStablePtr atomPtr + case atom of + Con con -> case con of + Lit (VecLit _) -> notSerializable + Lit l -> poke resultPtr (CLit l) $> 1 + _ -> notSerializable + _ -> notSerializable + where + notSerializable = setError "Unserializable atom" $> 0 diff --git a/src/Dex/Foreign/Util.hs b/src/Dex/Foreign/Util.hs new file mode 100644 index 000000000..de156e983 --- /dev/null +++ b/src/Dex/Foreign/Util.hs @@ -0,0 +1,24 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module Dex.Foreign.Util (fromStablePtr, toStablePtr, putOnHeap) where + +import Foreign.Ptr +import Foreign.StablePtr +import Foreign.Storable +import Foreign.Marshal.Alloc + +fromStablePtr :: Ptr a -> IO a +fromStablePtr = deRefStablePtr . castPtrToStablePtr . castPtr + +toStablePtr :: a -> IO (Ptr a) +toStablePtr x = castPtr . castStablePtrToPtr <$> newStablePtr x + +putOnHeap :: Storable a => a -> IO (Ptr a) +putOnHeap x = do + ptr <- malloc + poke ptr x + return ptr diff --git a/src/foreign/rts.c b/src/Dex/Foreign/rts.c similarity index 100% rename from src/foreign/rts.c rename to src/Dex/Foreign/rts.c diff --git a/src/dex.hs b/src/dex.hs index 9c36c9ae8..d6102f503 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -11,10 +11,12 @@ import System.Exit import Control.Monad import Control.Monad.State.Strict import Options.Applicative +import Text.PrettyPrint.ANSI.Leijen (text, hardline) import System.Posix.Terminal (queryTerminal) import System.Posix.IO (stdOutput) -import System.Exit + import System.Directory +import Data.List import Syntax import PPrint @@ -25,6 +27,8 @@ import Resources import TopLevel import Parser hiding (Parser) import LiveOutput +import Env (envNames) +import Export data ErrorHandling = HaltOnErr | ContinueOnErr data DocFmt = ResultOnly | TextDoc | HTMLDoc | JSONDoc @@ -44,8 +48,10 @@ runMode evalMode preludeFile opts = do env <- cached "prelude" key $ evalPrelude opts preludeFile let runEnv m = evalStateT m env case evalMode of - ReplMode prompt -> - runEnv $ runInputT defaultSettings $ forever (replLoop prompt opts) + ReplMode prompt -> do + let filenameAndDexCompletions = completeQuotedWord (Just '\\') "\"'" listFiles dexCompletions + let hasklineSettings = setComplete filenameAndDexCompletions defaultSettings + runEnv $ runInputT hasklineSettings $ forever (replLoop prompt opts) ScriptMode fname fmt _ -> do results <- runEnv $ evalFile opts fname printLitProg fmt results @@ -54,17 +60,19 @@ runMode evalMode preludeFile opts = do WebMode fname -> runWeb fname opts env WatchMode fname -> runTerminal fname opts env ExportMode dexPath objPath -> do - results <- fmap snd <$> (runEnv $ evalFile opts dexPath) + results <- fmap snd <$> runEnv (evalFile opts dexPath) let outputs = foldMap (\(Result outs _) -> outs) results let errors = foldMap (\case (Result _ (Left err)) -> [err]; _ -> []) results putStr $ foldMap (nonEmptyNewline . pprint) errors let exportedFuns = foldMap (\case (ExportedFun name f) -> [(name, f)]; _ -> []) outputs - exportFunctions objPath exportedFuns env opts + unless (backendName opts == LLVM) $ liftEitherIO $ + throw CompilerErr "Export only supported with the LLVM CPU backend" + exportFunctions objPath exportedFuns env evalPrelude :: EvalConfig -> Maybe FilePath -> IO TopEnv -evalPrelude opts fname = flip execStateT mempty $ do +evalPrelude opts fname = flip execStateT initTopEnv $ do source <- case fname of - Nothing -> return $ preludeSource + Nothing -> return preludeSource Just path -> liftIO $ readFile path result <- evalSource opts source void $ liftErrIO $ mapM (\(_, Result _ r) -> r) result @@ -78,6 +86,20 @@ replLoop prompt opts = do _ -> return () liftIO $ putStrLn $ pprint result +dexCompletions :: CompletionFunc (StateT TopEnv IO) +dexCompletions (line, _) = do + env <- get + let varNames = map pprint $ envNames env + -- note: line and thus word and rest have character order reversed + let (word, rest) = break (== ' ') line + let anywhereKeywords = ["def", "for", "rof", "case", "data", "where", "of", "if", + "then", "else", "interface", "instance", "do", "view"] + let startoflineKeywords = ["%bench \"", ":p", ":t", ":html", ":export"] + let candidates = (if null rest then startoflineKeywords else []) ++ + anywhereKeywords ++ varNames + let completions = map simpleCompletion $ filter (reverse word `isPrefixOf`) candidates + return (rest, completions) + liftErrIO :: MonadIO m => Except a -> m a liftErrIO (Left err) = liftIO $ putStrLn (pprint err) >> exitFailure liftErrIO (Right x) = return x @@ -106,56 +128,62 @@ printLitProg TextDoc prog = do isatty <- queryTerminal stdOutput putStr $ foldMap (uncurry (printLitBlock isatty)) prog printLitProg JSONDoc prog = - forM_ prog $ \(_, result) -> case toJSONStr result of + forM_ prog \(_, result) -> case toJSONStr result of "{}" -> return () s -> putStrLn s +nonEmptyNewline :: String -> String nonEmptyNewline [] = [] nonEmptyNewline l = l ++ ['\n'] parseOpts :: ParserInfo CmdOpts parseOpts = simpleInfo $ CmdOpts <$> parseMode - <*> (optional $ strOption $ long "prelude" <> metavar "FILE" <> help "Prelude file") + <*> optional (strOption $ long "prelude" <> metavar "FILE" <> help "Prelude file") <*> parseEvalOpts +helpOption :: String -> String -> Mod f a +helpOption optionName options = + helpDoc (Just (text optionName <> hardline <> text options)) + parseMode :: Parser EvalMode parseMode = subparser $ - (command "repl" $ simpleInfo $ - ReplMode <$> (strOption $ long "prompt" <> value ">=> " - <> metavar "STRING" <> help "REPL prompt")) - <> (command "web" $ simpleInfo (WebMode <$> sourceFileInfo )) - <> (command "watch" $ simpleInfo (WatchMode <$> sourceFileInfo )) - <> (command "export" $ simpleInfo (ExportMode <$> sourceFileInfo <*> objectFileInfo)) - <> (command "script" $ simpleInfo (ScriptMode <$> sourceFileInfo - <*> (option - (optionList [ ("literate" , TextDoc) - , ("result-only", ResultOnly) - , ("HTML" , HTMLDoc) - , ("JSON" , JSONDoc)]) - (long "outfmt" <> value TextDoc - <> help "Output format (literate(default)|result-only|HTML|JSON")) - <*> flag HaltOnErr ContinueOnErr ( - long "allow-errors" - <> help "Evaluate programs containing non-fatal type errors"))) + command "repl" (simpleInfo + (ReplMode <$> strOption (long "prompt" <> value ">=> " + <> metavar "STRING" <> help "REPL prompt"))) + <> command "web" (simpleInfo (WebMode <$> sourceFileInfo)) + <> command "watch" (simpleInfo (WatchMode <$> sourceFileInfo)) + <> command "export" (simpleInfo (ExportMode <$> sourceFileInfo <*> objectFileInfo)) + <> command "script" (simpleInfo (ScriptMode <$> sourceFileInfo + <*> option + (optionList [ ("literate" , TextDoc) + , ("result-only", ResultOnly) + , ("html" , HTMLDoc) + , ("json" , JSONDoc)]) + (long "outfmt" <> value TextDoc <> + helpOption "Output format" "literate (default) | result-only | html | json") + <*> flag HaltOnErr ContinueOnErr ( + long "allow-errors" + <> help "Evaluate programs containing non-fatal type errors"))) where sourceFileInfo = argument str (metavar "FILE" <> help "Source program") objectFileInfo = argument str (metavar "OBJFILE" <> help "Output path (.o file)") optionList :: [(String, a)] -> ReadM a -optionList opts = eitherReader $ \s -> case lookup s opts of +optionList opts = eitherReader \s -> case lookup s opts of Just x -> Right x Nothing -> Left $ "Bad option. Expected one of: " ++ show (map fst opts) parseEvalOpts :: Parser EvalConfig parseEvalOpts = EvalConfig - <$> (option - (optionList [ ("LLVM", LLVM) - , ("LLVM-CUDA", LLVMCUDA) - , ("LLVM-MC", LLVMMC) - , ("interp", Interp)]) - (long "backend" <> value LLVM <> help "Backend (LLVM(default)|LLVM-CUDA|interp)")) - <*> (optional $ strOption $ long "logto" + <$> option + (optionList [ ("llvm", LLVM) + , ("llvm-cuda", LLVMCUDA) + , ("llvm-mc", LLVMMC) + , ("interpreter", Interpreter)]) + (long "backend" <> value LLVM <> + helpOption "Backend" "llvm (default) | llvm-cuda | llvm-mc | interpreter") + <*> optional (strOption $ long "logto" <> metavar "FILE" <> help "File to log to" <> showDefault) diff --git a/src/lib/Actor.hs b/src/lib/Actor.hs index f152b2ba6..fbf0cb6e8 100644 --- a/src/lib/Actor.hs +++ b/src/lib/Actor.hs @@ -43,7 +43,7 @@ runActor (Actor m) = do linksRef <- newIORef [] chan <- newBackChan tid <- myThreadId - let p = (Proc Trap tid (asErrPChan chan)) + let p = Proc Trap tid (asErrPChan chan) runReaderT m (ActorConfig p chan linksRef) subChan :: (a -> b) -> PChan b -> PChan a @@ -123,7 +123,7 @@ receive :: MonadActor msg m => m msg receive = receiveF Just newBackChan :: IO (BackChan a) -newBackChan = liftM2 BackChan (newIORef []) (newChan) +newBackChan = liftM2 BackChan (newIORef []) newChan readBackChan :: BackChan a -> IO a readBackChan (BackChan ptr chan) = do xs <- readIORef ptr @@ -173,6 +173,6 @@ logServer = flip evalStateT (mempty, []) $ forever $ do Push x -> do modify $ onFst (<> x) subscribers <- gets snd - mapM_ (flip send x) subscribers + mapM_ (`send` x) subscribers -- TODO: state machine? diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index f17115bfd..1da5eac39 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -31,7 +31,9 @@ import GHC.Stack -- `DerivWrt` holds the (out-expr) variables that we're differentiating with -- respect to (including refs but not regions). -data DerivWrt = DerivWrt { activeVars :: Env Type, _activeEffs :: [Effect], rematVars :: Env Type } +data DerivWrt = DerivWrt { activeVars :: Env Type + , _activeEffs :: [Effect] + , rematVars :: Env Type } -- `Tangents` holds the tangent values and the region variables that are -- arguments to the linearized function. data TangentEnv = TangentEnv { tangentVals :: SubstEnv, activeRefs :: [Name], rematVals :: SubstEnv } @@ -42,10 +44,10 @@ newtype LinA a = LinA { runLinA :: PrimalM (a, TangentM a) } linearize :: Scope -> Atom -> Atom linearize scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do - buildLam b PureArrow $ \x@(Var v) -> do + buildLam b PureArrow \x@(Var v) -> do (y, yt) <- flip runReaderT (DerivWrt (varAsEnv v) [] mempty) $ runLinA $ linearizeBlock (b@>x) block -- TODO: check linearity - fLin <- buildLam (fmap tangentType b) LinArrow $ \xt -> runReaderT yt $ TangentEnv (v @> xt) [] mempty + fLin <- buildLam (fmap tangentType b) LinArrow \xt -> runReaderT yt $ TangentEnv (v @> xt) [] mempty fLinChecked <- checkEmbed fLin return $ PairVal y fLinChecked @@ -107,7 +109,7 @@ linearizeExpr env expr = case expr of return (ans, applyLinToTangents linLam) where linearizeInactiveAlt (Abs bs body) = do - buildNAbs bs $ \xs -> tangentFunAsLambda $ linearizeBlock (env <> newEnv bs xs) body + buildNAbs bs \xs -> tangentFunAsLambda $ linearizeBlock (env <> newEnv bs xs) body _ -> LinA $ do expr' <- substEmbed env expr runLinA $ case expr' of @@ -133,12 +135,14 @@ linearizeOp op = case op of FstRef ref -> (FstRef <$> la ref ) `bindLin` emitOp SndRef ref -> (SndRef <$> la ref ) `bindLin` emitOp Select p t f -> (Select p <$> la t <*> la f ) `bindLin` emitOp - PtrLoad _ -> emitWithZero -- XXX: This assumes that pointers are always constants + -- XXX: This assumes that pointers are always constants + PtrLoad _ -> emitWithZero + PtrStore _ _ -> emitDiscrete PtrOffset _ _ -> emitDiscrete + IOAlloc _ _ -> emitDiscrete + IOFree _ -> emitDiscrete TabCon ty xs -> (TabCon ty <$> traverse la xs) `bindLin` emitOp Inject _ -> emitDiscrete - GetPtr _ -> emitDiscrete - MakePtrType _ -> emitDiscrete SliceOffset _ _ -> emitDiscrete SliceCurry _ _ -> emitDiscrete VectorBinOp _ _ _ -> notImplemented @@ -168,6 +172,7 @@ linearizeOp op = case op of VariantLift ts v -> (VariantLift ts <$> la v) `bindLin` emitOp VariantSplit ts v -> (VariantSplit ts <$> la v) `bindLin` emitOp FFICall _ _ _ -> error $ "Can't differentiate through an FFI call" + ThrowException _ -> notImplemented where emitDiscrete = if isTrivialForAD (Op op) then LinA $ withZeroTangent <$> emitOp op @@ -251,18 +256,27 @@ linearizeHof :: SubstEnv -> Hof -> LinA Atom linearizeHof env hof = case hof of For ~(RegularFor d) ~(LamVal i body) -> LinA $ do i' <- mapM (substEmbed env) i - (ansWithLinTab, vi'') <- buildForAux d i' $ \i''@(Var vi'') -> + (ansWithLinTab, vi'') <- buildForAux d i' \i''@(Var vi'') -> (,vi'') <$> (willRemat vi'' $ tangentFunAsLambda $ linearizeBlock (env <> i@>i'') body) (ans, linTab) <- unzipTab ansWithLinTab - return (ans, buildFor d i' $ \i'' -> provideRemat vi'' i'' $ app linTab i'' >>= applyLinToTangents) + return (ans, buildFor d i' \i'' -> provideRemat vi'' i'' $ app linTab i'' >>= applyLinToTangents) Tile _ _ _ -> notImplemented RunWriter lam -> linearizeEff Nothing lam True (const RunWriter) (emitRunWriter "r") Writer RunReader val lam -> linearizeEff (Just val) lam False RunReader (emitRunReader "r") Reader RunState val lam -> linearizeEff (Just val) lam True RunState (emitRunState "r") State + RunIO ~(Lam (Abs _ (arrow, body))) -> LinA $ do + arrow' <- substEmbed env arrow + -- TODO: consider the possibility of other effects here besides IO + lam <- buildLam (Ignore UnitTy) arrow' \_ -> + tangentFunAsLambda $ linearizeBlock env body + result <- emit $ Hof $ RunIO lam + (ans, linLam) <- fromPair result + return (ans, applyLinToTangents linLam) -- TODO: Consider providing an upper bound for the number of while iterations as a hint. -- In the current form the best we can do is try to use some dynamically growing lists, -- but that won't work on the GPU. - While _ _ -> notImplemented + While _ -> notImplemented + CatchException _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" PTileReduce _ _ -> error "Unexpected PTileReduce" @@ -287,19 +301,19 @@ linearizeHof env hof = case hof of let (BinaryFunTy _ b _ _) = getType lam' let RefTy _ wTy = binderType b return $ emitter $ tangentType wTy - valEmitter $ \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do + valEmitter \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody return (ans, lin) - linearizeEffectFun :: EffectName -> Atom -> PrimalM (Atom, Var) - linearizeEffectFun effName ~(BinaryFunVal h ref eff body) = do + linearizeEffectFun :: RWS -> Atom -> PrimalM (Atom, Var) + linearizeEffectFun rws ~(BinaryFunVal h ref eff body) = do h' <- mapM (substEmbed env) h - buildLamAux h' (const $ return PureArrow) $ \h''@(Var hVar) -> do + buildLamAux h' (const $ return PureArrow) \h''@(Var hVar) -> do let env' = env <> h@>h'' eff' <- substEmbed env' eff ref' <- mapM (substEmbed env') ref - buildLamAux ref' (const $ return $ PlainArrow eff') $ \ref''@(Var refVar) -> - extendWrt [refVar] [(effName, varName hVar)] $ + buildLamAux ref' (const $ return $ PlainArrow eff') \ref''@(Var refVar) -> + extendWrt [refVar] [RWSEffect rws (varName hVar)] $ (,refVar) <$> (tangentFunAsLambda $ linearizeBlock (env' <> ref@>ref'') body) linearizePrimCon :: Con -> LinA Atom @@ -329,7 +343,7 @@ linearizeAtom atom = case atom of Con con -> linearizePrimCon con Lam (Abs i (TabArrow, body)) -> LinA $ do wrt <- ask - return (atom, buildLam i TabArrow $ \i' -> + return (atom, buildLam i TabArrow \i' -> rematPrimal wrt $ linearizeBlock (i@>i') body) DataCon _ _ _ _ -> notImplemented -- Need to synthesize or look up a tangent ADT Record elems -> Record <$> traverse linearizeAtom elems @@ -382,7 +396,7 @@ addTangent x y = case getType x of RecordTy (NoExt tys) -> do elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) return $ Record $ restructure elems tys - TabTy b _ -> buildFor Fwd b $ \i -> bindM2 addTangent (tabGet x i) (tabGet y i) + TabTy b _ -> buildFor Fwd b \i -> bindM2 addTangent (tabGet x i) (tabGet y i) TC con -> case con of BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y BaseType (Vector _) -> emitOp $ VectorBinOp FAdd x y @@ -407,20 +421,23 @@ tangentFunAsLambda :: LinA Atom -> PrimalM Atom tangentFunAsLambda m = do (ans, tanFun) <- runLinA m DerivWrt activeVars effs remats <- ask - let hs = map (Bind . (:>TyKind) . snd) effs + let hs = map (Bind . (:>TyKind) . effectRegion) effs let rematList = envAsVars remats liftM (PairVal ans) $ lift $ do - tanLam <- makeLambdas rematList $ \rematArgs -> - buildNestedLam PureArrow hs $ \hVals -> do + tanLam <- makeLambdas rematList \rematArgs -> + buildNestedLam PureArrow hs \hVals -> do let hVarNames = map (\(Var (v:>_)) -> v) hVals - let effs' = zipWith (\(effName, _) v -> (effName, v)) effs hVarNames + -- TODO: handle exception effect too + let effs' = zipWith (\(RWSEffect rws _) v -> RWSEffect rws v) effs hVarNames -- want to use tangents here, not the original binders - let regionMap = newEnv (map ((:>()) . snd) effs) hVals + let regionMap = newEnv (map ((:>()) . effectRegion) effs) hVals -- TODO: Only bind tangents for free variables? let activeVarBinders = map (Bind . fmap (tangentRefRegion regionMap)) $ envAsVars activeVars - buildNestedLam PureArrow activeVarBinders $ \activeVarArgs -> - buildLam (Ignore UnitTy) (PlainArrow $ EffectRow effs' Nothing) $ \_ -> - runReaderT tanFun $ TangentEnv (newEnv (envNames activeVars) activeVarArgs) hVarNames (newEnv rematList $ fmap Var rematArgs) + buildNestedLam PureArrow activeVarBinders \activeVarArgs -> + buildLam (Ignore UnitTy) (PlainArrow $ EffectRow (S.fromList effs') Nothing) \_ -> + runReaderT tanFun $ TangentEnv + (newEnv (envNames activeVars) activeVarArgs) hVarNames + (newEnv rematList $ fmap Var rematArgs) case rematList of [] -> return tanLam _ -> deShadow tanLam <$> getScope @@ -433,7 +450,7 @@ tangentFunAsLambda m = do return $ Lam $ makeAbs (Bind v) (PureArrow, block) makeLambdas [] f = f [] - makeLambdas (v:vs) f = makeLambda v $ \x -> makeLambdas vs $ \xs -> f (x:xs) + makeLambdas (v:vs) f = makeLambda v \x -> makeLambdas vs \xs -> f (x:xs) -- This doesn't work if we have references inside pairs, tables etc. -- TODO: something more general and cleaner. @@ -442,6 +459,11 @@ tangentFunAsLambda m = do RefTy ~(Var h) a -> RefTy (regEnv ! h) $ tangentType a _ -> tangentType ty + effectRegion eff = case eff of + RWSEffect _ h -> h + ExceptionEffect -> error "TODO!" + IOEffect -> error "TODO!" + -- Inverse of tangentFunAsLambda. Should be used inside a returned tangent action. applyLinToTangents :: Atom -> TangentM Atom applyLinToTangents f = do @@ -525,7 +547,7 @@ type TransposeM a = ReaderT TransposeEnv Embed a transpose :: Scope -> Atom -> Atom transpose scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do - buildLam (Bind $ "ct" :> getType block) LinArrow $ \ct -> do + buildLam (Bind $ "ct" :> getType block) LinArrow \ct -> do snd <$> (flip runReaderT mempty $ withLinVar b $ transposeBlock block ct) transposeBlock :: Block -> Atom -> TransposeM () @@ -571,7 +593,7 @@ transposeExpr expr ct = case expr of void $ emit $ Case e' alts' UnitTy where transposeNonlinAlt (Abs bs body) = - buildNAbs bs $ \xs -> do + buildNAbs bs \xs -> do localNonlinSubst (newEnv bs xs) $ transposeBlock body ct return UnitVal @@ -588,8 +610,6 @@ transposeOp op ct = case op of else transposeAtom y =<< mul ct =<< substNonlin x ScalarBinOp FDiv x y -> transposeAtom x =<< div' ct =<< substNonlin y ScalarBinOp _ _ _ -> notLinear - GetPtr _ -> notLinear - MakePtrType _ -> notLinear PrimEffect refArg m -> do refArg' <- substTranspose linRefSubst refArg let emitEff = emitOp . PrimEffect refArg' @@ -602,7 +622,7 @@ transposeOp op ct = case op of MPut x -> do transposeAtom x =<< emitEff MGet void $ emitEff $ MPut $ zeroAt $ getType x - TabCon ~(TabTy b _) es -> forM_ (enumerate es) $ \(i, e) -> do + TabCon ~(TabTy b _) es -> forM_ (enumerate es) \(i, e) -> do transposeAtom e =<< tabGet ct =<< intToIndexE (binderType b) (IdxRepVal $ fromIntegral i) IndexRef _ _ -> notImplemented FstRef _ -> notImplemented @@ -616,8 +636,11 @@ transposeOp op ct = case op of RecordSplit _ _ -> notImplemented VariantLift _ _ -> notImplemented VariantSplit _ _ -> notImplemented + PtrStore _ _ -> notLinear PtrLoad _ -> notLinear - PtrOffset _ _ -> notLinear + PtrOffset _ _ -> notLinear + IOAlloc _ _ -> notLinear + IOFree _ -> notLinear Inject _ -> notLinear SliceOffset _ _ -> notLinear SliceCurry _ _ -> notLinear @@ -625,9 +648,10 @@ transposeOp op ct = case op of ToOrdinal _ -> notLinear IdxSetSize _ -> notLinear ThrowError _ -> notLinear - FFICall _ _ _ -> notLinear + FFICall _ _ _ -> notLinear DataConTag _ -> notLinear ToEnum _ _ -> notLinear + ThrowException _ -> notLinear where -- Both nonlinear operations and operations on discrete types, where linearity doesn't make sense notLinear = error $ "Can't transpose a non-linear operation: " ++ pprint op @@ -655,32 +679,34 @@ linAtomRef a = error $ "Not a linear var: " ++ pprint a transposeHof :: Hof -> Atom -> TransposeM () transposeHof hof ct = case hof of For ~(RegularFor d) ~(Lam (Abs b (_, body))) -> - void $ buildFor (flipDir d) b $ \i -> do + void $ buildFor (flipDir d) b \i -> do ct' <- tabGet ct i localNonlinSubst (b@>i) $ transposeBlock body ct' return UnitVal where flipDir dir = case dir of Fwd -> Rev; Rev -> Fwd RunReader r ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do - (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" (getType r) $ \ref -> do + (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" (getType r) \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ct return UnitVal transposeAtom r ctr RunWriter ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do (ctBody, ctEff) <- fromPair ct - void $ emitRunReader "r" ctEff $ \ref -> do + void $ emitRunReader "r" ctEff \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody return UnitVal RunState s ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do (ctBody, ctState) <- fromPair ct - (_, cts) <- (fromPair =<<) $ emitRunState "s" ctState $ \ref -> do + (_, cts) <- (fromPair =<<) $ emitRunState "s" ctState \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody return UnitVal transposeAtom s cts - Tile _ _ _ -> notImplemented - While _ _ -> notImplemented - Linearize _ -> error "Unexpected linearization" - Transpose _ -> error "Unexpected transposition" - PTileReduce _ _ -> error "Unexpected PTileReduce" + Tile _ _ _ -> notImplemented + While _ -> notImplemented + RunIO _ -> notImplemented + CatchException _ -> notImplemented + Linearize _ -> error "Unexpected linearization" + Transpose _ -> error "Unexpected transposition" + PTileReduce _ _ -> error "Unexpected PTileReduce" transposeAtom :: Atom -> Atom -> TransposeM () transposeAtom atom ct = case atom of @@ -694,7 +720,7 @@ transposeAtom atom ct = case atom of DataCon _ _ _ e -> void $ zipWithT transposeAtom e =<< getUnpacked ct Variant _ _ _ _ -> notImplemented TabVal b body -> - void $ buildFor Fwd b $ \i -> do + void $ buildFor Fwd b \i -> do ct' <- tabGet ct i localNonlinSubst (b@>i) $ transposeBlock body ct' return UnitVal @@ -741,11 +767,12 @@ freeLinVars x = do isLin :: HasVars a => a -> TransposeM Bool isLin x = not . null <$> freeLinVars x -isLinEff :: EffectSummary -> TransposeM Bool -isLinEff effs = do +isLinEff :: EffectRow -> TransposeM Bool +isLinEff (EffectRow effs Nothing) = do regions <- asks linRegions return $ not $ null $ effRegions `envIntersect` regions - where effRegions = newEnv (S.map snd effs) (repeat ()) + where effRegions = freeVars $ toList effs +isLinEff _ = error "Can't transpose polymorphic effects" emitCTToRef :: Maybe Atom -> Atom -> TransposeM () emitCTToRef mref ct = case mref of @@ -765,7 +792,7 @@ withLinVar :: Binder -> TransposeM a -> TransposeM (a, Atom) withLinVar b body = case singletonTypeVal (binderType b) of Nothing -> flip evalStateT Nothing $ do - ans <- emitRunWriter "ref" (binderType b) $ \ref -> do + ans <- emitRunWriter "ref" (binderType b) \ref -> do lift (localLinRef (b@>Just ref) body) >>= put . Just >> return UnitVal (,) <$> (fromJust <$> get) <*> (getSnd ans) Just x -> (,x) <$> (localLinRef (b@>Nothing) body) -- optimization to avoid accumulating into unit diff --git a/src/lib/Cat.hs b/src/lib/Cat.hs index 3edb9df41..01aa6d062 100644 --- a/src/lib/Cat.hs +++ b/src/lib/Cat.hs @@ -18,7 +18,6 @@ module Cat (CatT, MonadCat, runCatT, look, extend, scoped, looks, extendLocal, -- Monad for tracking monoidal state import Control.Applicative -import Control.Monad.Fail import Control.Monad.State.Strict import Control.Monad.Reader import Control.Monad.Writer @@ -42,7 +41,7 @@ instance (Monoid env, Monad m) => MonadCat env (CatT env m) where put (fullState <> x, localState <> x) scoped (CatT m) = CatT $ do originalState <- get - put $ (fst originalState, mempty) + put (fst originalState, mempty) ans <- m newLocalState <- gets snd put originalState @@ -51,7 +50,7 @@ instance (Monoid env, Monad m) => MonadCat env (CatT env m) where instance MonadCat env m => MonadCat env (StateT s m) where look = lift look extend x = lift $ extend x - scoped m = StateT $ \s -> do + scoped m = StateT \s -> do ((ans, s'), env) <- scoped $ runStateT m s return $ ((ans, env), s') @@ -146,7 +145,7 @@ catTraverse f inj xs env = runCatT (traverse (asCat f inj) xs) env catFoldM :: (Monoid env, Traversable t, Monad m) => (env -> a -> m env) -> env -> t a -> m env -catFoldM f env xs = liftM snd $ flip runCatT env $ forM_ xs $ \x -> do +catFoldM f env xs = liftM snd $ flip runCatT env $ forM_ xs \x -> do cur <- look new <- lift $ f cur x extend new @@ -157,7 +156,7 @@ catFold f env xs = runIdentity $ catFoldM (\e x -> Identity $ f e x) env xs catMapM :: (Monoid env, Traversable t, Monad m) => (env -> a -> m (b, env)) -> env -> t a -> m (t b, env) -catMapM f env xs = flip runCatT env $ forM xs $ \x -> do +catMapM f env xs = flip runCatT env $ forM xs \x -> do cur <- look (y, new) <- lift $ f cur x extend new diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 12349af6e..4fd4d6a27 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -13,34 +13,39 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildPi, getAllowedEffects, withEffects, modifyAllowedEffects, buildLam, EmbedT, Embed, MonadEmbed, buildScoped, runEmbedT, - runSubstEmbed, runEmbed, getScope, embedLook, + runSubstEmbed, runEmbed, getScope, embedLook, liftEmbed, app, add, mul, sub, neg, div', iadd, imul, isub, idiv, ilt, ieq, - fpow, flog, fLitLike, + fpow, flog, fLitLike, recGetHead, buildImplicitNaryLam, select, substEmbed, substEmbedR, emitUnpack, getUnpacked, fromPair, getFst, getSnd, getFstRef, getSndRef, naryApp, appReduce, appTryReduce, buildAbs, buildFor, buildForAux, buildForAnn, buildForAnnAux, emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, - embedExtend, unpackConsList, emitRunWriter, emitRunState, + embedExtend, unpackConsList, emitRunWriter, applyPreludeFunction, + emitRunState, emitMaybeCase, emitWhile, buildDataDef, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, - traverseAtom, ptrOffset, ptrLoad, evalBlockE, substTraversalDef, - TraversalDef, traverseDecls, traverseDecl, traverseBlock, traverseExpr, + ptrOffset, ptrLoad, unsafePtrLoad, + evalBlockE, substTraversalDef, + TraversalDef, traverseDecls, traverseDecl, traverseDeclsOpen, + traverseBlock, traverseExpr, traverseAtom, clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, zeroAt, transformModuleAsBlock, dropSub, appReduceTraversalDef, indexSetSizeE, indexToIntE, intToIndexE, freshVarE) where import Control.Applicative import Control.Monad -import Control.Monad.Fail import Control.Monad.Except hiding (Except) import Control.Monad.Reader import Control.Monad.Writer hiding (Alt) import Control.Monad.Identity import Control.Monad.State.Strict import Data.Foldable (toList) +import Data.List (elemIndex) +import Data.Maybe (fromJust) +import Data.String (fromString) import Data.Tuple (swap) import GHC.Stack @@ -104,13 +109,16 @@ emitOp op = emit $ Op op emitUnpack :: MonadEmbed m => Expr -> m [Atom] emitUnpack expr = getUnpacked =<< emit expr --- Assumes the decl binders are already fresh wrt current scope emitBlock :: MonadEmbed m => Block -> m Atom -emitBlock (Block decls result) = do - mapM_ emitDecl decls - case result of - Atom x -> return x - _ -> emit result +emitBlock block = emitBlockRec mempty block + +emitBlockRec :: MonadEmbed m => SubstEnv -> Block -> m Atom +emitBlockRec env (Block (Nest (Let ann b expr) decls) result) = do + expr' <- substEmbed env expr + x <- emitTo (binderNameHint b) ann expr' + emitBlockRec (env <> b@>x) $ Block decls result +emitBlockRec env (Block Empty (Atom atom)) = substEmbed env atom +emitBlockRec env (Block Empty expr) = substEmbed env expr >>= emit freshVarE :: MonadEmbed m => BinderInfo -> Binder -> m Var freshVarE bInfo b = do @@ -160,7 +168,7 @@ buildLam b arr body = buildDepEffLam b (const (return arr)) body buildDepEffLam :: MonadEmbed m => Binder -> (Atom -> m Arrow) -> (Atom -> m Atom) -> m Atom -buildDepEffLam b fArr fBody = liftM fst $ buildLamAux b fArr $ \x -> (,()) <$> fBody x +buildDepEffLam b fArr fBody = liftM fst $ buildLamAux b fArr \x -> (,()) <$> fBody x buildLamAux :: MonadEmbed m => Binder -> (Atom -> m Arrow) -> (Atom -> m (Atom, a)) -> m (Atom, a) @@ -176,7 +184,7 @@ buildLamAux b fArr fBody = do return (Lam $ makeAbs b' (arr, wrapDecls decls ans), aux) buildNAbs :: MonadEmbed m => Nest Binder -> ([Atom] -> m Atom) -> m Alt -buildNAbs bs body = liftM fst $ buildNAbsAux bs $ \xs -> (,()) <$> body xs +buildNAbs bs body = liftM fst $ buildNAbsAux bs \xs -> (,()) <$> body xs buildNAbsAux :: MonadEmbed m => Nest Binder -> ([Atom] -> m (Atom, a)) -> m (Alt, a) buildNAbsAux bs body = do @@ -186,6 +194,28 @@ buildNAbsAux bs body = do return (fmap Bind vs, result) return (Abs bs' $ wrapDecls decls ans, aux) +buildDataDef :: MonadEmbed m + => Name -> Nest Binder -> ([Atom] -> m [DataConDef]) -> m DataDef +buildDataDef tyConName paramBinders body = do + ((paramBinders', dataDefs), _) <- scopedDecls $ do + vs <- freshNestedBinders paramBinders + result <- body $ map Var $ toList vs + return (fmap Bind vs, result) + return $ DataDef tyConName paramBinders' dataDefs + +buildImplicitNaryLam :: MonadEmbed m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom +buildImplicitNaryLam Empty body = body [] +buildImplicitNaryLam (Nest b bs) body = + buildLam b ImplicitArrow \x -> do + bs' <- substEmbed (b@>x) bs + buildImplicitNaryLam bs' \xs -> body $ x:xs + +recGetHead :: Label -> Atom -> Atom +recGetHead l x = do + let (RecordTy (Ext r _)) = getType x + let i = fromJust $ elemIndex l $ map fst $ toList $ reflectLabels r + getProjection [i] x + buildScoped :: MonadEmbed m => m Atom -> m Block buildScoped m = do (ans, decls) <- scopedDecls m @@ -321,6 +351,15 @@ appReduce (Lam (Abs v (_, b))) a = runReaderT (evalBlockE substTraversalDef b) (v @> a) appReduce _ _ = error "appReduce expected a lambda as the first argument" +-- TODO: this would be more convenient if we could add type inference too +applyPreludeFunction :: MonadEmbed m => String -> [Atom] -> m Atom +applyPreludeFunction s xs = do + scope <- getScope + case envLookup scope fname of + Nothing -> error $ "Function not defined yet: " ++ s + Just (ty, _) -> naryApp (Var (fname:>ty)) xs + where fname = GlobalName (fromString s) + appTryReduce :: MonadEmbed m => Atom -> Atom -> m Atom appTryReduce f x = case f of Lam _ -> appReduce f x @@ -329,6 +368,10 @@ appTryReduce f x = case f of ptrOffset :: MonadEmbed m => Atom -> Atom -> m Atom ptrOffset x i = emitOp $ PtrOffset x i +unsafePtrLoad :: MonadEmbed m => Atom -> m Atom +unsafePtrLoad x = emit $ Hof $ RunIO $ Lam $ Abs (Ignore UnitTy) $ + (PlainArrow (oneEffect IOEffect), Block Empty (Op (PtrLoad x))) + ptrLoad :: MonadEmbed m => Atom -> m Atom ptrLoad x = emitOp $ PtrLoad x @@ -341,6 +384,21 @@ unpackConsList xs = case getType xs of liftM (x:) $ unpackConsList rest _ -> error $ "Not a cons list: " ++ pprint (getType xs) +emitWhile :: MonadEmbed m => m Atom -> m () +emitWhile body = do + eff <- getAllowedEffects + lam <- buildLam (Ignore UnitTy) (PlainArrow eff) \_ -> body + void $ emit $ Hof $ While lam + +emitMaybeCase :: MonadEmbed m => Atom -> m Atom -> (Atom -> m Atom) -> m Atom +emitMaybeCase scrut nothingCase justCase = do + let (MaybeTy a) = getType scrut + nothingAlt <- buildNAbs Empty \[] -> nothingCase + justAlt <- buildNAbs (Nest (Bind ("x":>a)) Empty) \[x] -> justCase x + let (Abs _ nothingBody) = nothingAlt + let resultTy = getType nothingBody + emit $ Case scrut [nothingAlt, justAlt] resultTy + emitRunWriter :: MonadEmbed m => Name -> Type -> (Atom -> m Atom) -> m Atom emitRunWriter v ty body = do emit . Hof . RunWriter =<< mkBinaryEffFun Writer v ty body @@ -353,15 +411,16 @@ emitRunState :: MonadEmbed m => Name -> Atom -> (Atom -> m Atom) -> m Atom emitRunState v x0 body = do emit . Hof . RunState x0 =<< mkBinaryEffFun State v (getType x0) body -mkBinaryEffFun :: MonadEmbed m => EffectName -> Name -> Type -> (Atom -> m Atom) -> m Atom -mkBinaryEffFun newEff v ty body = do +mkBinaryEffFun :: MonadEmbed m => RWS -> Name -> Type -> (Atom -> m Atom) -> m Atom +mkBinaryEffFun rws v ty body = do eff <- getAllowedEffects - buildLam (Bind ("h":>TyKind)) PureArrow $ \r@(Var (rName:>_)) -> do - let arr = PlainArrow $ extendEffect (newEff, rName) eff + buildLam (Bind ("h":>TyKind)) PureArrow \r@(Var (rName:>_)) -> do + let arr = PlainArrow $ extendEffect (RWSEffect rws rName) eff buildLam (Bind (v:> RefTy r ty)) arr body buildForAnnAux :: MonadEmbed m => ForAnn -> Binder -> (Atom -> m (Atom, a)) -> m (Atom, a) buildForAnnAux ann i body = do + -- TODO: consider only tracking the effects that are actually needed. eff <- getAllowedEffects (lam, aux) <- buildLamAux i (const $ return $ PlainArrow eff) body (,aux) <$> (emit $ Hof $ For ann lam) @@ -372,22 +431,23 @@ buildForAnn ann i body = fst <$> buildForAnnAux ann i (\x -> (,()) <$> body x) buildForAux :: MonadEmbed m => Direction -> Binder -> (Atom -> m (Atom, a)) -> m (Atom, a) buildForAux = buildForAnnAux . RegularFor +-- Do we need this variant? buildFor :: MonadEmbed m => Direction -> Binder -> (Atom -> m Atom) -> m Atom buildFor = buildForAnn . RegularFor buildNestedLam :: MonadEmbed m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom buildNestedLam _ [] f = f [] buildNestedLam arr (b:bs) f = - buildLam b arr $ \x -> buildNestedLam arr bs $ \xs -> f (x:xs) + buildLam b arr \x -> buildNestedLam arr bs \xs -> f (x:xs) tabGet :: MonadEmbed m => Atom -> Atom -> m Atom tabGet x i = emit $ App x i unzipTab :: MonadEmbed m => Atom -> m (Atom, Atom) unzipTab tab = do - fsts <- buildLam (Bind ("i":>binderType v)) TabArrow $ \i -> + fsts <- buildLam (Bind ("i":>binderType v)) TabArrow \i -> liftM fst $ app tab i >>= fromPair - snds <- buildLam (Bind ("i":>binderType v)) TabArrow $ \i -> + snds <- buildLam (Bind ("i":>binderType v)) TabArrow \i -> liftM snd $ app tab i >>= fromPair return (fsts, snds) where TabTy v _ = getType tab @@ -453,9 +513,9 @@ instance Monad m => MonadEmbed (EmbedT m) where instance MonadEmbed m => MonadEmbed (ReaderT r m) where embedLook = lift embedLook embedExtend x = lift $ embedExtend x - embedScoped m = ReaderT $ \r -> embedScoped $ runReaderT m r + embedScoped m = ReaderT \r -> embedScoped $ runReaderT m r embedAsk = lift embedAsk - embedLocal v m = ReaderT $ \r -> embedLocal v $ runReaderT m r + embedLocal v m = ReaderT \r -> embedLocal v $ runReaderT m r instance MonadEmbed m => MonadEmbed (StateT s m) where embedLook = lift embedLook @@ -573,6 +633,14 @@ scopedDecls m = do (ans, (_, decls)) <- embedScoped m return (ans, decls) +liftEmbed :: MonadEmbed m => Embed a -> m a +liftEmbed action = do + envR <- embedAsk + envC <- embedLook + let (ans, envC') = runIdentity $ runEmbedT' action (envR, envC) + embedExtend envC' + return ans + -- === generic traversal === type TraversalDef m = (Decl -> m SubstEnv, Expr -> m Expr, Atom -> m Atom) @@ -646,7 +714,7 @@ traverseExpr def@(_, _, fAtom) expr = case expr of where traverseAlt (Abs bs body) = do bs' <- mapM (mapM fAtom) bs - buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ evalBlockE def body + buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ evalBlockE def body traverseAtom :: forall m . (MonadEmbed m, MonadReader SubstEnv m) => TraversalDef m -> Atom -> m Atom @@ -683,7 +751,7 @@ traverseAtom def@(_, _, fAtom) atom = case atom of BoxedRef b ptr size body -> do ptr' <- fAtom ptr size' <- buildScoped $ evalBlockE def size - Abs b' (decls, body') <- buildAbs b $ \x -> + Abs b' (decls, body') <- buildAbs b \x -> extendR (b@>x) $ evalBlockE def (Block Empty $ Atom body) case decls of Empty -> return $ BoxedRef b' ptr' size' body' @@ -701,7 +769,7 @@ traverseAtom def@(_, _, fAtom) atom = case atom of traverseAAlt (Abs bs a) = do bs' <- mapM (mapM fAtom) bs - (Abs bs'' b) <- buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ fAtom a + (Abs bs'' b) <- buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ fAtom a case b of Block Empty (Atom r) -> return $ Abs bs'' r _ -> error "ACase alternative traversal has emitted decls or exprs!" @@ -778,7 +846,7 @@ indexToIntE idx = case getType idx of (offsets, _) <- scanM (\sz prev -> (prev,) <$> iadd sz prev) sizes (IdxRepVal 0) -- Build and apply a case expression alts <- flip mapM (zip (toList offsets) (toList types)) $ - \(offset, subty) -> buildNAbs (toNest [Ignore subty]) $ \[subix] -> do + \(offset, subty) -> buildNAbs (toNest [Ignore subty]) \[subix] -> do i <- indexToIntE subix iadd offset i emit $ Case idx alts IdxRepTy diff --git a/src/lib/Env.hs b/src/lib/Env.hs index bfb2dd93e..456c613ab 100644 --- a/src/lib/Env.hs +++ b/src/lib/Env.hs @@ -39,6 +39,7 @@ data NameSpace = | InferenceName | SumName | FFIName + | TypeClassGenName -- names generated for type class dictionaries | AbstractedPtrName -- used in `abstractPtrLiterals` in Imp lowering | TopFunctionName -- top-level Imp functions | AllocPtrName -- used for constructing dests in Imp lowering @@ -163,6 +164,7 @@ env ! v = case envLookup env v of isGlobal :: VarP ann -> Bool isGlobal (GlobalName _ :> _) = True isGlobal (GlobalArrayName _ :> _) = True +isGlobal (Name TypeClassGenName _ _ :> _) = True isGlobal _ = False isGlobalBinder :: BinderP ann -> Bool diff --git a/src/lib/Export.hs b/src/lib/Export.hs new file mode 100644 index 000000000..0db91a501 --- /dev/null +++ b/src/lib/Export.hs @@ -0,0 +1,226 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RecordWildCards #-} + +module Export ( + exportFunctions, prepareFunctionForExport, exportedSignatureDesc, + ExportedSignature (..), ExportArrayType (..), ExportArg (..), ExportResult (..), + ) where + +import Control.Monad.State.Strict +import Control.Monad.Writer hiding (pass) +import qualified Data.Text as T +import Data.String +import Data.Foldable +import Data.List (nub, intercalate) + +import Algebra +import Syntax +import Embed +import Cat +import Env +import Type +import Simplify +import Imp +import JIT +import Logging +import LLVMExec +import PPrint +import Optimize + +exportFunctions :: FilePath -> [(String, Atom)] -> TopEnv -> IO () +exportFunctions objPath funcs env = do + let names = fmap fst funcs + unless (length (nub names) == length names) $ liftEitherIO $ + throw CompilerErr "Duplicate export names" + modules <- forM funcs $ \(name, funcAtom) -> do + let (impModule, _) = prepareFunctionForExport env name funcAtom + (,[name]) <$> execLogger Nothing (flip impToLLVM impModule) + exportObjectFile objPath modules + + +type CArgList = [IBinder] -- ^ List of arguments to the C call +data CArgEnv = CArgEnv { -- | Maps scalar atom binders to their CArgs. All atoms are Vars. + cargScalarScope :: Env Atom + -- | Tracks the CArg names used so far (globally scoped, unlike Embed) + , cargScope :: Env () } +type CArgM = WriterT CArgList (CatT CArgEnv Embed) + +instance Semigroup CArgEnv where + (CArgEnv a1 a2) <> (CArgEnv b1 b2) = CArgEnv (a1 <> b1) (a2 <> b2) + +instance Monoid CArgEnv where + mempty = CArgEnv mempty mempty + +runCArg :: CArgEnv -> CArgM a -> Embed (a, [IBinder], CArgEnv) +runCArg initEnv m = repack <$> runCatT (runWriterT m) initEnv + where repack ((ans, cargs), env) = (ans, cargs, env) + +prepareFunctionForExport :: TopEnv -> String -> Atom -> (ImpModule, ExportedSignature) +prepareFunctionForExport env nameStr func = do + -- Create a module that simulates an application of arguments to the function + -- TODO: Assert that the type of func is closed? + let ((dest, cargs, apiDesc), (_, decls)) = flip runEmbed (freeVars func) $ do + (args, cargArgs, cargEnv) <- runCArg mempty $ createArgs $ getType func + let (atomArgs, exportedArgSig) = unzip args + resultAtom <- naryApp func atomArgs + void $ emitTo outputName PlainLet $ Atom resultAtom + ((resultDest, exportedResSig), cdestArgs, _) <- runCArg cargEnv $ createDest mempty $ getType resultAtom + let cargs' = cargArgs <> cdestArgs + let exportedCCallSig = fmap (\(Bind (v:>_)) -> v) cargs' + return (resultDest, cargs', ExportedSignature{..}) + + let coreModule = Module Core decls mempty + let defunctionalized = simplifyModule env coreModule + let Module _ optDecls optBindings = optimizeModule defunctionalized + let (_, LetBound PlainLet outputExpr) = optBindings ! outputName + let block = Block optDecls outputExpr + + let name = Name TopFunctionName (fromString nameStr) 0 + let (_, impModule, _) = toImpModule env LLVM CEntryFun name cargs (Just dest) block + (impModule, apiDesc) + where + outputName = GlobalName "_ans_" + + createArgs :: Type -> CArgM [(Atom, ExportArg)] + createArgs ty = case ty of + PiTy b arrow result | arrow /= TabArrow -> do + argSubst <- looks cargScalarScope + let visibility = case arrow of + PlainArrow Pure -> ExplicitArg + PlainArrow _ -> error $ "Effectful functions cannot be exported" + ImplicitArrow -> ImplicitArg + _ -> error $ "Unexpected type for an exported function: " ++ pprint ty + (:) <$> createArg visibility (subst (argSubst, mempty) b) <*> createArgs result + _ -> return [] + + createArg :: ArgVisibility -> Binder -> CArgM (Atom, ExportArg) + createArg vis b = case ty of + BaseTy bt@(Scalar sbt) -> do + ~v@(Var (name:>_)) <- newCVar bt + extend $ mempty { cargScalarScope = b @> (Var $ name :> BaseTy bt) } + return (v, ExportScalarArg vis name sbt) + TabTy _ _ -> createTabArg vis mempty ty + _ -> error $ "Unsupported arg type: " ++ pprint ty + where ty = binderType b + + createTabArg :: ArgVisibility -> IndexStructure -> Type -> CArgM (Atom, ExportArg) + createTabArg vis idx ty = case ty of + BaseTy bt@(Scalar sbt) -> do + ~v@(Var (name:>_)) <- newCVar (ptrTy bt) + destAtom <- ptrLoad =<< applyIdxs v idx + funcArgScope <- looks cargScope + let exportArg = ExportArrayArg vis name $ case getRectShape funcArgScope idx of + Just rectShape -> RectContArrayPtr sbt rectShape + Nothing -> GeneralArrayPtr sbt + return (destAtom, exportArg) + TabTy b elemTy -> do + buildLamAux b (const $ return TabArrow) $ \(Var i) -> do + elemTy' <- substEmbed (b@>Var i) elemTy + createTabArg vis (idx <> Nest (Bind i) Empty) elemTy' + _ -> unsupported + where unsupported = error $ "Unsupported table type suffix: " ++ pprint ty + + createDest :: IndexStructure -> Type -> CArgM (Atom, ExportResult) + createDest idx ty = case ty of + BaseTy bt@(Scalar sbt) -> do + ~v@(Var (name:>_)) <- newCVar (ptrTy bt) + dest <- Con . BaseTypeRef <$> applyIdxs v idx + funcArgScope <- looks cargScope + let exportResult = case idx of + Empty -> ExportScalarResultPtr name sbt + _ -> ExportArrayResult name $ case getRectShape funcArgScope idx of + Just rectShape -> RectContArrayPtr sbt rectShape + Nothing -> GeneralArrayPtr sbt + return (dest, exportResult) + TabTy b elemTy -> do + (destTab, exportResult) <- buildLamAux b (const $ return TabArrow) $ \(Var i) -> do + elemTy' <- substEmbed (b@>Var i) elemTy + createDest (idx <> Nest (Bind i) Empty) elemTy' + return (Con $ TabRef destTab, exportResult) + _ -> unsupported + where unsupported = error $ "Unsupported result type: " ++ pprint ty + + -- TODO: I guess that the address space depends on the backend? + -- TODO: Have an ExternalPtr tag? + ptrTy ty = PtrType (Heap CPU, ty) + + getRectShape :: Env () -> IndexStructure -> Maybe [Either Name Int] + getRectShape scope idx = traverse (dimShape . binderType) $ toList idx + where + dimShape dimTy = case dimTy of + Fin (IdxRepVal n) -> Just $ Right $ fromIntegral n + Fin (Var v) | v `isin` scope -> Just $ Left $ varName v + _ -> Nothing + + newCVar :: BaseType -> CArgM Atom + newCVar bt = do + name <- genFresh (Name CArgName "arg" 0) <$> looks cargScope + extend $ mempty { cargScope = name @> () } + tell [Bind $ name :> bt] + return $ Var $ name :> BaseTy bt + +-- === Exported function signature === + +data ExportArrayType = GeneralArrayPtr ScalarBaseType + | RectContArrayPtr ScalarBaseType [Either Name Int] +data ArgVisibility = ImplicitArg | ExplicitArg +data ExportArg = ExportArrayArg ArgVisibility Name ExportArrayType + | ExportScalarArg ArgVisibility Name ScalarBaseType +data ExportResult = ExportArrayResult Name ExportArrayType + | ExportScalarResultPtr Name ScalarBaseType +data ExportedSignature = + ExportedSignature { exportedArgSig :: [ExportArg] + , exportedResSig :: ExportResult + , exportedCCallSig :: [Name] + } + +-- Serialization + +exportedSignatureDesc :: ExportedSignature -> (String, String, String) +exportedSignatureDesc ExportedSignature{..} = + ( intercalate "," $ fmap show exportedArgSig + , show exportedResSig + , intercalate "," $ fmap showCArgName exportedCCallSig + ) + +showExportSBT :: ScalarBaseType -> String +showExportSBT sbt = case sbt of + Int64Type -> "i64" + Int32Type -> "i32" + Word8Type -> "u8" + Float64Type -> "f64" + Float32Type -> "f32" + +showCArgName :: Name -> String +showCArgName ~name@(Name namespace tag idx) = case namespace of + CArgName -> T.unpack tag <> show idx + _ -> error $ "Expected a CArgName namespace: " ++ show name + +instance Show ExportArrayType where + show arr = case arr of + GeneralArrayPtr sbt -> showExportSBT sbt <> "[?]" + RectContArrayPtr sbt shape -> showExportSBT sbt <> showShape shape + where + showShape shape = "[" <> (intercalate "," $ fmap showDim shape) <> "]" + showDim size = case size of + Left name -> showCArgName name + Right lit -> show lit + +instance Show ExportArg where + show arg = case arg of + ExportArrayArg vis name ty -> showVis vis <> showCArgName name <> ":" <> show ty + ExportScalarArg vis name sbt -> showVis vis <> showCArgName name <> ":" <> showExportSBT sbt + where + showVis ImplicitArg = "?" + showVis ExplicitArg = "" + +instance Show ExportResult where + show res = case res of + ExportArrayResult name ty -> showCArgName name <> ":" <> show ty + ExportScalarResultPtr name sbt -> showCArgName name <> ":" <> showExportSBT sbt diff --git a/src/lib/Flops.hs b/src/lib/Flops.hs deleted file mode 100644 index 7ad8a3bbe..000000000 --- a/src/lib/Flops.hs +++ /dev/null @@ -1,92 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE OverloadedStrings #-} - -module Flops (impFunctionFlops) where - -import Control.Monad.Reader -import Control.Monad.Writer -import qualified Data.Map.Strict as M -import Data.Text.Prettyprint.Doc hiding (group) - -import Syntax -import Env -import PPrint - -data Term = Term Int [(Name, Int)] deriving (Show, Eq, Ord) -type Count = [Term] -newtype Profile = Profile (M.Map String Count) - -type FlopM a = ReaderT Term (Writer Profile) a - -impFunctionFlops :: ImpFunction -> Profile -impFunctionFlops (FFIFunction _) = mempty -impFunctionFlops (ImpFunction _ _ body) = - snd $ runWriter (runReaderT (flops body) (litTerm 1)) - -flops :: ImpBlock -> FlopM () -flops (ImpBlock statements _) = void $ traverse declFlops statements - -declFlops :: ImpDecl -> FlopM () -declFlops (ImpLet _ instr) = instrFlops instr - -instrFlops :: ImpInstr -> FlopM () -instrFlops instr = case instr of - IFor _ _ size block -> local (mulTerm $ evalSizeExpr size) $ flops block - ICond _ _ _ -> return () -- TODO: Implement - IWhile _ _ -> return () -- TODO: Implement - ILaunch _ _ _ -> return () -- TODO: Implement - IPrimOp op -> do - n <- ask - tell $ Profile $ M.singleton (showPrimName $ OpExpr op) [n] - _ -> return () - -evalSizeExpr :: IExpr -> Term -evalSizeExpr (IVar (v:>_)) = varTerm v -evalSizeExpr expr = error $ "Not implemented: " ++ pprint expr - -litTerm :: Int -> Term -litTerm n = Term n [] - -varTerm :: Name -> Term -varTerm v = Term 1 [(v, 1)] - -mulTerm :: Term -> Term -> Term -mulTerm (Term n xs) (Term n' xs') = Term (n * n') (xs <> xs') - -canonicalizeCount :: Count -> Count -canonicalizeCount terms = - let terms' = groupReduce (+) [(term, coeff) | - Term coeff term <- map canonicalizeTerm terms] - in [Term coeff term | (term, coeff) <- terms'] - -canonicalizeTerm :: Term -> Term -canonicalizeTerm (Term coeff term) = Term coeff (groupReduce (+) term) - -prettyCount :: Count -> Doc ann -prettyCount terms = - hsep $ punctuate " +" $ map pretty terms' - where terms' = canonicalizeCount terms - -groupReduce :: Ord a => (b -> b -> b) -> [(a,b)] -> [(a,b)] -groupReduce f pairs = M.toAscList $ foldr (M.unionWith f) mempty $ - map (uncurry M.singleton) pairs - -instance Semigroup Profile where - Profile m <> Profile m' = Profile $ M.unionWith (<>) m m' - -instance Monoid Profile where - mempty = Profile mempty - mappend = (<>) - -instance Pretty Profile where - pretty (Profile m) = vsep $ [pretty b <+> prettyCount c - | (b, c) <- M.toAscList m] - -instance Pretty Term where - pretty (Term coeff term) = pretty coeff <+> - hsep ([pretty v <> "^" <> pretty pow | (v, pow) <- term]) diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index a9c967ed0..c264c5b2b 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -80,10 +80,10 @@ toImpModule :: TopEnv -> Backend -> CallingConvention -> Name -> (ImpFunction, ImpModule, AtomRecon) toImpModule env backend cc entryName argBinders maybeDest block = do let standaloneFunctions = - for (requiredFunctions env block) $ \(v, f) -> + for (requiredFunctions env block) \(v, f) -> runImpM initCtx inVarScope $ toImpStandalone v f runImpM initCtx inVarScope $ do - (reconAtom, impBlock) <- scopedBlock $ translateTopLevel (maybeDest, block) + (reconAtom, impBlock) <- scopedBlock $ translateTopLevel env (maybeDest, block) otherFunctions <- toList <$> looks envFunctions let ty = IFunType cc (map binderAnn argBinders) (impBlockType impBlock) let mainFunction = ImpFunction (entryName:>ty) argBinders impBlock @@ -98,9 +98,13 @@ toImpModule env backend cc entryName argBinders maybeDest block = do requiredFunctions :: HasVars a => Scope -> a -> [(Name, Atom)] requiredFunctions scope expr = - for (transitiveClosure getParents immediateParents) $ \fname -> do - let (_, LetBound _ (Atom f)) = scope ! fname - (fname, f) + flip foldMap (transitiveClosure getParents immediateParents) \fname -> + case scope ! fname of + (_, LetBound _ (Atom f)) -> [(fname, f)] + -- we treat runtime-supplied global constants (e.g. the virtual stdout + -- channel) as lambda-bound. TODO: consider a new annotation. + (_, LamBound _) -> [] + _ -> error "Shouldn't have other free variables left" where getParents :: Name -> [Name] getParents fname = envNames $ freeVars $ scope ! fname @@ -110,14 +114,15 @@ requiredFunctions scope expr = -- We don't emit any results when a destination is provided, since they are already -- going to be available through the dest. -translateTopLevel :: WithDest Block -> ImpM (AtomRecon, [IExpr]) -translateTopLevel (maybeDest, block) = do +translateTopLevel :: TopEnv -> WithDest Block -> ImpM (AtomRecon, [IExpr]) +translateTopLevel topEnv (maybeDest, block) = do outDest <- case maybeDest of Nothing -> makeAllocDest Unmanaged $ getType block Just dest -> return dest handleErrors $ void $ translateBlock mempty (Just outDest, block) resultAtom <- destToAtom outDest - let vsOut = envAsVars $ freeVars resultAtom + -- Some names in topEnv refer to global constants, like the virtual stdout channel + let vsOut = envAsVars $ freeVars resultAtom `envDiff` topEnv let reconAtom = Abs (toNest $ [Bind (v:>ty) | (v:>(ty, _)) <- vsOut]) resultAtom let resultIExprs = case maybeDest of Nothing -> [IVar (v:>fromScalarType ty) | (v:>(ty, _)) <- vsOut] @@ -140,7 +145,7 @@ toImpStandalone fname ~(LamVal b body) = do impBlock <- scopedErrBlock $ do arg <- destToAtom argDest void $ translateBlock (b@>arg) (Just outDest, body) - let bs = for ptrSizes $ \(Bind (v:>BaseTy ty), _) -> Bind $ v:>ty + let bs = for ptrSizes \(Bind (v:>BaseTy ty), _) -> Bind $ v:>ty let fTy = IFunType CEntryFun (map binderAnn bs) (impBlockType impBlock) return $ ImpFunction (fname:>fTy) bs impBlock @@ -148,7 +153,7 @@ translateBlock :: SubstEnv -> WithDest Block -> ImpM Atom translateBlock env destBlock = do let (decls, result, copies) = splitDest destBlock env' <- (env<>) <$> catFoldM translateDecl env decls - forM_ copies $ \(dest, atom) -> copyAtom dest =<< impSubst env' atom + forM_ copies \(dest, atom) -> copyAtom dest =<< impSubst env' atom translateExpr env' result translateDecl :: SubstEnv -> WithDest Decl -> ImpM SubstEnv @@ -237,7 +242,7 @@ toImpOp :: WithDest (PrimOp Atom) -> ImpM Atom toImpOp (maybeDest, op) = case op of TabCon (TabTy b _) rows -> do dest <- allocDest maybeDest resultTy - forM_ (zip [0..] rows) $ \(i, row) -> do + forM_ (zip [0..] rows) \(i, row) -> do ithDest <- destGet dest =<< intToIndex (binderType b) (IIdxRepVal i) copyAtom ithDest row destToAtom dest @@ -271,14 +276,20 @@ toImpOp (maybeDest, op) = case op of IndexRef refDest i -> returnVal =<< destGet refDest i FstRef ~(Con (ConRef (PairCon ref _ ))) -> returnVal ref SndRef ~(Con (ConRef (PairCon _ ref))) -> returnVal ref + IOAlloc ty n -> do + ptr <- emitAlloc (Heap CPU, ty) (fromScalarAtom n) + returnVal $ toScalarAtom ptr + IOFree ptr -> do + emitStatement $ Free $ fromScalarAtom ptr + return UnitVal + PtrOffset arr (IdxRepVal 0) -> returnVal arr PtrOffset arr off -> do buf <- impOffset (fromScalarAtom arr) (fromScalarAtom off) returnVal $ toScalarAtom buf PtrLoad arr -> returnVal . toScalarAtom =<< loadAnywhere (fromScalarAtom arr) - GetPtr tab -> do - (dest, ptr) <- makeAllocDestForPtr (getType tab) - copyAtom dest tab - returnVal ptr + PtrStore ptr x -> do + store (fromScalarAtom ptr) (fromScalarAtom x) + return UnitVal SliceOffset ~(Con (IndexSliceVal n _ tileOffset)) idx -> do i' <- indexToInt idx i <- iaddI (fromScalarAtom tileOffset) i' @@ -351,7 +362,7 @@ toImpHof env (maybeDest, hof) = do Select (toScalarAtom isLast) (toScalarAtom elemsUntilEnd) (toScalarAtom usualChunkSize)) - emitLoop "li" Fwd (fromScalarAtom threadChunkSize) $ \li -> do + emitLoop "li" Fwd (fromScalarAtom threadChunkSize) \li -> do i <- li `iaddI` chunkStart let idx = Con $ ParIndexCon idxTy $ toScalarAtom i ithDest <- destGet dest idx @@ -362,19 +373,19 @@ toImpHof env (maybeDest, hof) = do cond <- liftM snd $ scopedBlock $ do i <- destToAtom iPtr inRange <- (fromScalarAtom i) `iltI` n + emitWhen inRange $ do + let idx = Con $ ParIndexCon idxTy i + ithDest <- destGet dest idx + void $ translateBlock (env <> b @> idx) (Just ithDest, body) + copyAtom iPtr . toScalarAtom =<< iaddI (fromScalarAtom i) + (fromScalarAtom numThreads) return ((), [inRange]) - wbody <- scopedErrBlock $ do - i <- destToAtom iPtr - let idx = Con $ ParIndexCon idxTy i - ithDest <- destGet dest idx - void $ translateBlock (env <> b @> idx) (Just ithDest, body) - copyAtom iPtr . toScalarAtom =<< iaddI (fromScalarAtom i) (fromScalarAtom numThreads) - emitStatement $ IWhile cond wbody + emitStatement $ IWhile cond destToAtom dest _ -> do n <- indexSetSize idxTy dest <- allocDest maybeDest resultTy - emitLoop (binderNameHint b) d n $ \i -> do + emitLoop (binderNameHint b) d n \i -> do idx <- intToIndex idxTy i ithDest <- destGet dest idx void $ translateBlock (env <> b @> idx) (Just ithDest, body) @@ -382,13 +393,13 @@ toImpHof env (maybeDest, hof) = do For ParallelFor ~fbody@(LamVal b _) -> do idxTy <- impSubst env $ binderType b dest <- allocDest maybeDest resultTy - buildKernel idxTy $ \LaunchInfo{..} buildBody -> do - liftM (,()) $ buildBody $ \ThreadInfo{..} -> do + buildKernel idxTy \LaunchInfo{..} buildBody -> do + liftM (,()) $ buildBody \ThreadInfo{..} -> do let threadBody = fst $ flip runSubstEmbed (freeVars fbody) $ - buildLam (Bind $ "hwidx" :> threadRange) PureArrow $ \hwidx -> + buildLam (Bind $ "hwidx" :> threadRange) PureArrow \hwidx -> appReduce fbody =<< (emitOp $ Inject hwidx) let threadDest = Con $ TabRef $ fst $ flip runSubstEmbed (freeVars dest) $ - buildLam (Bind $ "hwidx" :> threadRange) TabArrow $ \hwidx -> + buildLam (Bind $ "hwidx" :> threadRange) TabArrow \hwidx -> indexDest dest =<< (emitOp $ Inject hwidx) void $ toImpHof env (Just threadDest, For (RegularFor Fwd) threadBody) destToAtom dest @@ -400,12 +411,12 @@ toImpHof env (maybeDest, hof) = do nTiles <- n `idivI` tileLen epilogueOff <- nTiles `imulI` tileLen nEpilogue <- n `isubI` epilogueOff - emitLoop (binderNameHint tb) Fwd nTiles $ \iTile -> do + emitLoop (binderNameHint tb) Fwd nTiles \iTile -> do tileOffset <- toScalarAtom <$> iTile `imulI` tileLen let tileAtom = Con $ IndexSliceVal idxTy tileIdxTy tileOffset tileDest <- fromEmbed $ sliceDestDim d dest tileOffset tileIdxTy void $ translateBlock (env <> tb @> tileAtom) (Just tileDest, tBody) - emitLoop (binderNameHint sb) Fwd nEpilogue $ \iEpi -> do + emitLoop (binderNameHint sb) Fwd nEpilogue \iEpi -> do i <- iEpi `iaddI` epilogueOff idx <- intToIndex idxTy i sDest <- fromEmbed $ indexDestDim d dest idx @@ -415,16 +426,16 @@ toImpHof env (maybeDest, hof) = do idxTy <- impSubst env idxTy' (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy let PairTy _ accType = resultTy - (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy $ \LaunchInfo{..} buildBody -> do + (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy \LaunchInfo{..} buildBody -> do let widIdxTy = Fin $ toScalarAtom numWorkgroups let tidIdxTy = Fin $ toScalarAtom workgroupSize wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType - mappingKernelBody <- buildBody $ \ThreadInfo{..} -> do + mappingKernelBody <- buildBody \ThreadInfo{..} -> do let TC (ParIndexRange _ gtid nthr) = threadRange let scope = freeVars mappingDest let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do - buildLam (Bind $ "hwidx":>threadRange) TabArrow $ \hwidx -> do + buildLam (Bind $ "hwidx":>threadRange) TabArrow \hwidx -> do indexDest mappingDest =<< (emitOp $ Inject hwidx) wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid @@ -436,12 +447,12 @@ toImpHof env (maybeDest, hof) = do -- TODO: Skip the reduction kernel if unnecessary? -- TODO: Reduce sequentially in the CPU backend? -- TODO: Actually we only need the previous-power-of-2 many threads - buildKernel widIdxTy $ \LaunchInfo{..} buildBody -> do + buildKernel widIdxTy \LaunchInfo{..} buildBody -> do -- We only do a one-level reduciton in the workgroup, so it is correct -- only if the end up scheduling a single workgroup. moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups guardBlock moreThanOneGroup $ emitStatement IThrowError - redKernelBody <- buildBody $ \ThreadInfo{..} -> + redKernelBody <- buildBody \ThreadInfo{..} -> workgroupReduce tid finalAccDest wgResArr numTileWorkgroups return (redKernelBody, ()) PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest @@ -455,20 +466,21 @@ toImpHof env (maybeDest, hof) = do let arrIdxTy = binderType arrIdxB offPtr <- alloc IdxRepTy copyAtom offPtr $ toScalarAtom elemCountDown2 + let wbody = do + off <- fromScalarAtom <$> destToAtom offPtr + loadIdx <- iaddI tid off + shouldAdd <- bindM2 bandI (tid `iltI` off) (loadIdx `iltI` elemCount) + guardBlock shouldAdd $ do + threadDest <- destGet arrDest =<< intToIndex arrIdxTy tid + addToAtom threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx + emitStatement ISyncWorkgroup + copyAtom offPtr . toScalarAtom =<< off `idivI` (IIdxRepVal 2) cond <- liftM snd $ scopedBlock $ do off <- fromScalarAtom <$> destToAtom offPtr cond <- emitInstr $ IPrimOp $ ScalarBinOp (ICmp Greater) off (IIdxRepVal 0) + emitWhen cond wbody return ((), [cond]) - wbody <- scopedErrBlock $ do - off <- fromScalarAtom <$> destToAtom offPtr - loadIdx <- iaddI tid off - shouldAdd <- bindM2 bandI (tid `iltI` off) (loadIdx `iltI` elemCount) - guardBlock shouldAdd $ do - threadDest <- destGet arrDest =<< intToIndex arrIdxTy tid - addToAtom threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx - emitStatement ISyncWorkgroup - copyAtom offPtr . toScalarAtom =<< off `idivI` (IIdxRepVal 2) - emitStatement $ IWhile cond wbody + emitStatement $ IWhile cond firstThread <- tid `iltI` (IIdxRepVal 1) guardBlock firstThread $ copyAtom resDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy tid @@ -480,17 +492,15 @@ toImpHof env (maybeDest, hof) = do let getNext = imulI (IIdxRepVal 2) . fromScalarAtom =<< destToAtom rPtr cond <- liftM snd $ scopedBlock $ do canGrow <- getNext >>= (`iltI` x) + emitWhen canGrow $ copyAtom rPtr . toScalarAtom =<< getNext return ((), [canGrow]) - wbody <- scopedErrBlock $ do - copyAtom rPtr . toScalarAtom =<< getNext - emitStatement $ IWhile cond wbody + emitStatement $ IWhile cond fromScalarAtom <$> destToAtom rPtr - While ~(Lam (Abs _ (_, cond))) ~(Lam (Abs _ (_, body))) -> do - cond' <- liftM snd $ scopedBlock $ do - ans <- translateBlock env (Nothing, cond) + While ~(Lam (Abs _ (_, body))) -> do + body' <- liftM snd $ scopedBlock $ do + ans <- translateBlock env (Nothing, body) return ((), [fromScalarAtom ans]) - body' <- scopedErrBlock $ void $ translateBlock env (Nothing, body) - emitStatement $ IWhile cond' body' + emitStatement $ IWhile body' return UnitVal RunReader r ~(BinaryFunVal _ ref _ body) -> do rDest <- alloc $ getType r @@ -507,8 +517,11 @@ toImpHof env (maybeDest, hof) = do copyAtom sDest =<< impSubst env s void $ translateBlock (env <> ref @> sDest) (Just aDest, body) PairVal <$> destToAtom aDest <*> destToAtom sDest + RunIO ~(Lam (Abs _ (_, body))) -> + translateBlock env (maybeDest, body) Linearize _ -> error "Unexpected Linearize" Transpose _ -> error "Unexpected Transpose" + CatchException _ -> error "Unexpected CatchException" data LaunchInfo = LaunchInfo { numWorkgroups :: IExpr, workgroupSize :: IExpr } data ThreadInfo = ThreadInfo { tid :: IExpr, wid :: IExpr, threadRange :: Type } @@ -539,7 +552,7 @@ buildKernel idxTy f = do LLVMCUDA -> (CUDAKernelLaunch, GPU) LLVMMC -> (MCThreadLaunch , CPU) backend -> error $ "Shouldn't be launching kernels from " ++ show backend - ((kernelBody, aux), env) <- scoped $ f LaunchInfo{..} $ \mkBody -> + ((kernelBody, aux), env) <- scoped $ f LaunchInfo{..} \mkBody -> withDevice dev $ withLevel ThreadLevel $ scopedErrBlock $ do gtid <- iaddI tid =<< imulI wid wsz let threadRange = TC $ ParIndexRange idxTy (toScalarAtom gtid) (toScalarAtom nthr) @@ -572,7 +585,7 @@ type DestM = ReaderT DestEnv (CatT (Env (Type, Block)) Embed) makeDest :: AllocInfo -> Type -> Embed ([(Binder, Atom)], Dest) makeDest allocInfo ty = do (dest, ptrs) <- flip runCatT mempty $ flip runReaderT env $ makeDestRec ty - ptrs' <- forM (envPairs ptrs) $ \(v, (ptrTy, numel)) -> do + ptrs' <- forM (envPairs ptrs) \(v, (ptrTy, numel)) -> do numel' <- emitBlock numel return (Bind (v:>ptrTy), numel') return (ptrs', dest) @@ -589,7 +602,7 @@ makeDestRec ty = case ty of makeDestRec ty makeBoxes (envPairs ptrs) dest else do - lam <- buildLam (Bind ("i":> binderAnn b)) TabArrow $ \(Var i) -> do + lam <- buildLam (Bind ("i":> binderAnn b)) TabArrow \(Var i) -> do bodyTy' <- substEmbed (b@>Var i) bodyTy withEnclosingIdxs (Bind i) $ makeDestRec bodyTy' return $ Con $ TabRef lam @@ -605,7 +618,7 @@ makeDestRec ty = case ty of "Dependent data constructors only allowed for single-constructor types" tag <- rec TagRepTy let dcs' = applyDataDefParams def params - contents <- forM dcs' $ \(DataConDef _ bs) -> forM (toList bs) (rec . binderType) + contents <- forM dcs' \(DataConDef _ bs) -> forM (toList bs) (rec . binderType) return $ Con $ ConRef $ SumAsProd ty tag contents RecordTy (NoExt types) -> (Con . RecordRef) <$> forM types rec VariantTy (NoExt types) -> do @@ -641,7 +654,7 @@ makeBaseTypePtr ty = do -- where they could cast shadows let addrSpace = chooseAddrSpace allocInfo numel let ptrName = genFresh (Name AllocPtrName "ptr" 0) (scope <> ptrScope) - let ptrTy = PtrTy (AllocatedPtr, addrSpace, ty) + let ptrTy = PtrTy (addrSpace, ty) extend (ptrName @> (ptrTy, numel)) let ptr = Var (ptrName :> ptrTy) applyIdxs ptr idxs @@ -680,6 +693,11 @@ copyAtom (Con dest) src = case (dest, src) of (ConRef (SumAsProd _ tag payload), DataCon _ _ con x) -> do copyAtom tag (TagRepVal $ fromIntegral con) zipWithM_ copyAtom (payload !! con) x + (ConRef (SumAsProd _ tagDest payloadDest), Con (SumAsProd _ tag payload)) -> do + copyAtom tagDest tag + unless (all null payload) $ -- optimization + emitSwitch (fromScalarAtom tag) $ + zipWith (zipWithM_ copyAtom) payloadDest payload (ConRef destCon, Con srcCon) -> zipWithRefConM copyAtom destCon srcCon (RecordRef refs, Record vals) | fmap (const ()) refs == fmap (const ()) vals -> do @@ -703,15 +721,15 @@ copyDataConArgs bindings args = loadDest :: MonadEmbed m => Dest -> m Atom loadDest (BoxedRef b ptrPtr _ body) = do - ptr <- ptrLoad ptrPtr + ptr <- unsafePtrLoad ptrPtr body' <- substEmbed (b@>ptr) body loadDest body' loadDest (DataConRef def params bs) = do DataCon def params 0 <$> loadDataConArgs bs loadDest (Con dest) = do case dest of - BaseTypeRef ptr -> ptrLoad ptr - TabRef (TabVal b body) -> buildLam b TabArrow $ \i -> do + BaseTypeRef ptr -> unsafePtrLoad ptr + TabRef (TabVal b body) -> buildLam b TabArrow \i -> do body' <- substEmbed (b@>i) body result <- emitBlock body' loadDest result @@ -735,7 +753,7 @@ loadDataConArgs (Nest (DataConRefBinding b ref) rest) = do indexDestDim :: MonadEmbed m => Int->Dest -> Atom -> m Dest indexDestDim 0 dest i = indexDest dest i -indexDestDim d dest i = buildFor Fwd (Bind ("i":>idxTy)) $ \j -> do +indexDestDim d dest i = buildFor Fwd (Bind ("i":>idxTy)) \j -> do dest' <- indexDest dest j indexDestDim (d-1) dest' i where @@ -748,7 +766,7 @@ indexDest dest _ = error $ pprint dest sliceDestDim :: MonadEmbed m => Int -> Dest -> Atom -> Type -> m Dest sliceDestDim 0 dest i sliceIdxTy = sliceDest dest i sliceIdxTy -sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) $ \j -> do +sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) \j -> do dest' <- indexDest dest j sliceDestDim (d-1) dest' i sliceIdxTy where @@ -757,7 +775,7 @@ sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) $ \j -> do sliceDest :: MonadEmbed m => Dest -> Atom -> Type -> m Dest sliceDest ~(Con (TabRef tab@(TabVal b _))) i sliceIdxTy = (Con . TabRef) <$> do - buildFor Fwd (Bind ("j" :> sliceIdxTy)) $ \j -> do + buildFor Fwd (Bind ("j" :> sliceIdxTy)) \j -> do j' <- indexToIntE j ioff <- iadd j' i vidx <- intToIndexE (binderType b) ioff @@ -781,26 +799,15 @@ makeAllocDestWithPtrs allocTy ty = do backend <- asks impBackend curDev <- asks curDevice (ptrsSizes, dest) <- fromEmbed $ makeDest (backend, curDev, allocTy) ty - (env, ptrs) <- flip foldMapM ptrsSizes $ \(Bind (ptr:>PtrTy ptrTy), size) -> do + (env, ptrs) <- flip foldMapM ptrsSizes \(Bind (ptr:>PtrTy ptrTy), size) -> do ptr' <- emitAlloc ptrTy $ fromScalarAtom size case ptrTy of - (_, Heap _, _) | allocTy == Managed -> extendAlloc ptr' + (Heap _, _) | allocTy == Managed -> extendAlloc ptr' _ -> return () return (ptr @> toScalarAtom ptr', [ptr']) dest' <- impSubst env dest return (dest', ptrs) --- TODO: deallocation! -makeAllocDestForPtr :: Type -> ImpM (Dest, Atom) -makeAllocDestForPtr ty = do - (ptrSizes, dest) <- fromEmbed $ makeDest (LLVM, CPU, Unmanaged) ty - case ptrSizes of - [(Bind (ptr:>PtrTy ptrTy), size)] -> do - ptr' <- emitAlloc ptrTy $ fromScalarAtom size - dest' <- impSubst (ptr @> toScalarAtom ptr') dest - return (dest', toScalarAtom ptr') - _ -> error $ "expected a single pointer" - splitDest :: WithDest Block -> ([WithDest Decl], WithDest Expr, [(Dest, Atom)]) splitDest (maybeDest, (Block decls ans)) = do case (maybeDest, ans) of @@ -813,7 +820,7 @@ splitDest (maybeDest, (Block decls ans)) = do let closureCopies = fmap (\(n, (d, t)) -> (d, Var $ n :> t)) (envPairs $ varDests `envDiff` foldMap letBoundVars decls) - let destDecls = flip fmap (toList decls) $ \d -> case d of + let destDecls = flip fmap (toList decls) \d -> case d of Let _ b _ -> (fst <$> varDests `envLookup` b, d) (destDecls, (Nothing, ans), gatherCopies ++ closureCopies) _ -> (fmap (Nothing,) $ toList decls, (maybeDest, ans), []) @@ -831,12 +838,13 @@ splitDest (maybeDest, (Block decls ans)) = do (_, Con (Lit _)) -> tell [(dest, result)] -- This is conservative, in case the type is dependent. We could do better. (DataConRef _ _ _, DataCon _ _ _ _) -> tell [(dest, result)] + -- This is conservative. Without it, we hit bugs like #348 + (Con (ConRef (SumAsProd _ _ _)), _) -> tell [(dest, result)] (Con (ConRef destCon), Con srcCon) -> zipWithRefConM gatherVarDests destCon srcCon (Con (RecordRef items), Record items') | fmap (const ()) items == fmap (const ()) items' -> do zipWithM_ gatherVarDests (toList items) (toList items') - (Con (ConRef (SumAsProd _ _ _)), _) -> tell [(dest, result)] -- TODO (_, ProjectElt _ _) -> tell [(dest, result)] -- TODO: is this reasonable? _ -> unreachable where @@ -865,10 +873,10 @@ chooseAddrSpace (backend, curDev, allocTy) numel = case allocTy of else Heap mainDev | otherwise -> Heap mainDev where mainDev = case backend of - LLVM -> CPU - LLVMMC -> CPU - LLVMCUDA -> GPU - Interp -> error "Shouldn't be compiling with interpreter backend" + LLVM -> CPU + LLVMMC -> CPU + LLVMCUDA -> GPU + Interpreter -> error "Shouldn't be compiling with interpreter backend" isSmall :: Block -> Bool isSmall numel = case numel of @@ -883,7 +891,7 @@ isSmall numel = case numel of allocateBuffer :: AddressSpace -> Bool -> BaseType -> IExpr -> ImpM IExpr allocateBuffer addrSpace mustFree b numel = do - buffer <- emitAlloc (AllocatedPtr, addrSpace, b) numel + buffer <- emitAlloc (addrSpace, b) numel when mustFree $ extendAlloc buffer return buffer @@ -910,7 +918,7 @@ toScalarAtom ie = case ie of ILit l -> Con $ Lit l IVar (v:>b) -> Var (v:>BaseTy b) -fromScalarType :: Type -> IType +fromScalarType :: HasCallStack => Type -> IType fromScalarType (BaseTy b) = b fromScalarType ty = error $ "Not a scalar type: " ++ pprint ty @@ -941,7 +949,7 @@ zipTabDestAtom f ~dest@(Con (TabRef (TabVal b _))) ~src@(TabVal b' _) = do error $ "Mismatched dimensions: " <> pprint b <> " != " <> pprint b' let idxTy = binderType b n <- indexSetSize idxTy - emitLoop "i" Fwd n $ \i -> do + emitLoop "i" Fwd n \i -> do idx <- intToIndex idxTy i destIndexed <- destGet dest idx srcIndexed <- translateExpr mempty (Nothing, App src idx) @@ -951,8 +959,6 @@ zipWithRefConM :: Monad m => (Dest -> Atom -> m ()) -> Con -> Con -> m () zipWithRefConM f destCon srcCon = case (destCon, srcCon) of (PairCon d1 d2, PairCon s1 s2) -> f d1 s1 >> f d2 s2 (UnitCon, UnitCon) -> return () - (SumAsProd _ tagRef xssRef, SumAsProd _ tag xss) -> - f tagRef tag >> zipWithM_ (zipWithM f) xssRef xss (IntRangeVal _ _ iRef, IntRangeVal _ _ i) -> f iRef i (IndexRangeVal _ _ _ iRef, IndexRangeVal _ _ _ i) -> f iRef i _ -> error $ "Unexpected ref/val " ++ pprint (destCon, srcCon) @@ -971,6 +977,10 @@ addToAtom dest src = case (dest, src) of updated <- emitInstr $ IPrimOp $ op FAdd cur x' storeAnywhere ptr' updated (Con (TabRef _), TabVal _ _) -> zipTabDestAtom addToAtom dest src + (Con (ConRef (SumAsProd _ _ payloadDest)), Con (SumAsProd _ tag payload)) -> do + unless (all null payload) $ -- optimization + emitSwitch (fromScalarAtom tag) $ + zipWith (zipWithM_ addToAtom) payloadDest payload (Con (ConRef destCon), Con srcCon) -> zipWithRefConM addToAtom destCon srcCon (Con (RecordRef dests), Record srcs) -> zipWithM_ addToAtom (toList dests) (toList srcs) @@ -979,7 +989,7 @@ addToAtom dest src = case (dest, src) of loadAnywhere :: IExpr -> ImpM IExpr loadAnywhere ptr = do curDev <- asks curDevice - let (PtrType (_, addrSpace, ty)) = getIType ptr + let (PtrType (addrSpace, ty)) = getIType ptr case addrSpace of Heap ptrDev | ptrDev /= curDev -> do localPtr <- allocateStackSingleton ty @@ -990,7 +1000,7 @@ loadAnywhere ptr = do storeAnywhere :: IExpr -> IExpr -> ImpM () storeAnywhere ptr val = do curDev <- asks curDevice - let (PtrType (_, addrSpace, ty)) = getIType ptr + let (PtrType (addrSpace, ty)) = getIType ptr case addrSpace of Heap ptrDev | ptrDev /= curDev -> do localPtr <- allocateStackSingleton ty @@ -1049,6 +1059,9 @@ alloc ty = makeAllocDest Managed ty handleErrors :: ImpM () -> ImpM () handleErrors m = m `catchError` (const $ emitStatement IThrowError) +emitWhen :: IExpr -> ImpM () -> ImpM () +emitWhen cond doIfTrue = emitSwitch cond [return (), doIfTrue] + -- TODO: Consider targeting LLVM's `switch` instead of chained conditionals. emitSwitch :: IExpr -> [ImpM ()] -> ImpM () emitSwitch testIdx = rec 0 @@ -1093,7 +1106,7 @@ restructureScalarOrPairTypeRec ty _ = error $ "Not a scalar or pair: " ++ pprint emitMultiReturnInstr :: ImpInstr -> ImpM [IExpr] emitMultiReturnInstr instr = do - vs <- forM (impInstrTypes instr) $ \ty -> freshVar ("v":>ty) + vs <- forM (impInstrTypes instr) \ty -> freshVar ("v":>ty) emitImpDecl $ ImpLet (map Bind vs) instr return (map IVar vs) @@ -1117,7 +1130,7 @@ extendAlloc :: IExpr -> ImpM () extendAlloc v = extend $ mempty { envPtrsToFree = [v] } emitAlloc :: HasCallStack => PtrType -> IExpr -> ImpM IExpr -emitAlloc (_, addr, ty) n = emitInstr $ Alloc addr ty n +emitAlloc (addr, ty) n = emitInstr $ Alloc addr ty n scopedErrBlock :: ImpM () -> ImpM ImpBlock scopedErrBlock body = liftM snd $ scopedBlock $ handleErrors body $> ((),[]) @@ -1167,6 +1180,7 @@ instance Checkable ImpFunction where checkValid f@(ImpFunction (_:> IFunType cc _ _) bs block) = addContext ctx $ do let scope = foldMap (binderAsEnv . fmap (const ())) bs let env = foldMap (binderAsEnv ) bs + <> fmap (fromScalarType . fst) initTopEnv void $ flip runReaderT (env, deviceFromCallingConvention cc) $ flip runStateT scope $ checkBlock block where ctx = "Checking:\n" ++ pprint f @@ -1189,15 +1203,14 @@ checkDecl decl@(ImpLet bs instr) = addContext ctx $ do instrTypeChecked :: ImpInstr -> ImpCheckM [IType] instrTypeChecked instr = case instr of IFor _ i size block -> do - checkInt size + checkIdxRep size checkBinder i assertEq (binderAnn i) (getIType size) $ "Mismatch between the loop iterator and upper bound type" [] <- withTypeEnv (i @> getIType size) $ checkBlock block return [] - IWhile cond body -> do - [condTy] <- checkBlock cond - assertEq (Scalar Word8Type) condTy $ "Not a bool: " ++ pprint cond - [] <- checkBlock body + IWhile body -> do + [condTy] <- checkBlock body + assertEq (Scalar Word8Type) condTy $ "Not a bool: " ++ pprint body return [] ICond predicate consequent alternative -> do predTy <- checkIExpr predicate @@ -1214,23 +1227,26 @@ instrTypeChecked instr = case instr of mapM_ checkIExpr args IPrimOp op -> (:[]) <$> checkImpOp op ICastOp dt x -> (:[]) <$> do - case getIType x of - Scalar _ -> return () - _ -> throw CompilerErr $ "Invalid cast source type: " ++ pprint dt - case dt of - Scalar _ -> return () - _ -> throw CompilerErr $ "Invalid cast destination type: " ++ pprint dt + let st = getIType x + case (dt, st) of + (PtrType _, PtrType _) -> return () + (Scalar _, Scalar _) -> return () + (Scalar Int64Type, PtrType _) -> return () + (PtrType _, Scalar Int64Type) -> return () + _ -> throw CompilerErr $ + "Can't cast " ++ pprint st ++ " to " ++ pprint dt return dt - Alloc a ty _ -> (:[]) <$> do + Alloc a ty n -> (:[]) <$> do + checkIdxRep n when (a /= Stack) assertHost - return $ PtrType (AllocatedPtr, a, ty) + return $ PtrType (a, ty) MemCopy dest src numel -> [] <$ do - PtrType (_, _, destTy) <- checkIExpr dest - PtrType (_, _, srcTy) <- checkIExpr src + PtrType (_, destTy) <- checkIExpr dest + PtrType (_, srcTy) <- checkIExpr src assertEq destTy srcTy "pointer type mismatch" - checkInt numel + checkIdxRep numel Store dest val -> [] <$ do - PtrType (_, addr, ty) <- checkIExpr dest + PtrType (addr, ty) <- checkIExpr dest checkAddrAccessible addr valTy <- checkIExpr val assertEq ty valTy "Type mismatch in store" @@ -1267,6 +1283,11 @@ checkIExpr expr = case expr of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just x -> return x +checkIdxRep :: IExpr -> ImpCheckM () +checkIdxRep expr = do + t <- checkIExpr expr + assertEq IIdxRepTy t $ "Not an index rep tye: " ++ pprint t + checkInt :: IExpr -> ImpCheckM () checkInt expr = do bt <- checkIExpr expr @@ -1291,12 +1312,12 @@ checkImpOp op = do checkIntBaseType False $ BaseTy ibt return $ Scalar ty PtrLoad ref -> do - PtrType (_, addr, ty) <- return ref + PtrType (addr, ty) <- return ref checkAddrAccessible addr return ty PtrOffset ref _ -> do -- TODO: check offset too - PtrType (_, addr, ty) <- return ref - return $ PtrType (DerivedPtr, addr, ty) + PtrType (addr, ty) <- return ref + return $ PtrType (addr, ty) _ -> error $ "Not allowed in Imp IR: " ++ pprint op where checkEq :: (Pretty a, Show a, Eq a) => a -> a -> ImpCheckM () @@ -1324,13 +1345,13 @@ impInstrTypes :: ImpInstr -> [IType] impInstrTypes instr = case instr of IPrimOp op -> [impOpType op] ICastOp t _ -> [t] - Alloc a ty _ -> [PtrType (AllocatedPtr, a, ty)] + Alloc a ty _ -> [PtrType (a, ty)] Store _ _ -> [] Free _ -> [] IThrowError -> [] MemCopy _ _ _ -> [] IFor _ _ _ _ -> [] - IWhile _ _ -> [] + IWhile _ -> [] ICond _ _ _ -> [] ILaunch _ _ _ -> [] ISyncWorkgroup -> [] @@ -1360,8 +1381,8 @@ impOpType pop = case pop of Select _ x _ -> getIType x VectorPack xs -> Vector ty where Scalar ty = getIType $ head xs VectorIndex x _ -> Scalar ty where Vector ty = getIType x - PtrLoad ref -> ty where PtrType (_, _, ty) = getIType ref - PtrOffset ref _ -> PtrType (DerivedPtr, addr, ty) where PtrType (_, addr, ty) = getIType ref + PtrLoad ref -> ty where PtrType (_, ty) = getIType ref + PtrOffset ref _ -> PtrType (addr, ty) where PtrType (addr, ty) = getIType ref _ -> unreachable where unreachable = error $ "Not allowed in Imp IR: " ++ pprint pop diff --git a/src/lib/Imp/Embed.hs b/src/lib/Imp/Embed.hs index ef437e01a..315a00226 100644 --- a/src/lib/Imp/Embed.hs +++ b/src/lib/Imp/Embed.hs @@ -147,8 +147,7 @@ traverseImpInstr def instr = case instr of b' <- freshIVar b IFor dir (Bind b') <$> traverseIExpr size <*> (extendValSubst (b @> IVar b') $ traverseImpBlock def body) - IWhile cond body -> - IWhile <$> traverseImpBlock def cond <*> traverseImpBlock def body + IWhile body -> IWhile <$> traverseImpBlock def body ICond cond tb fb -> ICond <$> traverseIExpr cond <*> traverseImpBlock def tb <*> traverseImpBlock def fb diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index bd398e8c3..ed8f3c1d2 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -19,6 +19,7 @@ import Data.Functor import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import Data.String (fromString) +import qualified Data.Set as S import Data.Text.Prettyprint.Doc import Syntax @@ -49,7 +50,7 @@ inferModule :: TopEnv -> UModule -> Except Module inferModule scope (UModule decls) = do ((), (bindings, decls')) <- runUInferM mempty scope $ mapM_ (inferUDecl True) decls - let bindings' = envFilter bindings $ \(_, b) -> case b of + let bindings' = envFilter bindings \(_, b) -> case b of DataBoundTypeCon _ -> True DataBoundDataCon _ _ -> True _ -> False @@ -67,7 +68,7 @@ checkSigma expr reqCon sTy = case sTy of WithSrc _ (ULam b arrow' body) | arrow' == void arrow -> checkULam b body piTy _ -> do - buildLam (Bind ("a":> absArgType piTy)) arrow $ \x@(Var v) -> + buildLam (Bind ("a":> absArgType piTy)) arrow \x@(Var v) -> checkLeaks [v] $ checkSigma expr reqCon $ snd $ applyAbs piTy x _ -> checkOrInferRho expr (reqCon sTy) @@ -151,20 +152,20 @@ checkOrInferRho (WithSrc pos expr) reqTy = do addEffects $ arrowEff arr' appVal <- emitZonked $ App fVal xVal' instantiateSigma appVal >>= matchRequirement - UPi (pat, kind) arr ty -> do + UPi (pat, ann) arr ty -> do -- TODO: make sure there's no effect if it's an implicit or table arrow -- TODO: check leaks - kind' <- checkUType kind + ann' <- checkAnn ann piTy <- case pat of - Just pat' -> withNameHint ("pat" :: Name) $ buildPi b $ \x -> - withBindPat pat' x $ (,) <$> mapM checkUEff arr <*> checkUType ty - where b = case pat' of + UnderscoreUPat -> buildPi (Ignore ann') $ const $ + (,) <$> mapM checkUEffRow arr <*> checkUType ty + _ -> withNameHint ("pat" :: Name) $ buildPi b \x -> + withBindPat pat x $ (,) <$> mapM checkUEffRow arr <*> checkUType ty + where b = case pat of -- Note: The binder name becomes part of the type, so we -- need to keep the same name used in the pattern. - WithSrc _ (UPatBinder (Bind (v:>()))) -> Bind (v:>kind') - _ -> Ignore kind' - Nothing -> buildPi (Ignore kind') $ const $ - (,) <$> mapM checkUEff arr <*> checkUType ty + WithSrc _ (UPatBinder (Bind (v:>()))) -> Bind (v:>ann') + _ -> Ignore ann' matchRequirement piTy UDecl decl body -> do env <- inferUDecl False decl @@ -181,7 +182,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do case scrutTy' of TypeCon def params -> do let conDefs = applyDataDefParams def params - altsSorted <- forM (enumerate conDefs) $ \(i, DataConDef _ bs) -> do + altsSorted <- forM (enumerate conDefs) \(i, DataConDef _ bs) -> do case lookup (ConAlt i) alts' of Nothing -> return $ Abs (fmap (Ignore . binderType) bs) $ Block Empty $ Op $ ThrowError reqTy' @@ -255,7 +256,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do val' <- checkSigma val reqCon ty' matchRequirement val' UPrimExpr prim -> do - prim' <- forM prim $ \e -> do + prim' <- forM prim \e -> do e' <- inferRho e scope <- getScope return $ typeReduceAtom scope e' @@ -308,7 +309,7 @@ lookupSourceVar v = do Nothing -> do scope <- getScope let v' = asGlobal $ varName v - case envLookup scope (v':>()) of + case envLookup scope v' of Just (_, DataBoundTypeCon def ) -> return $ TypeCon def [] Just (_, DataBoundDataCon def con) -> return $ DataCon def [] con [] Just (ty, _) -> return $ Var $ v':>ty @@ -318,10 +319,9 @@ unpackTopPat :: LetAnn -> UPat -> Expr -> UInferM () unpackTopPat letAnn pat expr = do atom <- emit expr bindings <- bindPat pat atom - void $ flip traverseNames bindings $ \name val -> do + void $ flip traverseNames bindings \name val -> do let name' = asGlobal name - scope <- getScope - when (name' `isin` scope) $ throw RepeatedVarErr $ pprint $ name' + checkNotInScope name' emitTo name' letAnn $ Atom val inferUDecl :: Bool -> UDecl -> UInferM SubstEnv @@ -342,29 +342,115 @@ inferUDecl topLevel (ULet letAnn (p, ann) rhs) = do else bindPat p val inferUDecl True (UData tc dcs) = do (tc', paramBs) <- inferUConDef tc - scope <- getScope - when (tc' `isin` scope) $ throw RepeatedVarErr $ pprint $ getName tc' - let paramVars = map (\(Bind v) -> v) $ toList paramBs -- TODO: refresh things properly - (dcs', _) <- embedScoped $ - extendR (newEnv paramBs (map Var paramVars)) $ do - extendScope (foldMap boundVars paramBs) - mapM inferUConDef dcs - let dataDef = DataDef tc' paramBs $ map (uncurry DataConDef) dcs' - let tyConTy = getType $ TypeCon dataDef [] - extendScope $ tc' @> (tyConTy, DataBoundTypeCon dataDef) - forM_ (zip [0..] dcs') $ \(i, (dc,_)) -> do - -- Retrieving scope at every step to avoid duplicate constructor names - scope' <- getScope - when (dc `isin` scope') $ throw RepeatedVarErr $ pprint $ getName dc - let ty = getType $ DataCon dataDef [] i [] - extendScope $ dc @> (ty, DataBoundDataCon dataDef i) + dataDef <- buildDataDef tc' paramBs \params -> do + extendR (newEnv paramBs params) $ forM dcs \dc -> + uncurry DataConDef <$> inferUConDef dc + checkDataDefShadows dataDef + emitConstructors dataDef + return mempty +inferUDecl True (UInterface superclasses tc methods) = do + (tc', paramBs) <- inferUConDef tc + dataDef <- buildDataDef tc' paramBs \params -> do + extendR (newEnv paramBs params) $ do + conName <- freshClassGenName + superclasses' <- mkLabeledItems <$> mapM mkSuperclass superclasses + methods' <- mkLabeledItems <$> mapM mkMethod methods + return $ ClassDictDef conName superclasses' methods' + checkDataDefShadows dataDef + emitConstructors dataDef + emitSuperclassGetters dataDef + emitMethodGetters dataDef return mempty -inferUDecl False (UData _ _) = error "data definitions should be top-level" +inferUDecl True (UInstance argBinders instanceTy methods) = do + instanceDict <- checkInstance argBinders instanceTy methods + let instanceName = Name TypeClassGenName "instance" 0 + void $ emitTo instanceName InstanceLet $ Atom instanceDict + return mempty +inferUDecl False (UData _ _ ) = error "data definitions should be top-level" +inferUDecl False (UInterface _ _ _) = error "interface definitions should be top-level" +inferUDecl False (UInstance _ _ _) = error "instance definitions should be top-level" + +freshClassGenName :: MonadEmbed m => m Name +freshClassGenName = do + scope <- getScope + let v' = genFresh (Name TypeClassGenName "classgen" 0) scope + embedExtend $ asFst $ v' @> (UnitTy, UnknownBinder) + return v' + +mkMethod :: UAnnBinder -> UInferM (Label, Type) +mkMethod (Ignore _) = error "Methods must have names" +mkMethod (Bind (v:>ty)) = do + ty' <- checkUType ty + return (nameToLabel v, ty') + +mkSuperclass :: UType -> UInferM (Label, Type) +mkSuperclass ty = do + ty' <- checkUType ty + -- TODO: think about the scope of these names + l <- freshClassGenName + return (nameToLabel l, ty') + +-- TODO: just make Name and Label the same thing +nameToLabel :: Name -> Label +nameToLabel = pprint + +mkLabeledItems :: [(Label, a)] -> LabeledItems a +mkLabeledItems items = foldMap (uncurry labeledSingleton) items + +emitConstructors :: DataDef -> UInferM () +emitConstructors def@(DataDef tyConName _ dataConDefs) = do + let tyConTy = getType $ TypeCon def [] + checkNotInScope tyConName + extendScope $ tyConName @> (tyConTy, DataBoundTypeCon def) + forM_ (zip [0..] dataConDefs) \(i, DataConDef dataConName _) -> do + let dataConTy = getType $ DataCon def [] i [] + checkNotInScope dataConName + extendScope $ dataConName @> (dataConTy, DataBoundDataCon def i) + +emitMethodGetters :: DataDef -> UInferM () +emitMethodGetters def@(DataDef _ paramBs (ClassDictDef _ _ methodTys)) = do + forM_ (getLabels methodTys) \l -> do + f <- buildImplicitNaryLam paramBs \params -> do + buildLam (Bind ("d":> TypeCon def params)) ClassArrow \dict -> do + return $ recGetHead l $ getProjection [1] dict + let methodName = GlobalName $ fromString l + checkNotInScope methodName + emitTo methodName PlainLet $ Atom f +emitMethodGetters (DataDef _ _ _) = error "Not a class dictionary" + +emitSuperclassGetters :: MonadEmbed m => DataDef -> m () +emitSuperclassGetters def@(DataDef _ paramBs (ClassDictDef _ superclassTys _)) = do + forM_ (getLabels superclassTys) \l -> do + f <- buildImplicitNaryLam paramBs \params -> do + buildLam (Bind ("d":> TypeCon def params)) PureArrow \dict -> do + return $ recGetHead l $ getProjection [0] dict + getterName <- freshClassGenName + emitTo getterName SuperclassLet $ Atom f +emitSuperclassGetters (DataDef _ _ _) = error "Not a class dictionary" + +checkNotInScope :: Name -> UInferM () +checkNotInScope v = do + scope <- getScope + when (v `isin` scope) $ throw RepeatedVarErr $ pprint v + +checkDataDefShadows :: DataDef -> UInferM () +checkDataDefShadows (DataDef tc _ dataCons) = do + checkShadows $ tc:dcs + where dcs = [dc | DataConDef dc _ <- dataCons] + +checkShadows :: [Name] -> UInferM () +checkShadows vs = do + mapM_ checkNotInScope vs + case repeated vs of + [] -> return () + (v:_) -> throw RepeatedVarErr $ pprint v inferUConDef :: UConDef -> UInferM (Name, Nest Binder) inferUConDef (UConDef v bs) = do (bs', _) <- embedScoped $ checkNestedBinders bs - return (asGlobal v, bs') + let v' = asGlobal v + checkNotInScope v' + return (v', bs') checkNestedBinders :: Nest UAnnBinder -> UInferM (Nest Binder) checkNestedBinders Empty = return Empty @@ -381,7 +467,7 @@ inferULam (p, ann) arr body = do argTy <- checkAnn ann -- TODO: worry about binder appearing in arrow? buildLam (Bind $ patNameHint p :> argTy) arr - $ \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ inferSigma body + \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ inferSigma body checkULam :: UPatAnn -> UExpr -> PiType -> UInferM Atom checkULam (p, ann) body piTy = do @@ -389,13 +475,54 @@ checkULam (p, ann) body piTy = do checkAnn ann >>= constrainEq argTy buildDepEffLam (Bind $ patNameHint p :> argTy) ( \x -> return $ fst $ applyAbs piTy x) - $ \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ + \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ checkSigma body Suggest $ snd $ applyAbs piTy x -checkUEff :: EffectRow -> UInferM EffectRow -checkUEff (EffectRow effs t) = do - effs' <- forM effs $ \(effName, region) -> (effName,) <$> lookupVarName TyKind region - t' <- forM t $ \tv -> lookupVarName EffKind tv +checkInstance :: Nest UPatAnnArrow -> UType -> [UMethodDef] -> UInferM Atom +checkInstance Empty ty methods = do + ty' <- checkUType ty + case ty' of + TypeCon def@(DataDef className _ _) params -> + case applyDataDefParams def params of + ClassDictDef _ superclassTys methodTys -> do + let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys + methods' <- checkMethodDefs className methodTys methods + return $ ClassDictCon def params superclassHoles methods' + _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty + _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty +checkInstance (Nest ((p, ann), arrow) rest) ty methods = do + case arrow of + ImplicitArrow -> return () + ClassArrow -> return () + _ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow + argTy <- checkAnn ann + buildLam (Bind $ patNameHint p :> argTy) (fromUArrow arrow) \x@(Var v) -> + checkLeaks [v] $ withBindPat p x $ checkInstance rest ty methods + + +checkMethodDefs :: Name -> LabeledItems Type -> [UMethodDef] + -> UInferM (LabeledItems Atom) +checkMethodDefs className methodTys methods = do + methods' <- liftM mkLabeledItems $ forM methods \(UMethodDef (v:>()) rhs) -> do + let v' = nameToLabel v + case lookupLabelHead methodTys v' of + Nothing -> throw TypeErr $ + pprint v ++ " is not a method of " ++ pprint className + Just methodTy -> do + rhs' <- checkSigma rhs Suggest methodTy + return (v', rhs') + forM_ (reflectLabels methods') \(l,i) -> + when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l + forM_ (reflectLabels methodTys) \(l,_) -> + case lookupLabelHead methods' l of + Nothing -> throw TypeErr $ "Missing method: " ++ pprint l + Just _ -> return () + return methods' + +checkUEffRow :: EffectRow -> UInferM EffectRow +checkUEffRow (EffectRow effs t) = do + effs' <- liftM S.fromList $ mapM checkUEff $ toList effs + t' <- forM t \tv -> lookupVarName EffKind tv return $ EffectRow effs' t' where lookupVarName :: Type -> Name -> UInferM Name @@ -405,6 +532,15 @@ checkUEff (EffectRow effs t) = do constrainEq ty ty' return v' +checkUEff :: Effect -> UInferM Effect +checkUEff eff = case eff of + RWSEffect rws region -> do + (Var (v:>ty)) <- lookupSourceVar (region:>()) + constrainEq TyKind ty + return $ RWSEffect rws v + ExceptionEffect -> return ExceptionEffect + IOEffect -> return IOEffect + data CaseAltIndex = ConAlt Int | VariantAlt Label Int | VariantTailAlt (LabeledItems ()) @@ -415,7 +551,7 @@ checkCaseAlt reqTy scrutineeTy (UAlt pat body) = do (conIdx, patTys) <- checkCasePat pat scrutineeTy let (subPats, subPatTys) = unzip patTys let bs = zipWith (\p ty -> Bind $ patNameHint p :> ty) subPats subPatTys - alt <- buildNAbs (toNest bs) $ \xs -> + alt <- buildNAbs (toNest bs) \xs -> withBindPats (zip subPats xs) $ checkRho body reqTy return (conIdx, alt) @@ -533,7 +669,7 @@ bindPat' (WithSrc pos pat) val = addSrcContext pos $ case pat of throw TypeErr $ "Incorrect length of table pattern: table index set has " <> pprint (length idxs) <> " elements but there are " <> pprint (length ps) <> " patterns." - flip foldMapM (zip ps idxs) $ \(p, i) -> do + flip foldMapM (zip ps idxs) \(p, i) -> do v <- lift $ emitZonked $ App val i bindPat' p v @@ -606,13 +742,16 @@ inferTabCon xs reqTy = do return (tabTy, xs') emitZonked $ Op $ TabCon tabTy xs' +fromUArrow :: UArrow -> Arrow +fromUArrow arr = fmap (const Pure) arr + -- Bool flag is just to tweak the reported error message fromPiType :: Bool -> UArrow -> Type -> UInferM PiType fromPiType _ _ (Pi piTy) = return piTy -- TODO: check arrow fromPiType expectPi arr ty = do a <- freshType TyKind b <- freshType TyKind - let piTy = Abs (Ignore a) (fmap (const Pure) arr, b) + let piTy = Abs (Ignore a) (fromUArrow arr, b) if expectPi then constrainEq (Pi piTy) ty else constrainEq ty (Pi piTy) return piTy @@ -630,9 +769,20 @@ emitZonked expr = zonk expr >>= emit addEffects :: EffectRow -> UInferM () addEffects eff = do - eff' <- openEffectRow eff - allowedEffects <- getAllowedEffects - constrainEq (Eff allowedEffects) (Eff eff') + allowed <- checkAllowedUnconditionally eff + unless allowed $ do + allowedEffects <- getAllowedEffects + eff' <- openEffectRow eff + constrainEq (Eff allowedEffects) (Eff eff') + +checkAllowedUnconditionally :: EffectRow -> UInferM Bool +checkAllowedUnconditionally Pure = return True +checkAllowedUnconditionally eff = do + eff' <- zonk eff + effAllowed <- getAllowedEffects >>= zonk + return $ case checkExtends effAllowed eff' of + Left _ -> False + Right () -> True openEffectRow :: EffectRow -> UInferM EffectRow openEffectRow (EffectRow effs Nothing) = extendEffRow effs <$> freshEff @@ -747,7 +897,7 @@ runSolverT m = liftM fst $ flip runCatT mempty $ do applyDefaults :: MonadCat SolverEnv m => m () applyDefaults = do vs <- looks unsolved - forM_ (envPairs vs) $ \(v, k) -> case k of + forM_ (envPairs vs) \(v, k) -> case k of EffKind -> addSub v $ Eff Pure _ -> return () where addSub v ty = extend $ SolverEnv mempty (v@>ty) @@ -771,8 +921,8 @@ checkLeaks tvs m = do unless (null $ resultTypeLeaks) $ throw TypeErr $ "Leaked local variable `" ++ pprint (head resultTypeLeaks) ++ "` in result type " ++ pprint (getType ans) - forM_ (solverSub env) $ \ty -> - forM_ tvs $ \tv -> + forM_ (solverSub env) \ty -> + forM_ tvs \tv -> throwIf (tv `occursIn` ty) TypeErr $ "Leaked type variable: " ++ pprint tv extend env return ans @@ -792,7 +942,7 @@ freshType EffKind = Eff <$> freshEff freshType k = Var . (:>k) <$> freshInferenceName k freshEff :: (MonadError Err m, MonadCat SolverEnv m) => m EffectRow -freshEff = EffectRow [] . Just <$> freshInferenceName EffKind +freshEff = EffectRow mempty . Just <$> freshInferenceName EffKind constrainEq :: (MonadCat SolverEnv m, MonadError Err m) => Type -> Type -> m () @@ -875,19 +1025,16 @@ unifyEff r1 r2 = do vs <- looks solverVars case (r1', r2') of _ | r1' == r2' -> return () - (r, EffectRow [] (Just v)) | v `isin` vs -> bindQ (v:>EffKind) (Eff r) - (EffectRow [] (Just v), r) | v `isin` vs -> bindQ (v:>EffKind) (Eff r) - (EffectRow effs1@(_:_) t1, EffectRow effs2@(_:_) t2) -> do - let extras1 = effs1 `setDiff` effs2 - let extras2 = effs2 `setDiff` effs1 + (r, EffectRow effs (Just v)) | S.null effs && v `isin` vs -> bindQ (v:>EffKind) (Eff r) + (EffectRow effs (Just v), r) | S.null effs && v `isin` vs -> bindQ (v:>EffKind) (Eff r) + (EffectRow effs1 t1, EffectRow effs2 t2) | not (S.null effs1 || S.null effs2) -> do + let extras1 = effs1 `S.difference` effs2 + let extras2 = effs2 `S.difference` effs1 newRow <- freshEff - unifyEff (EffectRow [] t1) (extendEffRow extras2 newRow) - unifyEff (extendEffRow extras1 newRow) (EffectRow [] t2) + unifyEff (EffectRow mempty t1) (extendEffRow extras2 newRow) + unifyEff (extendEffRow extras1 newRow) (EffectRow mempty t2) _ -> throw TypeErr "" -setDiff :: Eq a => [a] -> [a] -> [a] -setDiff xs ys = filter (`notElem` ys) xs - bindQ :: (MonadCat SolverEnv m, MonadError Err m) => Var -> Type -> m () bindQ v t | v `occursIn` t = throw TypeErr $ "Occurs check failure: " ++ pprint (v, t) | hasSkolems t = throw TypeErr "Can't unify with skolem vars" diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index f9f43fa18..3f9911119 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -66,7 +66,9 @@ evalExpr env expr = case expr of evalBlock env $ applyNaryAbs (alts !! i) (xss !! i) _ -> error $ "Not implemented: SumAsProd with tag " ++ pprint expr _ -> error $ "Unexpected scrutinee: " ++ pprint e - _ -> error $ "Not implemented: " ++ pprint expr + Hof hof -> case hof of + RunIO ~(Lam (Abs _ (_, body))) -> evalBlock env body + _ -> error $ "Not implemented: " ++ pprint expr evalOp :: Op -> InterpM Atom evalOp expr = case expr of @@ -94,14 +96,14 @@ evalOp expr = case expr of "randunif" -> Float64Val $ c_unif x where [Int64Val x] = args "threefry2x32" -> Int64Val $ c_threefry x y where [Int64Val x, Int64Val y] = args _ -> error $ "FFI function not recognized: " ++ name - PtrOffset (Con (Lit (PtrLit (_, a, t) p))) (IdxRepVal i) -> - return $ Con $ Lit $ PtrLit (DerivedPtr, a, t) $ p `plusPtr` (sizeOf t * fromIntegral i) - PtrLoad (Con (Lit (PtrLit (_, Heap CPU, t) p))) -> Con . Lit <$> loadLitVal p t - PtrLoad (Con (Lit (PtrLit (_, Heap GPU, t) p))) -> + PtrOffset (Con (Lit (PtrLit (a, t) p))) (IdxRepVal i) -> + return $ Con $ Lit $ PtrLit (a, t) $ p `plusPtr` (sizeOf t * fromIntegral i) + PtrLoad (Con (Lit (PtrLit (Heap CPU, t) p))) -> Con . Lit <$> loadLitVal p t + PtrLoad (Con (Lit (PtrLit (Heap GPU, t) p))) -> allocaBytes (sizeOf t) $ \hostPtr -> do loadCUDAArray hostPtr p (sizeOf t) Con . Lit <$> loadLitVal hostPtr t - PtrLoad (Con (Lit (PtrLit (_, Stack, _) _))) -> + PtrLoad (Con (Lit (PtrLit (Stack, _) _))) -> error $ "Unexpected stack pointer in interpreter" ToOrdinal idxArg -> case idxArg of Con (IntRangeVal _ _ i) -> return i diff --git a/src/lib/JAX.hs b/src/lib/JAX.hs deleted file mode 100644 index 7810b7220..000000000 --- a/src/lib/JAX.hs +++ /dev/null @@ -1,711 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE PatternSynonyms #-} -{-# OPTIONS_GHC -w #-} -- XXX: Disable once fixed -{-# OPTIONS_GHC -Wno-orphans #-} - -module JAX (JAtom (..), JVar, typeToJType, jTypeToType, - JExpr, JaxFunction, toJaxFunction, simplifyJaxFunction, - dceJaxFunction) where - -import Control.Applicative -import Control.Monad.Except hiding (Except) -import Control.Monad.Reader -import Control.Monad.Writer -import Control.Monad.State.Strict -import Data.Aeson hiding (Array) -import Data.Maybe -import Data.Text.Prettyprint.Doc -import GHC.Generics -import GHC.Stack - -import Env -import Syntax -import PPrint -import Type -import Cat -import Array - --- === JAXish IR === - -type AxisSize = Int -type JVar = VarP JType -type IdxVar = VarP AxisSize -data IdxFlavor = MapIdx | SumIdx deriving (Generic, Show, Eq) - -data JDecl = JLet JVar JFor deriving (Generic, Show, Eq) -data JExpr = JExpr [JDecl] [JAtom] deriving (Generic, Show, Eq) -data JAtom = JLit [Int] Array | JVar JVar deriving (Generic, Show, Eq) -data IdxAtom = IdxAtom JAtom [IdxVar] deriving (Generic, Show, Eq) -data JType = JType [AxisSize] ScalarBaseType deriving (Generic, Show, Eq) -data JaxFunction = JaxFunction [JVar] JExpr deriving (Generic, Show, Eq) - -type JOp = JOpP IdxAtom -data JOpP e = JId e - | JIota AxisSize - | JGet e e - | JScalarBinOp BinOp e e - | JThreeFry2x32 e e - | JScalarUnOp UnOp e - deriving (Generic, Functor, Foldable, Traversable, Show, Eq) - -data TmpAtom = TmpLeaf IdxAtom - | TmpRefName Var - | TmpCon (PrimCon TmpAtom) - deriving (Generic, Show, Eq) - -data JFor = JFor [(IdxVar, IdxFlavor)] (JOpP IdxAtom) - deriving (Generic, Show, Eq) - -type JScope = Env () -- TODO: put bindings here too - --- === lowering from Expr === - -type JEmbedEnv = (JScope, [JDecl]) -type JSubstEnv = Env TmpAtom -type EffectState = Env (Int, TmpAtom) -type IdxEnv = [IdxVar] -- for i j. --> [i, j] -type JaxM = ReaderT IdxEnv (StateT EffectState (Cat JEmbedEnv)) - -runJaxM :: JaxM a -> a -runJaxM m = fst $ flip runCat mempty $ - flip evalStateT mempty $ flip runReaderT mempty m - -toJaxFunction :: ([Var], Block) -> JaxFunction -toJaxFunction (vs, block) = runJaxM $ do - vs' <- mapM freshVar vs - let env = newEnv vs $ map varToTmpAtom vs - (result, (_, decls)) <- scoped $ do - result <- toJaxBlock env block - return $ flattenAtom result - let jvs = map (fmap typeToJType) vs' - return $ JaxFunction jvs $ JExpr decls result - -varToTmpAtom :: Var -> TmpAtom -varToTmpAtom v = TmpLeaf $ IdxAtom (JVar $ fmap typeToJType v) [] - -flattenAtom :: TmpAtom -> [JAtom] -flattenAtom atom = - execWriter $ traverseArrayLeaves atom $ \(IdxAtom x []) -> do - tell [x] - return $ IdxAtom x [] - -toJaxBlock :: JSubstEnv -> Block -> JaxM TmpAtom -toJaxBlock env (Block [] result) = toJaxExpr env result -toJaxBlock env (Block (decl:decls) result) = do - env' <- toJaxDecl env decl - toJaxBlock (env <> env') body - where body = Block decls result - -toJaxDecl :: JSubstEnv -> Decl -> JaxM JSubstEnv -toJaxDecl env (Let _ v rhs) = do - ans <- toJaxExpr env rhs - return $ v @> ans - -toJaxAtom :: JSubstEnv -> Atom -> TmpAtom -toJaxAtom env atom = case atom of - Var v@(_:>RefTy _ _) -> TmpRefName v - Var v -> fromMaybe (error "lookup failed") $ envLookup env v - Con (Lit x) -> tmpAtomScalarLit x - Con con -> TmpCon $ fmap (toJaxAtom env) con - _ -> error $ "Not implemented: " ++ pprint atom - -toJaxExpr :: JSubstEnv -> Expr -> JaxM TmpAtom -toJaxExpr env expr = case expr of - -- For _ (LamExpr i@(_ :> FixedIntRange 0 n) body) -> do - -- idxEnv <- ask - -- -- TODO: scope this to avoid burning through names - -- i' <-freshIdxVar n - -- iotaVar <- emitJFor $ JFor [] $ JIota n - -- let iotaAtom = iotaVarAsIdx (FixedIntRange 0 n) $ IdxAtom (JVar iotaVar) [i'] - -- let env' = env <> i @> iotaAtom - -- ans <- extendR [i'] $ toJaxBlock env' body - -- liftM (TmpCon . AFor (varAnn i)) $ traverseArrayLeaves ans $ \x -> do - -- ansVar <- emitJFor $ JFor (map (,MapIdx) (idxEnv ++ [i'])) $ JId x - -- return $ IdxAtom (JVar ansVar) idxEnv - -- TabGet xs i -> do - -- let (TmpCon (AFor _ tab)) = toJaxAtom env xs - -- let i' = toJaxAtom env i - -- traverseArrayLeaves tab $ \x -> emitOp $ JGet x $ fromScalarAtom i' - Op op -> toJaxOp $ fmap (toJaxAtom env) op - -toJaxOp :: PrimOp TmpAtom -> JaxM TmpAtom -toJaxOp op = case op of - ScalarBinOp op x y -> liftM toScalarAtom $ - emitOp $ JScalarBinOp op (fromScalarAtom x) (fromScalarAtom y) - IndexAsInt x -> liftM toScalarAtom $ - emitOp $ JId (fromScalarAtom x) - ScalarUnOp op x -> liftM toScalarAtom $ - emitOp $ JScalarUnOp op (fromScalarAtom x) - PrimEffect (TmpRefName refVar) m -> do - case m of - MTell x -> do - (depth, curAccum) <- gets (! refVar) - xSum <- sumPoly depth x - newAccum <- local (take depth) $ addPoly curAccum xSum - modify (<> refVar @> (depth, newAccum)) - return $ TmpCon $ UnitCon - _ -> error $ "Not implemented: " ++ show op - -- RecGet x i -> do - -- case x of - -- TmpCon (RecCon r) -> return $ recGet r i - -- val -> error $ "Expected a record, got: " ++ show val - FFICall s _ args | s == "threefry2x32" -> liftM toScalarAtom $ - emitOp $ JThreeFry2x32 (fromScalarAtom x) (fromScalarAtom y) - where x:y:[] = args - _ -> error $ "Not implemented: " ++ show op - --- toJaxHof :: PrimHof TmpAtom (LamExpr, JSubstEnv) -> JaxM TmpAtom --- toJaxHof hof = case hof of --- RunWriter (LamExpr refVar _ body, env) -> do --- idxEnvDepth <- asks length --- let (RefTy wTy) = varAnn refVar --- wInit <- zerosAt wTy --- modify (<> refVar @> (idxEnvDepth, wInit)) --- aResult <- toJaxBlock env body --- wFinal <- gets $ snd . (! refVar) --- modify $ envDelete (varName refVar) --- return $ TmpCon $ RecCon $ Tup [aResult, wFinal] --- _ -> error $ "Not implemented: " ++ show hof - -iotaVarAsIdx :: Type -> IdxAtom -> TmpAtom -iotaVarAsIdx = undefined --- iotaVarAsIdx n x = TmpCon $ AsIdx n $ toScalarAtom x - -fromScalarAtom :: HasCallStack => TmpAtom -> IdxAtom -fromScalarAtom atom = case atom of - TmpCon (Coerce _ x) -> fromScalarAtom x - --TmpCon (AGet (TmpLeaf x)) -> x - _ -> error $ "Not a scalar atom: " ++ show atom - -toScalarAtom :: IdxAtom -> TmpAtom -toScalarAtom x = undefined --TmpCon $ AGet $ TmpLeaf x - -traverseArrayLeaves :: HasCallStack => Monad m => TmpAtom -> (IdxAtom -> m IdxAtom) -> m TmpAtom -traverseArrayLeaves atom f = case atom of - TmpCon con -> liftM TmpCon $ case con of - --AFor n body -> liftM (AFor n) $ traverseArrayLeaves body f - --AGet (TmpLeaf x) -> liftM (AGet . TmpLeaf) $ f x - _ -> error $ "Not implemented: " ++ show atom - TmpLeaf x -> liftM TmpLeaf $ f x - TmpRefName _ -> error "Unexpected reference name" - -typeToJType :: Type -> JType -typeToJType ty = case ty of - TC (JArrayType dims b) -> JType dims b - _ -> error $ "Not a jax type: " ++ pprint ty - -jTypeToType :: JType -> Type -jTypeToType ty = case ty of - JType shape b -> TC $ JArrayType shape b - -emitOp :: JOpP IdxAtom -> JaxM IdxAtom -emitOp op = do - idxEnv <- ask - v <- emitJFor $ JFor (map (,MapIdx) idxEnv) op - return $ IdxAtom (JVar v) idxEnv - -zerosAt :: Type -> JaxM TmpAtom -zerosAt ty = case ty of - BaseTy (Scalar FloatType) -> return $ tmpAtomScalarLit $ FloatLit 0.0 - _ -> error "Not implemented" - -addPoly :: TmpAtom -> TmpAtom -> JaxM TmpAtom -addPoly x y = case getType x of - BaseTy (Scalar FloatType) -> liftM toScalarAtom $ - emitOp $ JScalarBinOp FAdd (fromScalarAtom x) (fromScalarAtom y) - ty -> error $ "Not implemented: " ++ pprint ty - -sumPoly :: Int -> TmpAtom -> JaxM TmpAtom -sumPoly depth atom = do - idxEnv <- ask - let (forIdxs, sumIdxs) = splitAt depth idxEnv - let idxBinders = zip forIdxs (repeat MapIdx) - <> zip sumIdxs (repeat SumIdx) - traverseArrayLeaves atom $ \x -> do - v <- emitJFor $ JFor idxBinders $ JId x - return $ IdxAtom (JVar v) forIdxs - -tmpAtomScalarLit :: LitVal -> TmpAtom -tmpAtomScalarLit x = toScalarAtom $ IdxAtom (JLit [] $ arrayFromScalar x) [] - -instance HasType TmpAtom where - typeCheck atom = case atom of - TmpLeaf idxAtom -> return $ jTypeToType $ getJType idxAtom - TmpRefName _ -> undefined - TmpCon con -> undefined - --- === Simplification pass on JAX IR === - -type BindingEnv = Env (VarUsage, JFor) -type SimpEnv = (Env JAtom, BindingEnv) -type SimpM = Cat JEmbedEnv - -pattern JForId :: JAtom -> JFor -pattern JForId x = JFor [] (JId (IdxAtom x [])) - -simplifyJaxFunction :: JaxFunction -> JaxFunction -simplifyJaxFunction (JaxFunction vs expr) = fst $ flip runCat mempty $ do - vs' <- mapM freshVar vs - let env = (newEnv vs (map JVar vs'), mempty) - (result', (_, decls')) <- scoped $ simplifyJaxExpr env expr - return $ JaxFunction vs' $ JExpr decls' result' - -simplifyJaxExpr :: SimpEnv -> JExpr -> SimpM [JAtom] -simplifyJaxExpr env expr@(JExpr decls results) = do - let usageEnv = collectUsage expr - (_, env') <- flip runCatT env $ mapM (simplifyJaxDecl usageEnv) decls - let (substEnv, _) = env <> env' - return $ fmap (jSubst substEnv) results - -simplifyJaxDecl :: UsageEnv -> JDecl -> CatT SimpEnv SimpM () -simplifyJaxDecl usageEnv (JLet v jfor) = do - (substEnv, bindingEnv) <- look - let usage = lookupUse usageEnv v - let jfor' = simpFix (simplifyJFor bindingEnv) $ jSubst substEnv jfor - case jfor' of - JForId x -> extend $ asFst (v @> x) - _ -> do - vOut <- lift $ emitJFor jfor' - extend $ (v @> JVar vOut, vOut @> (usage, jfor')) - -simplifyJFor :: BindingEnv -> JFor -> Maybe JFor -simplifyJFor env jfor@(JFor idxs op) = - liftM (JFor idxs) (mapParallel (inlineFromId env) op) - <|> inlineGetIota env jfor - <|> inlineIntoId env jfor - <|> liftM (JFor idxs) (algebraicSimp op) - <|> checkProgress etaReduce jfor - -inlineGetIota :: BindingEnv -> JFor -> Maybe JFor -inlineGetIota env (JFor idxBinders op) = do - let idxEnv = map fst idxBinders - JGet (IdxAtom x xIdxs) (IdxAtom (JVar v) idxs) <- return op - (_, varDef) <- envLookup env v - (JFor [] (JIota _), [i]) <- return $ betaReduce varDef idxs - let idxs' = xIdxs ++ [i] - -- TODO: have a more direct way to check index ordering condition - case checkIdxEnv idxs' idxEnv of - Left _ -> Nothing - Right () -> return $ JFor idxBinders $ JId $ IdxAtom x idxs' - -inlineIntoId :: BindingEnv -> JFor -> Maybe JFor -inlineIntoId env (JFor idxs op) = do - JId (IdxAtom (JVar v) appIdxs) <- return op - (UsedOnce, jfor) <- envLookup env v - let idxScope = foldMap ((@>()) . fst) idxs - let jforFresh = refreshIdxVars idxScope jfor - (jfor', []) <- return $ betaReduce jforFresh appIdxs - let (JFor idxs' op') = refreshIdxVars idxScope jfor' - return $ JFor (idxs <> idxs') op' - -inlineFromId :: BindingEnv -> IdxAtom -> Maybe IdxAtom -inlineFromId env idxAtom = do - IdxAtom (JVar v) idxs <- return idxAtom - (_, jfor) <- envLookup env v - (JFor [] (JId (IdxAtom x idxs')), idxs'') <- return $ betaReduce jfor idxs - return $ IdxAtom x (idxs' <> idxs'') - -algebraicSimp :: JOp -> Maybe JOp -algebraicSimp op = case op of - JScalarBinOp FAdd x y - | fromScalarLit x == Just (FloatLit 0) -> Just $ JId y - | fromScalarLit y == Just (FloatLit 0) -> Just $ JId x - _ -> Nothing - -fromScalarLit :: IdxAtom -> Maybe LitVal -fromScalarLit (IdxAtom (JLit [] x) []) = scalarFromArray x -fromScalarLit _ = Nothing - --- === variable usage pass === - -data VarUsage = Unused | UsedOnce | ArbitraryUse deriving (Show, Eq) - -type UsageEnv = MonMap Name VarUsage - -collectUsage :: JExpr -> UsageEnv -collectUsage (JExpr decls result) = snd $ flip runCat mempty $ do - extend $ useFreeVars ArbitraryUse result - forM_ (reverse decls) $ \(JLet v jfor) -> do - use <- looks $ flip lookupUse v - case use of - Unused -> return () - _ -> extend $ useFreeVars UsedOnce jfor - -lookupUse :: UsageEnv -> VarP ann -> VarUsage -lookupUse env (v:>_) = monMapLookup env v - -useFreeVars :: HasJVars a => VarUsage -> a -> UsageEnv -useFreeVars use x = foldMap (useVar use) $ envNames $ freeJVars x - -useVar :: VarUsage -> Name -> UsageEnv -useVar use v = monMapSingle v use - -instance Semigroup VarUsage where - Unused <> use = use - use <> Unused = use - _ <> _ = ArbitraryUse - -instance Monoid VarUsage where - mempty = Unused - -dceJaxFunction :: JaxFunction -> JaxFunction -dceJaxFunction (JaxFunction vs expr@(JExpr decls result)) = - JaxFunction vs (JExpr decls' result) - where - decls' = filter (\(JLet v _) -> lookupUse useEnv v /= Unused) decls - useEnv = collectUsage expr - --- === JAX IR builder === - -emitJFor :: MonadCat JEmbedEnv m => JFor -> m JVar -emitJFor jfor = do - v <- freshVar ("v":> getJType jfor) - extend $ (v @> (), [JLet v jfor]) - return v - -freshVar :: MonadCat JEmbedEnv m => VarP ann -> m (VarP ann) -freshVar v = do - scope <- looks fst - let v' = rename v scope - extend $ asFst (v' @> ()) - return v' - -freshIdxVar :: MonadCat JEmbedEnv m => AxisSize -> m IdxVar -freshIdxVar n = do - scope <- looks fst - let nameChoices = [Name JaxIdx name 0 | name <- ["i", "j", "k"]] - let v = renameChoice nameChoices scope :> n - extend $ asFst (v @> ()) - return v - --- === JAXy IR Types === - -type IdxTyEnv = [IdxVar] -type JTypeEnv = (Env JType, IdxEnv) - -instance Checkable JaxFunction where - checkValid (JaxFunction vs body) = do - let argTys = map varAnn vs - void $ checkJExprType (newEnv vs argTys, []) body - -checkJExprType :: JTypeEnv -> JExpr -> Except [JType] -checkJExprType initEnv (JExpr decls results) = - liftM fst $ flip runCatT initEnv $ do - forM_ decls $ \(JLet v@(_:>reqTy) jfor) -> do - env <- look - ty <- checkJType env jfor - assertEq reqTy ty "Annotation" - extend (v @> ty, []) - env <- look - forM results $ checkJType env - -class HasJType a where - getJType :: a -> JType - checkJType :: MonadError Err m => JTypeEnv -> a -> m JType - -instance HasJType JFor where - getJType (JFor idxs op) = JType (shape ++ shape') b - where - shape = [n | (_:>n, MapIdx) <- idxs] - (JType shape' b) = getJType op - - checkJType env jfor@(JFor idxs op) = - addContext ("\nChecking: " ++ pprint jfor) $ do - let idxBinders = map fst idxs - checkBinders env idxBinders - let env' = env <> (mempty, idxBinders) - let shape = [n | (_:>n, MapIdx) <- idxs] - (JType shape' b) <- checkJType env' op - return (JType (shape ++ shape') b) - -assertNoMutualShadows :: (MonadError Err m, Pretty b, Traversable f) - => f (VarP b) -> m () -assertNoMutualShadows bs = - void $ flip runCatT mempty $ forM bs $ \b -> do - env <- look - checkNoShadow env b - extend (b@>()) - -checkBinders :: (MonadError Err m, Pretty ann) => JTypeEnv -> [VarP ann] -> m () -checkBinders env bs = do - mapM_ (checkNoShadow (fst env)) bs - assertNoMutualShadows bs - -instance HasJType IdxAtom where - getJType (IdxAtom x idxs) = JType (drop (length idxs) shape) b - where JType shape b = getJType x - - checkJType (env, idxEnv) (IdxAtom x idxs) = do - JType shape b <- checkJType (env, []) x - throwIf (length idxs > length shape) CompilerErr $ - "Too many indices: " ++ pprint idxs - forM_ (zip idxs shape) $ \((_:>nAnn), nArr) -> - assertEq nArr nAnn "Index size doesn't match array shape" - checkIdxEnv idxs idxEnv - return $ JType (drop (length idxs) shape) b - -checkIdxEnv :: MonadError Err m => [IdxVar] -> IdxTyEnv -> m () -checkIdxEnv [] _ = return () -checkIdxEnv (i:_) [] = throw CompilerErr $ "Index not in env " ++ pprint i -checkIdxEnv (i:idxs) (i':idxEnv) - | varName i == varName i' = do - assertEq i' i "Index size doesn't match index env" - checkIdxEnv idxs idxEnv - | otherwise = checkIdxEnv (i:idxs) idxEnv - -instance HasJType JAtom where - getJType atom = case atom of - JVar (_:> ty) -> ty - JLit shape arr -> JType shape b - where (_, Scalar b) = arrayType arr - - checkJType (env,_) atom = case atom of - JVar v@(_:> ty) -> do - case envLookup env v of - Just reqTy -> do - assertEq reqTy ty "JVar" - return ty - _ -> throw CompilerErr $ "Lookup failed: " ++ pprint v - JLit shape arr -> return $ JType shape b - where (_, Scalar b) = arrayType arr - -instance (Pretty a, HasJType a) => HasJType (JOpP a) where - getJType op = ignoreExcept $ addContext ("Getting type of: " ++ pprint op) $ - traverseJOpType $ fmap getJType op - checkJType env op = do - op' <- traverse (checkJType env) op - traverseJOpType op' - -traverseJOpType :: MonadError Err m => JOpP JType -> m JType -traverseJOpType jop = case jop of - JScalarBinOp op xTy' yTy' -> do - assertEq (JType [] xTy) xTy' "Arg type mismatch" - assertEq (JType [] yTy) yTy' "Arg type mismatch" - return $ JType [] outTy - where (xTy, yTy, outTy) = binOpType op - JScalarUnOp op xTy' -> do - assertEq (JType [] xTy) xTy' "Arg type mismatch" - return $ JType [] outTy - where (xTy, outTy) = unOpType op - JThreeFry2x32 xTy yTy -> do - assertEq (JType [] IntType) xTy "Arg type mismatch" - assertEq (JType [] IntType) yTy "Arg type mismatch" - return $ JType [] IntType - JId ty -> return $ ty - JIota n -> return $ JType [n] IntType - JGet (JType (_:shape) b) idxTy -> do - assertEq (JType [] IntType) idxTy "Arg type mismatch" - return $ JType shape b - JGet (JType [] _) _ -> error "Attempting to index zero-dim array" - --- === free vars and substitutions === - -class HasJVars a where - freeJVars :: a -> Env () - jSubst :: Env JAtom -> a -> a - -instance HasJVars JFor where - freeJVars (JFor _ op) = freeJVars op - jSubst env (JFor idxs op) = JFor idxs $ jSubst env op - -instance HasJVars JAtom where - freeJVars x = case x of - JLit _ _ -> mempty - JVar v -> v @> () - jSubst env x = case x of - JLit _ _ -> x - JVar v -> env ! v - -instance HasJVars IdxAtom where - freeJVars (IdxAtom x _) = freeJVars x - jSubst env (IdxAtom x idxs) = IdxAtom (jSubst env x) idxs - -instance (Traversable f, HasJVars a) => HasJVars (f a) where - freeJVars xs = foldMap freeJVars xs - jSubst env op = fmap (jSubst env) op - -etaReduce :: JFor -> JFor -etaReduce (JFor [] op) = JFor [] op -etaReduce (JFor (b:bs) op) = do - let (JFor bs' op') = etaReduce (JFor bs op) - fromMaybe (JFor (b:bs') op') $ do - (i, MapIdx) <- return b - [] <- return bs' - JId (IdxAtom x idxs) <- return op' - (idxs', i') <- unsnoc idxs - unless (i == i') Nothing - return $ JFor bs' $ JId $ IdxAtom x idxs' - -betaReduce :: JFor -> [IdxVar] -> (JFor, [IdxVar]) -betaReduce jfor idxs = do - let freeVs = foldMap (@>()) idxs - let jfor' = refreshIdxVars freeVs jfor - betaReduceRec jfor' idxs - -betaReduceRec :: JFor -> [IdxVar] -> (JFor, [IdxVar]) -betaReduceRec jfor [] = (jfor, []) -betaReduceRec jfor idxs = do - let Just (rest, i) = unsnoc idxs - let (jfor', idxs') = betaReduceRec jfor rest - fromMaybe (jfor', idxs' ++ [i]) $ do - [] <- return idxs' - JFor ((b,MapIdx):bs) op <- return jfor' - return (JFor bs $ substOp (b @> i) op, []) - -refreshIdxVars :: JScope -> JFor -> JFor -refreshIdxVars scope (JFor binders op) = do - let (idxs, flavors) = unzip binders - let idxs' = fst $ renames idxs () scope - JFor (zip idxs' flavors) $ substOp (newEnv idxs idxs') op - --- TODO: extend `HasJVars` to handle index var substitution too -substOp :: Env IdxVar -> JOp -> JOp -substOp env op = flip fmap op $ \(IdxAtom x atomIdxs) -> - IdxAtom x $ map trySubst atomIdxs - where trySubst v = fromMaybe v (envLookup env v) - --- TODO: make a right-appending list we can actually pattern-match on -unsnoc :: [a] -> Maybe ([a], a) -unsnoc xs = case reverse xs of - [] -> Nothing - x:rest -> Just (reverse rest, x) - --- === simplification combinators === - --- Simplifiers must only produce `Just` if some progress was made. --- (e.g. avoid `mySimp x = trySimp x <|> pure x`) - -simpFix :: Eq a => (a -> Maybe a) -> a -> a -simpFix f x = case f x of - Nothing -> x - Just x' -> simpFix f x' - --- TODO: more efficient implementation without using Eq -mapParallel :: (Eq a, Eq (f a), Functor f) => (a -> Maybe a) -> f a -> Maybe (f a) -mapParallel f = checkProgress (fmap (\x -> fromMaybe x (f x))) - -checkProgress :: Eq a => (a -> a) -> a -> Maybe a -checkProgress f x | x' == x = Nothing - | otherwise = Just x' - where x' = f x - --- === instances === - -instance Pretty JaxFunction where - pretty (JaxFunction vs body) = "lambda" <+> pretty vs <> hardline <> pretty body - -instance Pretty JExpr where - pretty (JExpr decls results) = - foldMap (\d -> pretty d <> hardline) decls <> "results:" <+> pretty results - -instance Pretty IdxAtom where - pretty (IdxAtom x idxs) = pretty x <> foldMap (\(i:>_) -> "." <> pretty i) idxs - -instance Pretty JAtom where - pretty (JLit _ x) = pretty $ scalarFromArray x - pretty (JVar (v:>_)) = pretty v - -instance Pretty JDecl where - pretty (JLet v rhs) = pretty v <+> "=" <+> pretty rhs - -instance Pretty a => Pretty (JOpP a) where - pretty op = prettyOpName op <+> foldMap (\x -> parens (pretty x) <> " ") op - -instance Pretty JType where - pretty (JType s b) = pretty b <> pretty s - -instance Pretty JFor where - pretty (JFor [] op) = pretty op - pretty jfor@(JFor ((_,flavor):_) _) = - pretty s <+> prettyJForCtx flavor jfor - where - s :: String - s = case flavor of MapIdx -> "for" - SumIdx -> "sum" -instance Pretty TmpAtom where - pretty _ = "" - -prettyJForCtx :: IdxFlavor -> JFor -> Doc ann -prettyJForCtx flavor jfor@(JFor idxs op) = case idxs of - [] -> " . " <> pretty op - (i, flavor'):rest - | flavor == flavor' -> pretty (varName i) <+> - prettyJForCtx flavor (JFor rest op) - | otherwise -> pretty jfor - -prettyOpName :: JOpP a -> Doc ann -prettyOpName jop = case jop of - JScalarBinOp op _ _ -> pretty $ show op - JScalarUnOp op _ -> pretty $ show op - JThreeFry2x32 _ _ -> "threefry2x32" - JIota n -> "iota@" <> pretty n - JGet _ _ -> "get" - JId _ -> "id" - -instance ToJSON JDecl -instance FromJSON JDecl - -instance ToJSON JaxFunction -instance FromJSON JaxFunction - -instance ToJSON JExpr -instance FromJSON JExpr - -instance ToJSON JFor -instance FromJSON JFor - -instance ToJSON JAtom -instance FromJSON JAtom - -instance ToJSON IdxAtom -instance FromJSON IdxAtom - -instance ToJSON IdxFlavor -instance FromJSON IdxFlavor - -instance (ToJSON ann) => ToJSON (VarP ann) -instance (FromJSON ann) => FromJSON (VarP ann) - -instance (ToJSON e) => ToJSON (JOpP e) -instance (FromJSON e) => FromJSON (JOpP e) - -instance ToJSON JType -instance FromJSON JType - -instance ToJSON Name -instance FromJSON Name - -instance ToJSON NameSpace -instance FromJSON NameSpace - -instance ToJSON BinOp -instance FromJSON BinOp - -instance ToJSON UnOp -instance FromJSON UnOp - -instance ToJSON CmpOp -instance FromJSON CmpOp - -instance ToJSON LitVal -instance FromJSON LitVal - -instance ToJSON BaseType -instance FromJSON BaseType - -instance ToJSON ScalarBaseType -instance FromJSON ScalarBaseType - -instance ToJSON Array -instance FromJSON Array - -instance ToJSON Vec -instance FromJSON Vec diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index 8ddf1e67e..860829bc3 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -33,7 +33,6 @@ import Control.Monad.State.Strict import Control.Monad.Reader import Data.ByteString.Short (toShort) import qualified Data.ByteString.Char8 as B -import Data.List (concat) import Data.String import Data.Foldable import Data.Text.Prettyprint.Doc @@ -91,18 +90,20 @@ compileFunction logger fun@(ImpFunction f bs body) = case cc of extraSpecs <- gets funSpecs return ([L.GlobalDefinition mainFun], extraSpecs) EntryFun requiresCUDA -> return $ runCompile CPU $ do + (streamFDParam , streamFDOperand ) <- freshParamOpPair attrs $ i32 (argPtrParam , argPtrOperand ) <- freshParamOpPair attrs $ hostPtrTy i64 (resultPtrParam, resultPtrOperand) <- freshParamOpPair attrs $ hostPtrTy i64 - argOperands <- forM (zip [0..] argTys) $ \(i, ty) -> + initializeOutputStream streamFDOperand + argOperands <- forM (zip [0..] argTys) \(i, ty) -> gep argPtrOperand (i64Lit i) >>= castLPtr (scalarTy ty) >>= load when (toBool requiresCUDA) ensureHasCUDAContext results <- extendOperands (newEnv bs argOperands) $ compileBlock body - forM_ (zip [0..] results) $ \(i, x) -> + forM_ (zip [0..] results) \(i, x) -> gep resultPtrOperand (i64Lit i) >>= castLPtr (L.typeOf x) >>= flip store x mainFun <- makeFunction (asLLVMName name) - [argPtrParam, resultPtrParam] (Just $ i64Lit 0) + [streamFDParam, argPtrParam, resultPtrParam] (Just $ i64Lit 0) extraSpecs <- gets funSpecs - return ([L.GlobalDefinition mainFun], extraSpecs) + return ([L.GlobalDefinition mainFun, outputStreamPtrDef], extraSpecs) where attrs = [L.NoAlias, L.NoCapture, L.NonNull] CUDAKernelLaunch -> do (CUDAKernel kernelText) <- compileCUDAKernel logger $ impKernelToLLVMGPU fun @@ -200,8 +201,8 @@ compileInstr instr = case instr of IFor d i n body -> [] <$ do n' <- compileExpr n compileLoop d i n' $ compileVoidBlock body - IWhile cond body -> [] <$ do - compileWhile (head <$> compileBlock cond) (compileVoidBlock body) + IWhile body -> [] <$ do + compileWhile (head <$> compileBlock body) ICond p cons alt -> [] <$ do p' <- compileExpr p >>= (`asIntWidth` i1) compileIf p' (compileVoidBlock cons) (compileVoidBlock alt) @@ -265,15 +266,15 @@ compileInstr instr = case instr of GPU -> cuMemAlloc elemTy numBytes where elemTy = scalarTy t Free ptr -> [] <$ do - let PtrType (_, addr, _) = getIType ptr + let PtrType (addr, _) = getIType ptr ptr' <- compileExpr ptr case addr of Heap CPU -> free ptr' Heap GPU -> cuMemFree ptr' Stack -> error "Shouldn't be freeing alloca" MemCopy dest src numel -> [] <$ do - let PtrType (_, destAddr, ty) = getIType dest - let PtrType (_, srcAddr , _ ) = getIType src + let PtrType (destAddr, ty) = getIType dest + let PtrType (srcAddr , _ ) = getIType src destDev <- deviceFromAddr destAddr srcDev <- deviceFromAddr srcAddr dest' <- compileExpr dest >>= castVoidPtr @@ -302,6 +303,10 @@ compileInstr instr = case instr of GT -> emitInstr dt $ L.FPTrunc x dt [] (L.FloatingPointType _, L.IntegerType _) -> emitInstr dt $ L.FPToSI x dt [] (L.IntegerType _, L.FloatingPointType _) -> emitInstr dt $ L.SIToFP x dt [] + (L.PointerType _ _, L.PointerType eltTy _) -> castLPtr eltTy x + (L.IntegerType 64 , ptrTy@(L.PointerType _ _)) -> + emitInstr ptrTy $ L.IntToPtr x ptrTy [] + (L.PointerType _ _, L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 [] _ -> error $ "Unsupported cast" ICall f@(fname:> IFunType cc argTys resultTys) args -> do -- TODO: consider having a separate calling convention specification rather @@ -310,8 +315,7 @@ compileInstr instr = case instr of let resultTys' = map scalarTy resultTys case cc of FFIFun -> do - let [resultTy] = resultTys' - ans <- emitInstr resultTy $ externCall (makeFunSpec f) args' + ans <- emitExternCall (makeFunSpec f) args' return [ans] FFIMultiResultFun -> do resultPtr <- makeMultiResultAlloc resultTys' @@ -366,14 +370,13 @@ compileIf cond tb fb = do fb finishBlock (L.Br contName []) contName -compileWhile :: Compile Operand -> Compile () -> Compile () -compileWhile compileCond compileBody = do +compileWhile :: Compile Operand -> Compile () +compileWhile compileBody = do loopBlock <- freshName "whileLoop" nextBlock <- freshName "whileCont" - entryCond <- compileCond >>= (`asIntWidth` i1) + entryCond <- compileBody >>= (`asIntWidth` i1) finishBlock (L.CondBr entryCond loopBlock nextBlock []) loopBlock - compileBody - loopCond <- compileCond >>= (`asIntWidth` i1) + loopCond <- compileBody >>= (`asIntWidth` i1) finishBlock (L.CondBr loopCond loopBlock nextBlock []) nextBlock throwRuntimeError :: Compile () @@ -580,6 +583,21 @@ _gpuDebugPrint i32Val = do genericPtrTy ty = L.PointerType ty $ L.AddrSpace 0 vprintfSpec = ExternFunSpec "vprintf" i32 [] [] [genericPtrTy i8, genericPtrTy i8] +-- Takes a single int64 payload. TODO: implement a varargs version +_debugPrintf :: String -> Operand -> Compile () +_debugPrintf fmtStr x = do + let chars = map (C.Int 8) $ map (fromIntegral . fromEnum) fmtStr ++ [0] + let formatStrArr = L.ConstantOperand $ C.Array i8 chars + formatStrPtr <- alloca (length chars) i8 + castLPtr (L.typeOf formatStrArr) formatStrPtr >>= (`store` formatStrArr) + void $ emitExternCall printfSpec [formatStrPtr, x] + where printfSpec = ExternFunSpec "printf" i32 [] [] [hostVoidp, i64] + +_debugPrintfPtr :: String -> Operand -> Compile () +_debugPrintfPtr s x = do + x' <- emitInstr i64 $ L.PtrToInt x i64 [] + _debugPrintf s x' + compileBlock :: ImpBlock -> Compile [Operand] compileBlock (ImpBlock Empty result) = traverse compileExpr result compileBlock (ImpBlock (Nest decl rest) result) = do @@ -604,7 +622,7 @@ compileExpr expr = case expr of packArgs :: [Operand] -> Compile Operand packArgs elems = do arr <- alloca (length elems) hostVoidp - forM_ (zip [0..] elems) $ \(i, e) -> do + forM_ (zip [0..] elems) \(i, e) -> do eptr <- alloca 1 $ L.typeOf e store eptr e earr <- gep arr $ i32Lit i @@ -613,7 +631,7 @@ packArgs elems = do unpackArgs :: Operand -> [L.Type] -> Compile [Operand] unpackArgs argArrayPtr types = - forM (zip [0..] types) $ \(i, ty) -> do + forM (zip [0..] types) \(i, ty) -> do argVoidPtr <- gep argArrayPtr $ i64Lit i argPtr <- castLPtr (hostPtrTy ty) argVoidPtr load =<< load argPtr @@ -621,7 +639,7 @@ unpackArgs argArrayPtr types = makeMultiResultAlloc :: [L.Type] -> Compile Operand makeMultiResultAlloc tys = do resultsPtr <- alloca (length tys) hostVoidp - forM_ (zip [0..] tys) $ \(i, ty) -> do + forM_ (zip [0..] tys) \(i, ty) -> do ptr <- alloca 1 ty >>= castVoidPtr resultsPtrOffset <- gep resultsPtr $ i32Lit i store resultsPtrOffset ptr @@ -629,7 +647,7 @@ makeMultiResultAlloc tys = do loadMultiResultAlloc :: [L.Type] -> Operand -> Compile [Operand] loadMultiResultAlloc tys ptr = - forM (zip [0..] tys) $ \(i, ty) -> + forM (zip [0..] tys) \(i, ty) -> gep ptr (i32Lit i) >>= load >>= castLPtr ty >>= load runMCKernel :: ExternFunSpec @@ -765,7 +783,7 @@ scalarTy b = case b of Float64Type -> fp64 Float32Type -> fp32 Vector sb -> L.VectorType (fromIntegral vectorWidth) $ scalarTy $ Scalar sb - PtrType (_, s, t) -> L.PointerType (scalarTy t) (lAddress s) + PtrType (s, t) -> L.PointerType (scalarTy t) (lAddress s) hostPtrTy :: L.Type -> L.Type hostPtrTy ty = L.PointerType ty $ L.AddrSpace 0 @@ -798,7 +816,7 @@ asIntWidth :: Operand -> L.Type -> Compile Operand asIntWidth op ~expTy@(L.IntegerType expWidth) = case compare expWidth opWidth of LT -> emitInstr expTy $ L.Trunc op expTy [] EQ -> return op - GT -> emitInstr expTy $ L.ZExt op expTy [] + GT -> emitInstr expTy $ L.SExt op expTy [] where ~(L.IntegerType opWidth) = L.typeOf op freshParamOpPair :: [L.ParameterAttribute] -> L.Type -> Compile (Parameter, Operand) @@ -858,16 +876,40 @@ cpuBinaryIntrinsic op x y = case L.typeOf x of floatIntrinsic ty name = ExternFunSpec (L.mkName name) ty [] [] [ty, ty] callFloatIntrinsic ty name = emitExternCall (floatIntrinsic ty name) [x, y] +-- === Output stream === + +outputStreamPtrLName :: L.Name +outputStreamPtrLName = asLLVMName outputStreamPtrName + +outputStreamPtrDef :: L.Definition +outputStreamPtrDef = L.GlobalDefinition $ L.globalVariableDefaults + { L.name = outputStreamPtrLName + , L.type' = hostVoidp + , L.linkage = L.Private + , L.initializer = Just $ C.Null hostVoidp } + +outputStreamPtr :: Operand +outputStreamPtr = L.ConstantOperand $ C.GlobalReference + (hostPtrTy hostVoidp) outputStreamPtrLName + +initializeOutputStream :: Operand -> Compile () +initializeOutputStream streamFD = do + streamPtr <- emitExternCall fdopenFun [streamFD] + store outputStreamPtr streamPtr + +outputStreamEnv :: OperandEnv +outputStreamEnv = outputStreamPtrName @> outputStreamPtr + -- === Compile monad utilities === runCompile :: Device -> Compile a -> a runCompile dev m = evalState (runReaderT m env) initState where - env = CompileEnv mempty dev + env = CompileEnv outputStreamEnv dev initState = CompileState [] [] [] "start_block" mempty mempty mempty extendOperands :: OperandEnv -> Compile a -> Compile a -extendOperands openv = local $ \env -> env { operandEnv = (operandEnv env) <> openv } +extendOperands openv = local \env -> env { operandEnv = (operandEnv env) <> openv } lookupImpVar :: IVar -> Compile Operand lookupImpVar v = asks ((! v) . operandEnv) @@ -885,7 +927,7 @@ freshName :: Name -> Compile L.Name freshName v = do used <- gets usedNames let v' = genFresh v used - modify $ \s -> s { usedNames = used <> v' @> () } + modify \s -> s { usedNames = used <> v' @> () } return $ nameToLName v' where nameToLName :: Name -> L.Name @@ -945,6 +987,9 @@ mallocFun = ExternFunSpec "malloc_dex" (hostPtrTy i8) [L.NoAlias] [] [i64] freeFun :: ExternFunSpec freeFun = ExternFunSpec "free_dex" L.VoidType [] [] [hostPtrTy i8] +fdopenFun :: ExternFunSpec +fdopenFun = ExternFunSpec "fdopen_w" (hostPtrTy i8) [L.NoAlias] [] [i32] + boolTy :: L.Type boolTy = i8 diff --git a/src/lib/LLVM/JIT.hs b/src/lib/LLVM/JIT.hs index 2075902f5..c73b396be 100644 --- a/src/lib/LLVM/JIT.hs +++ b/src/lib/LLVM/JIT.hs @@ -29,10 +29,11 @@ import qualified LLVM.OrcJIT as OrcJIT import qualified LLVM.Target as T import qualified LLVM.Linking as Linking -import qualified LLVM.Module as Mod -import qualified LLVM.AST as L -import qualified LLVM.AST.Global as L +import qualified LLVM.AST +import qualified LLVM.AST.Global as LLVM.AST import qualified LLVM.AST.Constant as C +import qualified LLVM.Module as LLVM +import qualified LLVM.Context as LLVM import LLVM.Shims @@ -71,24 +72,30 @@ data NativeModule = NativeModule { moduleJIT :: JIT , moduleKey :: OrcJIT.ModuleKey , moduleDtors :: [FunPtr (IO ())] + , llvmModule :: LLVM.Module + , llvmContext :: LLVM.Context } --- XXX: This destroys the passed in module! +type CompilationPipeline = LLVM.Module -> IO () + -- TODO: This leaks resources if we fail halfway -compileModule :: JIT -> (L.Module, Mod.Module) -> IO NativeModule -compileModule moduleJIT@JIT{..} (ast, m) = do +compileModule :: JIT -> LLVM.AST.Module -> CompilationPipeline -> IO NativeModule +compileModule moduleJIT@JIT{..} ast compilationPipeline = do + llvmContext <- LLVM.createContext + llvmModule <- LLVM.createModuleFromAST llvmContext ast + compilationPipeline llvmModule moduleKey <- OrcJIT.allocateModuleKey execSession resolver <- newSymbolResolver execSession (makeResolver compileLayer) modifyIORef resolvers (M.insert moduleKey resolver) - OrcJIT.addModule compileLayer moduleKey m - moduleDtors <- forM dtorNames $ \dtorName -> do + OrcJIT.addModule compileLayer moduleKey llvmModule + moduleDtors <- forM dtorNames \dtorName -> do dtorSymbol <- OrcJIT.mangleSymbol compileLayer (fromString dtorName) Right (OrcJIT.JITSymbol dtorAddr _) <- OrcJIT.findSymbol compileLayer dtorSymbol False return $ castPtrToFunPtr $ wordPtrToPtr dtorAddr return NativeModule{..} where makeResolver :: OrcJIT.IRCompileLayer OrcJIT.ObjectLinkingLayer -> OrcJIT.SymbolResolver - makeResolver cl = OrcJIT.SymbolResolver $ \sym -> do + makeResolver cl = OrcJIT.SymbolResolver \sym -> do rsym <- OrcJIT.findSymbol cl sym False -- We look up functions like malloc in the current process -- TODO: Use JITDylibs to avoid inlining addresses as constants: @@ -109,15 +116,16 @@ compileModule moduleJIT@JIT{..} (ast, m) = do -- Unfortunately the JIT layers we use here don't handle the destructors properly, -- so we have to find and call them ourselves. dtorNames = do - let dtorStructs = flip foldMap (L.moduleDefinitions ast) $ \case - L.GlobalDefinition - L.GlobalVariable{name="llvm.global_dtors", - initializer=Just (C.Array _ elems), - ..} -> elems + let dtorStructs = flip foldMap (LLVM.AST.moduleDefinitions ast) \case + LLVM.AST.GlobalDefinition + LLVM.AST.GlobalVariable{ + name="llvm.global_dtors", + initializer=Just (C.Array _ elems), + ..} -> elems _ -> [] -- Sort in the order of decreasing priority! fmap snd $ sortBy (flip compare) $ flip fmap dtorStructs $ - \(C.Struct _ _ [C.Int _ n, C.GlobalReference _ (L.Name dname), _]) -> + \(C.Struct _ _ [C.Int _ n, C.GlobalReference _ (LLVM.AST.Name dname), _]) -> (n, C8BS.unpack $ SBS.fromShort dname) foreign import ccall "dynamic" @@ -133,13 +141,15 @@ unloadNativeModule NativeModule{..} = do modifyIORef resolvers (M.delete moduleKey) OrcJIT.removeModule compileLayer moduleKey OrcJIT.releaseModuleKey execSession moduleKey + LLVM.disposeModule llvmModule + LLVM.disposeContext llvmContext -withNativeModule :: JIT -> (L.Module, Mod.Module) -> (NativeModule -> IO a) -> IO a -withNativeModule jit m = bracket (compileModule jit m) unloadNativeModule +withNativeModule :: JIT -> LLVM.AST.Module -> CompilationPipeline -> (NativeModule -> IO a) -> IO a +withNativeModule jit m p = bracket (compileModule jit m p) unloadNativeModule getFunctionPtr :: NativeModule -> String -> IO (FunPtr a) getFunctionPtr NativeModule{..} funcName = do let JIT{..} = moduleJIT symbol <- OrcJIT.mangleSymbol compileLayer $ fromString funcName - Right (OrcJIT.JITSymbol funcAddr _) <- OrcJIT.findSymbol compileLayer symbol False + Right (OrcJIT.JITSymbol funcAddr _) <- OrcJIT.findSymbolIn compileLayer moduleKey symbol False return $ castPtrToFunPtr $ wordPtrToPtr funcAddr diff --git a/src/lib/LLVM/Shims.hs b/src/lib/LLVM/Shims.hs index 9cbd02119..e509ac12d 100644 --- a/src/lib/LLVM/Shims.hs +++ b/src/lib/LLVM/Shims.hs @@ -7,7 +7,6 @@ module LLVM.Shims ( SymbolResolver (..), newSymbolResolver, disposeSymbolResolver, newTargetMachine, newHostTargetMachine, disposeTargetMachine, - newTargetOptions, disposeTargetOptions ) where import qualified Data.Map as M @@ -36,7 +35,7 @@ data SymbolResolver = SymbolResolver (FunPtr FFIResolver) (Ptr OrcJIT.FFI.Symbol -- | Create a `FFI.SymbolResolver` that can be used with the JIT. newSymbolResolver :: OrcJIT.ExecutionSession -> OrcJIT.SymbolResolver -> IO SymbolResolver newSymbolResolver (OrcJIT.ExecutionSession session) (OrcJIT.SymbolResolver resolverFn) = do - ffiResolverPtr <- wrapFFIResolver $ \sym res -> do + ffiResolverPtr <- wrapFFIResolver \sym res -> do f <- encodeM =<< resolverFn =<< decodeM sym f res lambdaResolver <- OrcJIT.FFI.createLambdaResolver session ffiResolverPtr @@ -61,10 +60,10 @@ newTargetMachine :: Target.Target newTargetMachine (Target.Target targetFFI) triple cpu features (Target.TargetOptions targetOptFFI) relocModel codeModel cgoLevel = do - SBS.useAsCString triple $ \tripleFFI -> do - BS.useAsCString cpu $ \cpuFFI -> do + SBS.useAsCString triple \tripleFFI -> do + BS.useAsCString cpu \cpuFFI -> do let featuresStr = BS.intercalate "," $ fmap encodeFeature $ M.toList features - BS.useAsCString featuresStr $ \featuresFFI -> do + BS.useAsCString featuresStr \featuresFFI -> do relocModelFFI <- encodeM relocModel codeModelFFI <- encodeM codeModel cgoLevelFFI <- encodeM cgoLevel @@ -73,24 +72,15 @@ newTargetMachine (Target.Target targetFFI) triple cpu features targetOptFFI relocModelFFI codeModelFFI cgoLevelFFI where encodeFeature (Target.CPUFeature f, on) = (if on then "+" else "-") <> f -newHostTargetMachine :: R.Model -> CM.Model -> CGO.Level -> IO (Target.TargetMachine, Target.TargetOptions) +newHostTargetMachine :: R.Model -> CM.Model -> CGO.Level -> IO Target.TargetMachine newHostTargetMachine relocModel codeModel cgoLevel = do Target.initializeAllTargets triple <- Target.getProcessTargetTriple (target, _) <- Target.lookupTarget Nothing triple cpu <- Target.getHostCPUName features <- Target.getHostCPUFeatures - targetOptions <- newTargetOptions - tm <- newTargetMachine target triple cpu features targetOptions relocModel codeModel cgoLevel - return (tm, targetOptions) + Target.withTargetOptions \targetOptions -> + newTargetMachine target triple cpu features targetOptions relocModel codeModel cgoLevel disposeTargetMachine :: Target.TargetMachine -> IO () disposeTargetMachine (Target.TargetMachine tmFFI) = Target.FFI.disposeTargetMachine tmFFI - --- llvm-hs doesn't expose any way to manage target options in a non-bracketed way - -newTargetOptions :: IO Target.TargetOptions -newTargetOptions = Target.TargetOptions <$> Target.FFI.createTargetOptions - -disposeTargetOptions :: Target.TargetOptions -> IO () -disposeTargetOptions (Target.TargetOptions optsFFI) = Target.FFI.disposeTargetOptions optsFFI diff --git a/src/lib/LLVMExec.hs b/src/lib/LLVMExec.hs index e6dcc2095..f4435dd19 100644 --- a/src/lib/LLVMExec.hs +++ b/src/lib/LLVMExec.hs @@ -10,6 +10,7 @@ module LLVMExec (LLVMKernel (..), ptxDataLayout, ptxTargetTriple, compileAndEval, compileAndBench, exportObjectFile, + standardCompilationPipeline, compileCUDAKernel, loadLitVal) where import qualified LLVM.Analysis as L @@ -27,9 +28,10 @@ import qualified LLVM.Internal.Module as Mod import qualified LLVM.PassManager as P import qualified LLVM.Transforms as P import qualified LLVM.Target as T -import qualified LLVM.Linking as Linking import LLVM.Context -import Data.Time.Clock (getCurrentTime, diffUTCTime) +import Data.Int +import GHC.IO.FD +import GHC.IO.Handle.FD import System.IO import System.IO.Unsafe import System.IO.Temp @@ -40,83 +42,108 @@ import System.Exit import Foreign.Marshal.Alloc import Foreign.Ptr +import Foreign.C.Types (CInt (..)) import Foreign.Storable hiding (alignment) import Control.Monad +import Control.Concurrent import Control.Exception hiding (throw) import Data.ByteString.Short (ShortByteString) import Data.ByteString.Char8 (unpack, pack) import qualified Data.ByteString.Char8 as B import qualified Data.Map as M import qualified Data.Set as S +import qualified Control.Exception as E import Logging import Syntax import Resources import CUDA (synchronizeCUDA) import LLVM.JIT +import Util (measureSeconds) -- === One-shot evaluation === -compileAndEval :: Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] -compileAndEval logger ast fname args resultTypes = do - allocaBytes (length args * cellSize) $ \argsPtr -> - allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do - storeLitVals argsPtr args - evalTime <- compileOneOff logger ast fname $ checkedCallFunPtr False argsPtr resultPtr - logThis logger [EvalTime evalTime Nothing] - loadLitVals resultPtr resultTypes - -compileAndBench :: Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] -compileAndBench logger ast fname args resultTypes = do - allocaBytes (length args * cellSize) $ \argsPtr -> - allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do - storeLitVals argsPtr args - compileOneOff logger ast fname $ \fPtr -> do - -- First warmup iteration, which we also use to get the results - void $ checkedCallFunPtr True argsPtr resultPtr fPtr - results <- loadLitVals resultPtr resultTypes - let run = do - time <- checkedCallFunPtr True argsPtr resultPtr fPtr - _benchResults <- loadLitVals resultPtr resultTypes - -- TODO: Free results! - return time - exampleDuration <- run - let timeBudget = 2 -- seconds - let benchRuns = (ceiling $ timeBudget / exampleDuration) :: Int - times <- forM [1..benchRuns] $ const run - let avgTime = sum times / (fromIntegral benchRuns) - logThis logger [EvalTime avgTime (Just benchRuns)] - return results - foreign import ccall "dynamic" - callFunPtr :: DexExecutable -> Ptr () -> Ptr () -> IO DexExitCode + callFunPtr :: DexExecutable -> Int32 -> Ptr () -> Ptr () -> IO DexExitCode -type DexExecutable = FunPtr (Ptr () -> Ptr () -> IO DexExitCode) +type DexExecutable = FunPtr (Int32 -> Ptr () -> Ptr () -> IO DexExitCode) type DexExitCode = Int -checkedCallFunPtr :: Bool -> Ptr () -> Ptr () -> DexExecutable -> IO Double -checkedCallFunPtr sync argsPtr resultPtr fPtr = do - t1 <- getCurrentTime - exitCode <- callFunPtr fPtr argsPtr resultPtr - when sync $ synchronizeCUDA - t2 <- getCurrentTime +compileAndEval :: Logger [Output] -> L.Module -> String + -> [LitVal] -> [BaseType] -> IO [LitVal] +compileAndEval logger ast fname args resultTypes = do + withPipeToLogger logger \fd -> + allocaBytes (length args * cellSize) \argsPtr -> + allocaBytes (length resultTypes * cellSize) \resultPtr -> do + storeLitVals argsPtr args + evalTime <- compileOneOff logger ast fname $ + checkedCallFunPtr fd argsPtr resultPtr + logThis logger [EvalTime evalTime Nothing] + loadLitVals resultPtr resultTypes + +compileAndBench :: Bool -> Logger [Output] -> L.Module -> String + -> [LitVal] -> [BaseType] -> IO [LitVal] +compileAndBench shouldSyncCUDA logger ast fname args resultTypes = do + withPipeToLogger logger \fd -> + allocaBytes (length args * cellSize) \argsPtr -> + allocaBytes (length resultTypes * cellSize) \resultPtr -> do + storeLitVals argsPtr args + compileOneOff logger ast fname \fPtr -> do + ((avgTime, benchRuns, results), totalTime) <- measureSeconds $ do + -- First warmup iteration, which we also use to get the results + void $ checkedCallFunPtr fd argsPtr resultPtr fPtr + results <- loadLitVals resultPtr resultTypes + let run = do + let (CInt fd') = fdFD fd + exitCode <- callFunPtr fPtr fd' argsPtr resultPtr + unless (exitCode == 0) $ throwIO $ Err RuntimeErr Nothing "" + -- TODO: Free results! + exampleDuration <- snd <$> measureSeconds run + let timeBudget = 2 -- seconds + let benchRuns = (ceiling $ timeBudget / exampleDuration) :: Int + totalTime <- liftM snd $ measureSeconds $ do + forM_ [1..benchRuns] $ const run + when shouldSyncCUDA $ synchronizeCUDA + let avgTime = totalTime / (fromIntegral benchRuns) + return (avgTime, benchRuns, results) + logThis logger [EvalTime avgTime (Just (benchRuns, totalTime))] + return results + +withPipeToLogger :: Logger [Output] -> (FD -> IO a) -> IO a +withPipeToLogger logger writeAction = do + result <- snd <$> withPipe + (\h -> readStream h \s -> logThis logger [TextOut s]) + (\h -> handleToFd h >>= writeAction) + case result of + Left e -> E.throw e + Right ans -> return ans + +checkedCallFunPtr :: FD -> Ptr () -> Ptr () -> DexExecutable -> IO Double +checkedCallFunPtr fd argsPtr resultPtr fPtr = do + let (CInt fd') = fdFD fd + (exitCode, duration) <- measureSeconds $ do + exitCode <- callFunPtr fPtr fd' argsPtr resultPtr + return exitCode unless (exitCode == 0) $ throwIO $ Err RuntimeErr Nothing "" - return $ t2 `secondsSince` t1 - where - secondsSince end start = realToFrac $ end `diffUTCTime` start + return duration compileOneOff :: Logger [Output] -> L.Module -> String -> (DexExecutable -> IO a) -> IO a compileOneOff logger ast name f = do - withContext $ \c -> do - Mod.withModuleFromAST c ast $ \m -> do - withHostTargetMachine $ \tm -> do - linkDexrt c m - let exports = [name] - internalize exports m - optimizeModule c logger tm m - withJIT tm $ \jit -> do - withNativeModule jit (ast, m) $ \compiled -> - f =<< getFunctionPtr compiled name + withHostTargetMachine \tm -> + withJIT tm \jit -> + withNativeModule jit ast (standardCompilationPipeline logger [name] tm) \compiled -> + f =<< getFunctionPtr compiled name + +standardCompilationPipeline :: Logger [Output] -> [String] -> T.TargetMachine -> Mod.Module -> IO () +standardCompilationPipeline logger exports tm m = do + linkDexrt m + internalize exports m + showModule m >>= logPass JitPass + L.verify m + runDefaultPasses tm m + showModule m >>= logPass LLVMOpt + showAsm tm m >>= logPass AsmPass + where logPass passName s = logThis logger [PassInfo passName s] -- === object file export === @@ -124,53 +151,40 @@ compileOneOff logger ast name f = do -- Each module comes with a list of exported functions exportObjectFile :: FilePath -> [(L.Module, [String])] -> IO () exportObjectFile objFile modules = do - withContext $ \c -> do - void $ Linking.loadLibraryPermanently Nothing - withHostTargetMachine $ \tm -> - withBrackets (fmap (toLLVM c tm) modules) $ \mods -> do - Mod.withModuleFromAST c L.defaultModule $ \exportMod -> do + withContext \c -> do + withHostTargetMachine \tm -> + withBrackets (fmap (toLLVM c) modules) \mods -> do + Mod.withModuleFromAST c L.defaultModule \exportMod -> do void $ foldM linkModules exportMod mods - linkDexrt c exportMod - internalize allExports exportMod + execLogger Nothing \logger -> + standardCompilationPipeline logger allExports tm exportMod Mod.writeObjectToFile tm (Mod.File objFile) exportMod where allExports = foldMap snd modules - toLLVM :: Context -> T.TargetMachine -> (L.Module, [String]) -> (Mod.Module -> IO a) -> IO a - toLLVM c tm (ast, exports) cont = do - Mod.withModuleFromAST c ast $ \m -> do - internalize exports m - execLogger Nothing $ \logger -> optimizeModule c logger tm m - cont m + toLLVM :: Context -> (L.Module, [String]) -> (Mod.Module -> IO a) -> IO a + toLLVM c (ast, exports) cont = do + Mod.withModuleFromAST c ast \m -> internalize exports m >> cont m linkModules a b = a <$ Mod.linkModules a b withBrackets :: [(a -> IO b) -> IO b] -> ([a] -> IO b) -> IO b withBrackets brackets f = go brackets [] where - go (h:t) args = h $ \arg -> go t (arg:args) + go (h:t) args = h \arg -> go t (arg:args) go [] args = f args -- === LLVM passes === -optimizeModule :: Context -> Logger [Output] -> T.TargetMachine -> Mod.Module -> IO () -optimizeModule ctx logger tm m = do - showModule m >>= logPass JitPass - L.verify m - runDefaultPasses tm m - showModule m >>= logPass LLVMOpt - showAsm ctx tm m >>= logPass AsmPass - where logPass passName s = logThis logger [PassInfo passName s] - runDefaultPasses :: T.TargetMachine -> Mod.Module -> IO () runDefaultPasses t m = do - P.withPassManager defaultPasses $ \pm -> void $ P.runPassManager pm m + P.withPassManager defaultPasses \pm -> void $ P.runPassManager pm m -- We are highly dependent on LLVM when it comes to some optimizations such as -- turning a sequence of scalar stores into a vector store, so we execute some -- extra passes to make sure they get simplified correctly. runPasses extraPasses (Just t) m - P.withPassManager defaultPasses $ \pm -> void $ P.runPassManager pm m + P.withPassManager defaultPasses \pm -> void $ P.runPassManager pm m where defaultPasses = P.defaultCuratedPassSetSpec {P.optLevel = Just 3} extraPasses = [ P.SuperwordLevelParallelismVectorize @@ -182,11 +196,12 @@ runPasses passes mt m = do Just t -> Just <$> T.getTargetMachineDataLayout t Nothing -> return Nothing let passSpec = P.PassSetSpec passes dl Nothing mt - P.withPassManager passSpec $ \pm -> void $ P.runPassManager pm m + P.withPassManager passSpec \pm -> void $ P.runPassManager pm m internalize :: [String] -> Mod.Module -> IO () internalize names m = runPasses [P.InternalizeFunctions names, P.GlobalDeadCodeElimination] Nothing m + -- === supported target machines === -- XXX: We need to use the large code model for macOS, because the libC functions @@ -204,7 +219,7 @@ withHostTargetMachine f = withGPUTargetMachine :: B.ByteString -> (T.TargetMachine -> IO a) -> IO a withGPUTargetMachine computeCapability next = do (tripleTarget, _) <- T.lookupTarget Nothing ptxTargetTriple - T.withTargetOptions $ \topt -> + T.withTargetOptions \topt -> T.withTargetMachine tripleTarget ptxTargetTriple @@ -222,11 +237,12 @@ withGPUTargetMachine computeCapability next = do showModule :: Mod.Module -> IO String showModule m = unpack <$> Mod.moduleLLVMAssembly m -showAsm :: Context -> T.TargetMachine -> Mod.Module -> IO String -showAsm ctx t m' = do +showAsm :: T.TargetMachine -> Mod.Module -> IO String +showAsm t m' = do + ctx <- Mod.moduleContext m' -- Uncomment this to dump assembly to a file that can be linked to a C benchmark suite: - -- withModuleClone ctx m' $ \m -> Mod.writeObjectToFile t (Mod.File "asm.o") m - withModuleClone ctx m' $ \m -> unpack <$> Mod.moduleTargetAssembly t m + -- withModuleClone ctx m' \m -> Mod.writeObjectToFile t (Mod.File "asm.o") m + withModuleClone ctx m' \m -> unpack <$> Mod.moduleTargetAssembly t m withModuleClone :: Context -> Mod.Module -> (Mod.Module -> IO a) -> IO a withModuleClone ctx m f = do @@ -272,10 +288,11 @@ ptrArray p = map (\i -> p `plusPtr` (i * cellSize)) [0..] -- === dex runtime === +{-# NOINLINE dexrtAST #-} dexrtAST :: L.Module dexrtAST = unsafePerformIO $ do - withContext $ \ctx -> do - Mod.withModuleFromBitcode ctx (("dexrt.c" :: String), dexrtBC) $ \m -> + withContext \ctx -> do + Mod.withModuleFromBitcode ctx (("dexrt.c" :: String), dexrtBC) \m -> stripFunctionAnnotations <$> Mod.moduleAST m where -- We strip the function annotations for dexrt functions, because clang @@ -289,13 +306,14 @@ dexrtAST = unsafePerformIO $ do _ -> L.GlobalDefinition $ f { L.functionAttributes = [] } stripDef def = def -linkDexrt :: Context -> Mod.Module -> IO () -linkDexrt ctx m = do +linkDexrt :: Mod.Module -> IO () +linkDexrt m = do + ctx <- Mod.moduleContext m dataLayout <- Mod.getDataLayout =<< Mod.readModule m targetTriple <- Mod.getTargetTriple =<< Mod.readModule m let dexrtTargetAST = dexrtAST { L.moduleDataLayout = dataLayout , L.moduleTargetTriple = targetTriple } - Mod.withModuleFromAST ctx dexrtTargetAST $ \dexrtm -> do + Mod.withModuleFromAST ctx dexrtTargetAST \dexrtm -> do Mod.linkModules m dexrtm runPasses [P.AlwaysInline True] Nothing m @@ -307,23 +325,21 @@ data LLVMKernel = LLVMKernel L.Module compileCUDAKernel :: Logger [Output] -> LLVMKernel -> IO CUDAKernel compileCUDAKernel logger (LLVMKernel ast) = do T.initializeAllTargets - withContext $ \ctx -> - Mod.withModuleFromAST ctx ast $ \m -> do - withGPUTargetMachine (pack arch) $ \tm -> do - linkLibdevice ctx m - linkDexrt ctx m - internalize ["kernel"] m - optimizeModule ctx logger tm m + withContext \ctx -> + Mod.withModuleFromAST ctx ast \m -> do + withGPUTargetMachine (pack arch) \tm -> do + linkLibdevice m + standardCompilationPipeline logger ["kernel"] tm m ptx <- Mod.moduleTargetAssembly tm m usePTXAS <- maybe False (=="1") <$> lookupEnv "DEX_USE_PTXAS" if usePTXAS then do - withSystemTempFile "kernel.ptx" $ \ptxPath ptxH -> do + withSystemTempFile "kernel.ptx" \ptxPath ptxH -> do B.hPut ptxH ptx hClose ptxH - withSystemTempFile "kernel.sass" $ \sassPath sassH -> do + withSystemTempFile "kernel.sass" \sassPath sassH -> do let cmd = proc ptxasPath [ptxPath, "-o", sassPath, "-arch=" ++ arch, "-O3"] - withCreateProcess cmd $ \_ _ _ ptxas -> do + withCreateProcess cmd \_ _ _ ptxas -> do code <- waitForProcess ptxas case code of ExitSuccess -> return () @@ -338,7 +354,7 @@ compileCUDAKernel logger (LLVMKernel ast) = do {-# NOINLINE libdevice #-} libdevice :: L.Module libdevice = unsafePerformIO $ do - withContext $ \ctx -> do + withContext \ctx -> do let libdeviceDirectory = "/usr/local/cuda/nvvm/libdevice" [libdeviceFileName] <- listDirectory libdeviceDirectory let libdevicePath = libdeviceDirectory ++ "/" ++ libdeviceFileName @@ -348,10 +364,11 @@ libdevice = unsafePerformIO $ do return $ m { L.moduleDataLayout = Just ptxDataLayout , L.moduleTargetTriple = Just ptxTargetTriple } -linkLibdevice :: Context -> Mod.Module -> IO () -linkLibdevice ctx m = - Mod.withModuleFromAST ctx zeroNVVMReflect $ \reflectm -> - Mod.withModuleFromAST ctx libdevice $ \ldm -> do +linkLibdevice :: Mod.Module -> IO () +linkLibdevice m = do + ctx <- Mod.moduleContext m + Mod.withModuleFromAST ctx zeroNVVMReflect \reflectm -> + Mod.withModuleFromAST ctx libdevice \ldm -> do Mod.linkModules m ldm Mod.linkModules m reflectm runPasses [P.AlwaysInline True] Nothing m @@ -394,3 +411,32 @@ ptxDataLayout = (L.defaultDataLayout L.LittleEndian) [ ((L.VectorAlign, w), L.AlignmentInfo w w) | w <- [16, 32, 64, 128] ] , L.nativeSizes = Just $ S.fromList [16, 32, 64] } + +-- ==== unix pipe utilities === + +type IOExcept a = Either SomeException a + +withPipe :: (Handle -> IO a) -> (Handle -> IO b) -> IO (IOExcept a, IOExcept b) +withPipe readAction writeAction = do + (readHandle, writeHandle) <- createPipe + waitForReader <- forkWithResult $ readAction readHandle + waitForWriter <- forkWithResult $ writeAction writeHandle + y <- waitForWriter `finally` hClose writeHandle + x <- waitForReader `finally` hClose readHandle + return (x, y) + +forkWithResult :: IO a -> IO (IO (IOExcept a)) +forkWithResult action = do + resultMVar <- newEmptyMVar + void $ forkIO $ catch (do result <- action + putMVar resultMVar $ Right result) + (\e -> putMVar resultMVar $ Left (e::SomeException)) + return $ takeMVar resultMVar + +readStream :: Handle -> (String -> IO ()) -> IO () +readStream h action = go + where + go :: IO () + go = do + eof <- hIsEOF h + unless eof $ hGetLine h >>= action >> go diff --git a/src/lib/LiveOutput.hs b/src/lib/LiveOutput.hs index 640fcdb8b..83a302b43 100644 --- a/src/lib/LiveOutput.hs +++ b/src/lib/LiveOutput.hs @@ -92,7 +92,7 @@ sourceBlockToDag block = do -- TODO: Stop forcing dependencies on all preceding blocks. This will require -- an improvement of the analysis above, such that all blocks depend on those -- that contain interface instance definitions. - extend $ (foldMap ((@>n) . Bind) $ envAsVars $ boundUVars block, [n]) + extend (foldMap ((@>n) . Bind) $ envAsVars $ boundUVars block, [n]) case sbContents block of IncludeSourceFile _ -> extend $ asSnd [n] _ -> return () @@ -145,7 +145,7 @@ oneSourceBlock k b = RFragment mempty (M.singleton k b) mempty serveResults :: StreamingBody -> Application serveResults results request respond = do - putStrLn (show $ pathInfo request) + print (pathInfo request) case pathInfo request of [] -> respondWith "static/index.html" "text/html" ["style.css"] -> respondWith "static/style.css" "text/css" @@ -204,7 +204,7 @@ displayResultsTerm reqChan = c <- myChan send reqChan $ subChan Left c void $ spawn Trap $ monitorKeyboard $ subChan Right c - forever $ termDisplayLoop + forever termDisplayLoop termDisplayLoop :: TermDisplayM () termDisplayLoop = do @@ -232,7 +232,7 @@ cropTrailingLines n s = unlines $ reverse $ drop n $ reverse $ lines s renderResults :: RFragment -> Maybe String renderResults (RFragment NotSet _ _) = Nothing renderResults (RFragment (Set ids) blocks results) = - liftM fold $ flip mapM ids $ \i -> do + liftM fold $ forM ids $ \i -> do b <- M.lookup i blocks r <- M.lookup i results return $ printLitBlock True b r @@ -241,7 +241,7 @@ monitorKeyboard :: PChan KeyboardCommand -> Actor () () monitorKeyboard chan = do liftIO $ hSetBuffering stdin NoBuffering forever $ do - c <- liftIO $ getChar + c <- liftIO getChar case c of 'k' -> send chan ScrollUp 'j' -> send chan ScrollDown @@ -274,10 +274,15 @@ onmod fname action = do -- === DAG utils === +-- | A pair of an @a@ and a list of neighbor node ids. type Node a = (a, [NodeId]) -data Dag a = Dag (M.Map NodeId (Node a)) (M.Map (a, [NodeId]) NodeId) --- returns the addition only, not the new DAG +-- | A directed acyclic graph, represented as a bidirectional map from node ids +-- to nodes. +data Dag a = Dag (M.Map NodeId (Node a)) (M.Map (Node a) NodeId) + +-- | Adds a node to a DAG, if it does not already exist. +-- Returns the added node id and a DAG representing the added node. addToDag :: Ord a => Dag a -> Node a -> (NodeId, Dag a) addToDag (Dag _ m) node = case M.lookup node m of diff --git a/src/lib/Logging.hs b/src/lib/Logging.hs index 1ee82ccdc..37d40fd8a 100644 --- a/src/lib/Logging.hs +++ b/src/lib/Logging.hs @@ -20,7 +20,7 @@ data Logger l = Logger (MVar l) (Maybe Handle) runLogger :: (Monoid l, MonadIO m) => Maybe FilePath -> (Logger l -> m a) -> m (a, l) runLogger maybePath m = do log <- liftIO $ newMVar mempty - logFile <- liftIO $ forM maybePath $ \path -> openFile path WriteMode + logFile <- liftIO $ forM maybePath \path -> openFile path WriteMode ans <- m $ Logger log logFile logged <- liftIO $ readMVar log return (ans, logged) @@ -30,10 +30,10 @@ execLogger maybePath m = fst <$> runLogger maybePath m logThis :: (Pretty l, Monoid l, MonadIO m) => Logger l -> l -> m () logThis (Logger log maybeLogHandle) x = liftIO $ do - forM_ maybeLogHandle $ \h -> do + forM_ maybeLogHandle \h -> do hPutStrLn h $ pprint x hFlush h - modifyMVar_ log $ \cur -> return (cur <> x) + modifyMVar_ log \cur -> return (cur <> x) readLog :: MonadIO m => Logger l -> m l readLog (Logger log _) = liftIO $ readMVar log diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index dd88dac68..44022ba70 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -95,13 +95,22 @@ inlineModule m = transformModuleAsBlock inlineBlock (computeInlineHints m) inlineTraverseDecl :: Decl -> InlineM SubstEnv inlineTraverseDecl decl = case decl of - -- This is not a super safe condition for inlining, because it might still duplicate work - -- unexpectedly (consider an `arr` that's only used as `for i. 2.0 .* arr`). Still, this is not - -- the way arrays are usually used, so it should be good enough for now. In the future we should - -- strengthen this check to better ensure that each element of the array is used at most once. Let _ b@(BindWithHint CanInline _) expr@(Hof (For _ body)) | isPure expr -> do ~(LamVal ib block) <- traverseAtom inlineTraversalDef body return $ b @> TabVal ib block + -- If `f` turns out to be an inlined table lambda, we expand its block and + -- call ourselves recursively on the block's result expression. This makes + -- it possible for us to e.g. discover that the result is a `for` loop, and + -- match the case above, to continue the inlining process. + Let letAnn letBinder (App f' x') -> do + f <- traverseAtom inlineTraversalDef f' + x <- traverseAtom inlineTraversalDef x' + case f of + TabVal b (Block body result) -> do + dropSub $ extendR (b@>x) $ do + blockEnv <- traverseDeclsOpen substTraversalDef body + extendR blockEnv $ inlineTraverseDecl $ Let letAnn letBinder result + _ -> (letBinder@>) <$> emitTo (binderNameHint letBinder) letAnn (App f x) _ -> traverseDecl inlineTraversalDef decl -- TODO: This is a bit overeager. We should count under how many loops are we. @@ -113,15 +122,12 @@ inlineTraverseExpr expr = case expr of Hof (For d body) -> do newBody <- traverseAtom inlineTraversalDef body case newBody of - -- Trivial bodies -- XXX: The trivial body might be a table lambda, and those could technically -- get quite expensive. But I think this should never be the case in practice. - -- XXX: This doesn't always have to end up being beneficial. If the result is - -- significantly smaller than the intermediates it refers to, then this - -- optimization will waste a bunch of memory by keeping the large intermediates alive. + -- Trivial bodies LamVal ib block@(Block Empty (Atom _)) -> return $ Atom $ TabVal ib block -- Pure broadcasts - LamVal ib@(Ignore _) block | blockEffs block == NoEffects -> do + LamVal ib@(Ignore _) block | blockEffs block == Pure -> do result <- dropSub $ evalBlockE inlineTraversalDef block Atom <$> buildLam ib TabArrow (\_ -> return $ result) _ -> return $ Hof $ For d newBody diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 6022fd00d..c09da8a2d 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -118,10 +118,6 @@ instance PrettyPrec BaseType where Vector sb -> atPrec ArgPrec $ "<" <> p vectorWidth <+> "x" <+> p sb <> ">" PtrType ty -> atPrec AppPrec $ "Ptr" <+> p ty -instance Pretty PtrOrigin where - pretty AllocatedPtr = "a" - pretty DerivedPtr = "d" - instance Pretty AddressSpace where pretty Stack = "stack" pretty (Heap d) = p (show d) @@ -138,12 +134,16 @@ instance PrettyPrec ScalarBaseType where printDouble :: Double -> Doc ann printDouble x = p (double2Float x) +printFloat :: Float -> Doc ann +printFloat x = p $ reverse $ dropWhile (=='0') $ reverse $ + showFFloat (Just 6) x "" + instance Pretty LitVal where pretty = prettyFromPrettyPrec instance PrettyPrec LitVal where prettyPrec (Int64Lit x) = atPrec ArgPrec $ p x prettyPrec (Int32Lit x) = atPrec ArgPrec $ p x prettyPrec (Float64Lit x) = atPrec ArgPrec $ printDouble x - prettyPrec (Float32Lit x) = atPrec ArgPrec $ p x + prettyPrec (Float32Lit x) = atPrec ArgPrec $ printFloat x prettyPrec (Word8Lit x) = atPrec ArgPrec $ p $ show $ toEnum @Char $ fromIntegral x prettyPrec (PtrLit ty x) = atPrec ArgPrec $ "Ptr" <+> p ty <+> p (show x) prettyPrec (VecLit l) = atPrec ArgPrec $ encloseSep "<" ">" ", " $ fmap p l @@ -478,9 +478,7 @@ instance Pretty ImpFunction where instance Pretty ImpInstr where pretty (IFor a i n block) = forStr (RegularFor a) <+> p i <+> "<" <+> p n <> nest 4 (hardline <> p block) - pretty (IWhile cond body) = "while" <+> - nest 2 (p cond) <+> "do" <> - nest 4 (hardline <> p body) + pretty (IWhile body) = "while" <+> nest 2 (p body) pretty (ICond predicate cons alt) = "if" <+> p predicate <+> "then" <> nest 2 (hardline <> p cons) <> hardline <> "else" <> nest 2 (hardline <> p alt) @@ -521,8 +519,8 @@ instance Pretty Output where benchName <> hardline <> "Compile time: " <> prettyDuration compileTime <> hardline <> "Run time: " <> prettyDuration runTime <+> - (case stats of Just runs -> "\t" <> parens ("based on" <+> p runs <+> "runs") - Nothing -> "") + (case stats of Just (runs, _) -> "\t" <> parens ("based on" <+> p runs <+> "runs") + Nothing -> "") where benchName = case name of "" -> "" _ -> "\n" <> p name pretty (PassInfo name s) = "===" <+> p name <+> "===" <> hardline <> p s @@ -592,7 +590,7 @@ instance PrettyPrec UExpr' where where kw = case dir of Fwd -> "for" Rev -> "rof" UPi binder arr ty -> atPrec LowestPrec $ - prettyUPiBinder binder <+> pretty arr <+> pLowest ty + prettyUBinder binder <+> pretty arr <+> pLowest ty UDecl decl body -> atPrec LowestPrec $ align $ p decl <> hardline <> pLowest body UHole -> atPrec ArgPrec "_" @@ -631,6 +629,13 @@ instance Pretty UDecl where align $ prettyUBinder b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) pretty (UData tyCon dataCons) = "data" <+> p tyCon <+> "where" <> nest 2 (hardline <> prettyLines dataCons) + pretty (UInterface cs def methods) = + "interface" <+> p cs <+> p def <> hardline <> prettyLines methods + pretty (UInstance bs ty methods) = + "instance" <+> p bs <+> p ty <> hardline <> prettyLines methods + +instance Pretty UMethodDef where + pretty (UMethodDef b rhs) = p b <+> "=" <+> p rhs instance Pretty UConDef where pretty (UConDef con bs) = p con <+> spaced bs @@ -653,26 +658,25 @@ prettyUBinder (pat, ann) = p pat <> annDoc where Just ty -> ":" <> pApp ty Nothing -> mempty -prettyUPiBinder :: UPiPatAnn -> Doc ann -prettyUPiBinder (pat, ann) = patDoc <> p ann where - patDoc = case pat of - Just pat' -> pApp pat' <> ":" - Nothing -> mempty - spaced :: (Foldable f, Pretty a) => f a -> Doc ann spaced xs = hsep $ map p $ toList xs instance Pretty EffectRow where - pretty (EffectRow [] Nothing) = mempty + pretty Pure = mempty pretty (EffectRow effs tailVar) = - braces $ hsep (punctuate "," (fmap prettyEff effs)) <> tailStr + braces $ hsep (punctuate "," (map p (toList effs))) <> tailStr where - prettyEff (effName, region) = p effName <+> p region tailStr = case tailVar of Nothing -> mempty Just v -> "|" <> p v -instance Pretty EffectName where +instance Pretty Effect where + pretty eff = case eff of + RWSEffect rws h -> p rws <+> p h + ExceptionEffect -> "Except" + IOEffect -> "IO" + +instance Pretty RWS where pretty eff = case eff of Reader -> "Read" Writer -> "Accum" diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index 4595d1d4c..e11842020 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -70,18 +70,18 @@ parallelTraverseExpr expr = case expr of -- TODO: functionEffs is an overapproximation of the effects that really appear inside refs <- gets activeAccs let allowedRegions = foldMap (\(varType -> RefTy (Var reg) _) -> reg @> ()) refs - bodyEffs <- substEmbedR $ functionEffs fbody - let onlyAllowedEffects = flip all bodyEffs $ \(eff, reg) -> eff == Writer && reg `isin` allowedRegions - case onlyAllowedEffects of + (EffectRow bodyEffs t) <- substEmbedR $ functionEffs fbody + let onlyAllowedEffects = all (parallelizableEffect allowedRegions) $ toList bodyEffs + case t == Nothing && onlyAllowedEffects of True -> do b' <- substEmbedR b liftM Atom $ runLoopM $ withLoopBinder b' $ buildParallelBlock $ asABlock body False -> nothingSpecial Hof (RunWriter (BinaryFunVal h b _ body)) -> do ~(RefTy _ accTy) <- traverseAtom substTraversalDef $ binderType b - liftM Atom $ emitRunWriter (binderNameHint b) accTy $ \ref@(Var refVar) -> do + liftM Atom $ emitRunWriter (binderNameHint b) accTy \ref@(Var refVar) -> do let RefTy h' _ = varType refVar - modify $ \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> refVar } + modify \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> refVar } extendR (h @> h' <> b @> ref) $ evalBlockE parallelTrav body -- TODO: Do some alias analysis. This is not fundamentally hard, but it is a little annoying. -- We would have to track not only the base references, but also all the aliases, along @@ -95,7 +95,13 @@ parallelTraverseExpr expr = case expr of where nothingSpecial = traverseExpr parallelTrav expr disallowRef ~(Var refVar) = - modify $ \accEnv -> accEnv { activeAccs = activeAccs accEnv `envDiff` (refVar @> ()) } + modify \accEnv -> accEnv { activeAccs = activeAccs accEnv `envDiff` (refVar @> ()) } + +parallelizableEffect :: Env () -> Effect -> Bool +parallelizableEffect allowedRegions effect = case effect of + RWSEffect Writer h | h `isin` allowedRegions -> True + -- TODO: we should be able to parallelize the exception effect too + _ -> False -- Precondition: This is never called with no binders in the loop env buildParallelBlock :: ABlock -> LoopM Atom @@ -197,7 +203,7 @@ emitLoops buildPureLoop (ABlock decls result) = do let buildBody pari = do is <- unpackConsList pari extendR (newEnv lbs is) $ do - ctxEnv <- flip traverseNames dapps $ \_ (arr, idx) -> + ctxEnv <- flip traverseNames dapps \_ (arr, idx) -> -- XXX: arr is namespaced in the new program foldM appTryReduce arr =<< substEmbedR idx extendR ctxEnv $ evalBlockE appReduceTraversalDef $ Block decls $ Atom result @@ -205,18 +211,18 @@ emitLoops buildPureLoop (ABlock decls result) = do True -> buildPureLoop (Bind $ "pari" :> iterTy) buildBody False -> do body <- do - buildLam (Bind $ "gtid" :> IdxRepTy) PureArrow $ \gtid -> do - buildLam (Bind $ "nthr" :> IdxRepTy) PureArrow $ \nthr -> do + buildLam (Bind $ "gtid" :> IdxRepTy) PureArrow \gtid -> do + buildLam (Bind $ "nthr" :> IdxRepTy) PureArrow \nthr -> do let threadRange = TC $ ParIndexRange iterTy gtid nthr let accTys = mkConsListTy $ fmap (derefType . varType) newRefs - emitRunWriter "refsList" accTys $ \localRefsList -> do + emitRunWriter "refsList" accTys \localRefsList -> do localRefs <- unpackRefConsList localRefsList - buildFor Fwd (Bind $ "tidx" :> threadRange) $ \tidx -> do + buildFor Fwd (Bind $ "tidx" :> threadRange) \tidx -> do pari <- emitOp $ Inject tidx extendR (newEnv oldRefNames localRefs) $ buildBody pari (ans, updateList) <- fromPair =<< (emit $ Hof $ PTileReduce iterTy body) updates <- unpackConsList updateList - forM_ (zip newRefs updates) $ \(ref, update) -> + forM_ (zip newRefs updates) \(ref, update) -> emitOp $ PrimEffect (Var ref) $ MTell update return ans where diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index c253cc25b..cb6574b8b 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -12,12 +12,13 @@ import Control.Monad import Control.Monad.Combinators.Expr import Control.Monad.Reader import Text.Megaparsec hiding (Label, State) -import Text.Megaparsec.Char hiding (space) +import Text.Megaparsec.Char hiding (space, eol) +import qualified Text.Megaparsec.Char as MC import Data.Char (isLower) import Data.Functor import Data.List.NonEmpty (NonEmpty (..)) -import qualified Data.Map.Strict as M import Data.Void +import qualified Data.Set as S import Data.String (fromString) import qualified Text.Megaparsec.Char.Lexer as L import qualified Text.Megaparsec.Debug @@ -87,22 +88,22 @@ logLevel :: Parser LogLevel logLevel = do void $ try $ lexeme $ char '%' >> string "passes" passes <- many passName - void eol + eol case passes of - [] -> return $ LogAll + [] -> return LogAll _ -> return $ LogPasses passes logTime :: Parser LogLevel logTime = do void $ try $ lexeme $ char '%' >> string "time" - void eol + eol return PrintEvalTime logBench :: Parser LogLevel logBench = do void $ try $ lexeme $ char '%' >> string "bench" benchName <- stringLiteral - void eol + eol return $ PrintBench benchName passName :: Parser PassName @@ -115,52 +116,25 @@ sourceBlock' :: Parser SourceBlock' sourceBlock' = proseBlock <|> topLevelCommand - <|> fmap (declsToModule . (:[])) (topDecl <* eolf) - <|> fmap (declsToModule . (:[])) (interfaceInstance <* eolf) - <|> fmap declsToModule (interfaceDef <* eolf) - <|> fmap (Command (EvalExpr Printed) . exprAsModule) (expr <* eol) + <|> liftM declToModule (topDecl <* eolf) + <|> liftM declToModule (instanceDef <* eolf) + <|> liftM declToModule (interfaceDef <* eolf) + <|> liftM (Command (EvalExpr Printed) . exprAsModule) (expr <* eol) <|> hidden (some eol >> return EmptyLines) <|> hidden (sc >> eol >> return CommentLine) - where declsToModule = RunModule . UModule . toNest + where + declsToModule = RunModule . UModule . toNest + declToModule = declsToModule . (:[]) proseBlock :: Parser SourceBlock' proseBlock = label "prose block" $ char '\'' >> fmap (ProseBlock . fst) (withSource consumeTillBreak) -loadData :: Parser SourceBlock' -loadData = do - symbol "load" - fmt <- dataFormat - s <- stringLiteral - symbol "as" - b <- patAnn - void eol - return $ LoadData b fmt s - topLevelCommand :: Parser SourceBlock' topLevelCommand = - (liftM IncludeSourceFile includeSourceFile) - <|> loadData - <|> dumpData + liftM IncludeSourceFile includeSourceFile <|> explicitCommand "top-level command" -dataFormat :: Parser DataFormat -dataFormat = do - s <- nameString - case s of - "dxo" -> return DexObject - "dxbo" -> return DexBinaryObject - _ -> fail $ show s ++ " not a recognized data format (one of dxo|dxbo)" - -dumpData :: Parser SourceBlock' -dumpData = do - symbol "dump" - fmt <- dataFormat - s <- stringLiteral - e <- blockOrExpr - void eol - return $ Command (Dump fmt s) (exprAsModule e) - explicitCommand :: Parser SourceBlock' explicitCommand = do cmdName <- char ':' >> nameString @@ -179,7 +153,7 @@ exprAsModule :: UExpr -> (Name, UModule) exprAsModule e = (asGlobal v, UModule (toNest [d])) where v = mkName "_ans_" - d = ULet PlainLet (WithSrc (srcPos e) (UPatBinder (Bind (v:>()))), Nothing) e + d = ULet PlainLet (WithSrc (srcPos e) (nameToPat v), Nothing) e -- === uexpr === @@ -196,6 +170,7 @@ leafExpr = parens (mayPair $ makeExprParser leafExpr ops) <|> uLit <|> uPiType <|> uLamExpr + <|> uViewExpr <|> uForExpr <|> caseExpr <|> ifExpr @@ -203,6 +178,7 @@ leafExpr = parens (mayPair $ makeExprParser leafExpr ops) <|> unitCon <|> (uLabeledExprs `fallBackTo` uVariantExpr) <|> uIsoSugar + <|> uDoSugar "expression" containedExpr :: Parser UExpr @@ -217,8 +193,9 @@ uType = expr uString :: Lexer UExpr uString = do (s, pos) <- withPos $ strLit - let cs = map (WithSrc (Just pos) . charExpr) s - return $ WithSrc (Just pos) $ UTabCon cs + let addSrc = WithSrc (Just pos) + let cs = map (addSrc . charExpr) s + return $ mkApp (addSrc "toList") $ addSrc $ UTabCon cs uLit :: Parser UExpr uLit = withSrc $ uLitParser @@ -231,8 +208,7 @@ charExpr :: Char -> UExpr' charExpr c = UPrimExpr $ ConExpr $ Lit $ Word8Lit $ fromIntegral $ fromEnum c uVarOcc :: Parser UExpr -uVarOcc = withSrc $ try $ (UVar . (:>())) <$> (occName <* notFollowedBy (sym ":")) - where occName = upperName <|> lowerName <|> symName +uVarOcc = withSrc $ try $ (UVar . (:>())) <$> (anyName <* notFollowedBy (sym ":")) uHole :: Parser UExpr uHole = withSrc $ underscore $> UHole @@ -247,9 +223,9 @@ topDecl = dataDef <|> topLet topLet :: Parser UDecl topLet = do - lAnn <- (char '@' >> letAnnStr <* (void eol <|> sc)) <|> return PlainLet - ~(ULet _ (p, ann) rhs, pos) <- withPos decl - let (ann', rhs') = addImplicitImplicitArgs pos ann rhs + lAnn <- (char '@' >> letAnnStr <* (eol <|> sc)) <|> return PlainLet + ~(ULet _ (p, ann) rhs) <- decl + let (ann', rhs') = addImplicitImplicitArgs ann rhs return $ ULet lAnn (p, ann') rhs' -- Given a type signature, find all "implicit implicit args": lower-case @@ -275,7 +251,7 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ -- recursive steps UVar _ -> mempty UPi (p, ann) _ ty -> - findVarsInAppLHS ann <> (findVarsInAppLHS ty `envDiff` boundUVars p) + foldMap findVarsInAppLHS ann <> (findVarsInAppLHS ty `envDiff` boundUVars p) UApp _ f x -> findVarsInAppLHS f <> findVarsInAppLHS x ULam (p, ann) _ x -> foldMap findVarsInAppLHS ann <> (findVarsInAppLHS x `envDiff` boundUVars p) @@ -283,7 +259,7 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ UFor _ _ _ -> error "Unexpected for in type annotation" UHole -> mempty UTypeAnn v ty -> findVarsInAppLHS v <> findVarsInAppLHS ty - UTabCon _ -> error "Unexpected table in type annotation" + UTabCon _ -> error "Unexpected table constructor in type annotation" UIndexRange low high -> foldMap findVarsInAppLHS low <> foldMap findVarsInAppLHS high UPrimExpr prim -> foldMap findVarsInAppLHS prim @@ -298,77 +274,33 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ UIntLit _ -> mempty UFloatLit _ -> mempty -addImplicitImplicitArgs :: SrcPos -> Maybe UType -> UExpr -> (Maybe UType, UExpr) -addImplicitImplicitArgs _ Nothing e = (Nothing, e) -addImplicitImplicitArgs sourcePos (Just typ) ex = - let (ty', e') = foldr (addImplicitArg sourcePos) (typ, ex) implicitVars +addImplicitImplicitArgs :: Maybe UType -> UExpr -> (Maybe UType, UExpr) +addImplicitImplicitArgs Nothing e = (Nothing, e) +addImplicitImplicitArgs (Just typ) ex = + let (ty', e') = foldr addImplicitArg (typ, ex) implicitVars in (Just ty', e') where implicitVars = findImplicitImplicitArgNames typ - addImplicitArg :: SrcPos -> Name -> (UType, UExpr) -> (UType, UExpr) - addImplicitArg pos v (ty, e) = - ( WithSrc (Just pos) $ UPi (Just uPat, uTyKind) ImplicitArrow ty - , WithSrc (Just pos) $ ULam (uPat, Just uTyKind) ImplicitArrow e) - where - uPat = WithSrc (Just pos) $ UPatBinder $ Bind $ v:>() - k = if v == mkName "eff" then EffectRowKind else TypeKind - uTyKind = WithSrc (Just pos) $ UPrimExpr $ TCExpr k + addImplicitArg :: Name -> (UType, UExpr) -> (UType, UExpr) + addImplicitArg v (ty, e) = + ( ns $ UPi (uPat, Nothing) ImplicitArrow ty + , ns $ ULam (uPat, Nothing) ImplicitArrow e) + where uPat = ns $ nameToPat v -interfaceDef :: Parser [UDecl] +superclassConstraints :: Parser [UType] +superclassConstraints = optionalMonoid $ brackets $ uType `sepBy` sym "," + +interfaceDef :: Parser UDecl interfaceDef = do keyWord InterfaceKW - (tyCon, pos) <- withPos tyConDef - keyWord WhereKW - recordFieldsWithSrc <- withSrc $ interfaceRecordFields ":" - let (UConDef interfaceName uAnnBinderNest) = tyCon - record = URecordTy . NoExt <$> recordFieldsWithSrc - consName = mkInterfaceConsName interfaceName - varNames = fmap (\(Bind v) -> varName v) uAnnBinderNest - (WithSrc _ recordFields) = recordFieldsWithSrc - funDefs = mkFunDefs (pos, varNames, interfaceName) recordFields - return $ UData tyCon [UConDef consName (toNest [Ignore record])] : funDefs - where - -- From an interface - -- interface I a:Type b:Type where - -- f : a -> b - -- mkFunDefs generates the equivalent of the following function definition: - -- def f (instance# : I a b) ?=> : a -> b = - -- (I# {f=f,...}) = instance# - -- f - -- where I# is an automatically generated constructor of I. - mkFunDefs - :: (SrcPos, Nest Name, Name) -> LabeledItems UExpr -> [UDecl] - mkFunDefs meta (LabeledItems items) = - fmap (\(name, ty :| []) -> mkOneFunDef meta (name, ty)) $ M.toList items - mkOneFunDef :: (SrcPos, Nest Name, Name) -> (Label, UExpr) -> UDecl - mkOneFunDef (pos, typeVarNames, interfaceName) (fLabel, fType) = - ULet PlainLet (p, ann') rhs' - where - uAnnPat = ( Just $ WithSrc (Just pos) $ UPatBinder $ Bind $ instanceName :> () - , foldl mkUApp (var interfaceName) typeVarNames) - p = patb fLabel - ann = Just $ ns $ UPi uAnnPat ClassArrow fType - - mkUApp func typeVarName = - ns $ UApp (PlainArrow ()) func (var typeVarName) - recordStr = "recordVar" - recordPat = ns $ UPatRecord $ Ext (labeledSingleton fLabel (patb - fLabel)) $ Just (ns (UPatBinder (Ignore ()))) - conPat = ns $ UPatCon (mkInterfaceConsName interfaceName) - $ toNest [patb recordStr] - - let1 = ULet PlainLet (conPat, Nothing) $ var instanceName - let2 = ULet PlainLet (recordPat, Nothing) $ var $ mkName recordStr - body = ns $ UDecl let1 (ns $ UDecl let2 (var (mkName fLabel))) - rhs = ns $ ULam (patb instanceStr, Nothing) ClassArrow body - (ann', rhs') = addImplicitImplicitArgs pos ann rhs - - ns = WithSrc Nothing - patb s = ns $ UPatBinder $ Bind $ mkName s :> () - instanceStr = mkNoShadowingStr "instance" - instanceName = mkName instanceStr - var name = ns $ UVar $ name :> () + superclasses <- superclassConstraints + tyCon <- tyConDef + methods <- onePerLine $ do + v <- anyName + ty <- annot uType + return $ Bind $ v:>ty + return $ UInterface superclasses tyCon methods dataDef :: Parser UDecl dataDef = do @@ -378,9 +310,15 @@ dataDef = do dataCons <- onePerLine dataConDef return $ UData tyCon dataCons --- TODO: default to `Type` if unannoted tyConDef :: Parser UConDef -tyConDef = UConDef <$> (upperName <|> symName) <*> manyNested namedBinder +tyConDef = do + con <- upperName <|> symName + bs <- manyNested $ label "type constructor parameter" $ do + v <- lowerName + ty <- annot containedExpr <|> return tyKind + return $ Bind $ v :> ty + return $ UConDef con bs + where tyKind = ns $ UPrimExpr $ TCExpr TypeKind -- TODO: dependent types dataConDef :: Parser UConDef @@ -395,52 +333,31 @@ decl = do rhs <- sym "=" >> blockOrExpr return $ lhs rhs -interfaceInstance :: Parser UDecl -interfaceInstance = do +instanceDef :: Parser UDecl +instanceDef = do keyWord InstanceKW - (p, pos) <- withPos letPat - ann <- annot uType - case mkConstructorNameVar ann of - Left err -> fail err - Right constructorNameVar -> do - keyWord WhereKW - record <- withSrc $ (URecord . NoExt) <$> interfaceRecordFields "=" - let constructorCall = constructorNameVar `mkApp` record - (ann', rhs') = addImplicitImplicitArgs pos (Just ann) constructorCall - return $ ULet InstanceLet (p, ann') rhs' + explicitArgs <- many defArg + constraints <- classConstraints + classTy <- uType + let implicitArgs = findImplicitImplicitArgNames $ + buildPiType explicitArgs Pure $ + foldr addClassConstraint classTy constraints + let argBinders = + [((ns (nameToPat v), Nothing), ImplicitArrow) | v <- implicitArgs] ++ + explicitArgs ++ + [((UnderscoreUPat, Just c) , ClassArrow ) | c <- constraints] + methods <- onePerLine instanceMethod + return $ UInstance (toNest argBinders) classTy methods where - -- Here, we are traversing the type annotation to retrieve the name of - -- the interface and generate its corresponding constructor. A valid type - -- annotation for an instance is composed of: - -- 1) implicit/class arguments - -- 2) a function whose name is the name of the interface applied to 0 or - -- more arguments - mkConstructorNameVar ann = - stripArrows ann >>= stripAppliedArgs >>= buildConstructor - - stripArrows (WithSrc _ (UPi _ arr typ)) - | arr `elem` [ClassArrow, ImplicitArrow] = stripArrows typ - | otherwise = Left ("Met invalid arrow '" ++ pprint arr ++ "' in type " ++ - "annotation of instance. Only class arrows and " ++ - "implicit arrows are allowed.") - stripArrows ann = Right ann - - stripAppliedArgs ann - | (WithSrc _ (UApp _ func _)) <- ann = stripAppliedArgs func - | otherwise = Right ann - - buildConstructor (WithSrc _ (UVar v)) = - Right $ (var . nameToStr . mkInterfaceConsName . varName) v - buildConstructor _ = Left ("Could not extract interface name from type " ++ - "annotation.") - var s = noSrc $ UVar $ mkName s :> () - -interfaceRecordFields :: String -> Parser (LabeledItems UExpr) -interfaceRecordFields bindwith = - fuse <$> onePerLine (do l <- fieldLabel - e <- symbol bindwith *> expr - return $ labeledSingleton l e) - where fuse = foldr (<>) NoLabeledItems + addClassConstraint :: UType -> UType -> UType + addClassConstraint c ty = ns $ UPi (UnderscoreUPat, Just c) ClassArrow ty + +instanceMethod :: Parser UMethodDef +instanceMethod = do + v <- anyName + sym "=" + rhs <- blockOrExpr + return $ UMethodDef (v:>()) rhs simpleLet :: Parser (UExpr -> UDecl) simpleLet = label "let binding" $ do @@ -449,36 +366,41 @@ simpleLet = label "let binding" $ do return $ ULet PlainLet (p, ann) letPat :: Parser UPat -letPat = nameAsPat $ upperName <|> lowerName <|> symName +letPat = withSrc $ nameToPat <$> anyName funDefLet :: Parser (UExpr -> UDecl) funDefLet = label "function definition" $ mayBreak $ do keyWord DefKW v <- letPat - bs <- many arg + cs <- classConstraints + argBinders <- many defArg (eff, ty) <- label "result type annotation" $ annot effectiveType - when (null bs && eff /= Pure) $ fail "Nullary def can't have effects" + when (null argBinders && eff /= Pure) $ fail "Nullary def can't have effects" + let bs = map classAsBinder cs ++ argBinders let funTy = buildPiType bs eff ty let letBinder = (v, Just funTy) - let lamBinders = flip map bs $ \(p,_, arr) -> ((p,Nothing), arr) - return $ \body -> ULet PlainLet letBinder (buildLam lamBinders body) + let lamBinders = flip map bs \((p,_), arr) -> ((p,Nothing), arr) + return \body -> ULet PlainLet letBinder (buildLam lamBinders body) where - arg :: Parser (UPat, UType, UArrow) - arg = label "def arg" $ do - (p, ty) <-parens ((,) <$> pat <*> annot uType) - arr <- arrow (return ()) <|> return (PlainArrow ()) - return (p, ty, arr) + classAsBinder :: UType -> UPatAnnArrow + classAsBinder ty = ((UnderscoreUPat, Just ty), ClassArrow) + +defArg :: Parser UPatAnnArrow +defArg = label "def arg" $ do + (p, ty) <-parens ((,) <$> pat <*> annot uType) + arr <- arrow (return ()) <|> return (PlainArrow ()) + return ((p, Just ty), arr) -nameAsPat :: Parser Name -> Parser UPat -nameAsPat p = withSrc $ (UPatBinder . Bind . (:>())) <$> p +classConstraints :: Parser [UType] +classConstraints = label "class constraints" $ + optionalMonoid $ brackets $ mayNotPair $ uType `sepBy` sym "," -buildPiType :: [(UPat, UType, UArrow)] -> EffectRow -> UType -> UType +buildPiType :: [UPatAnnArrow] -> EffectRow -> UType -> UType buildPiType [] Pure ty = ty buildPiType [] _ _ = error "shouldn't be possible" -buildPiType ((p, patTy, arr):bs) eff resTy = WithSrc pos $ case bs of - [] -> UPi (Just p, patTy) (fmap (const eff ) arr) resTy - _ -> UPi (Just p, patTy) (fmap (const Pure) arr) $ buildPiType bs eff resTy - where WithSrc pos _ = patTy +buildPiType (((p, patTy), arr):bs) eff resTy = ns case bs of + [] -> UPi (p, patTy) (fmap (const eff ) arr) resTy + _ -> UPi (p, patTy) (fmap (const Pure) arr) $ buildPiType bs eff resTy effectiveType :: Parser (EffectRow, UType) effectiveType = (,) <$> effects <*> uType @@ -487,15 +409,20 @@ effects :: Parser EffectRow effects = braces someEffects <|> return Pure where someEffects = do - effs <- liftM2 (,) effectName lowerName `sepBy` sym "," + effs <- effect `sepBy` sym "," v <- optional $ symbol "|" >> lowerName - return $ EffectRow effs v + return $ EffectRow (S.fromList effs) v -effectName :: Parser EffectName -effectName = (keyWord WriteKW $> Writer) - <|> (keyWord ReadKW $> Reader) - <|> (keyWord StateKW $> State) - "effect name (Accum|Read|State)" +effect :: Parser Effect +effect = (RWSEffect <$> rwsName <*> anyCaseName) + <|> (keyWord ExceptKW $> ExceptionEffect) + <|> (keyWord IOKW $> IOEffect) + "effect (Accum h | Read h | State h | Except | IO)" + +rwsName :: Parser RWS +rwsName = (keyWord WriteKW $> Writer) + <|> (keyWord ReadKW $> Reader) + <|> (keyWord StateKW $> State) uLamExpr :: Parser UExpr uLamExpr = do @@ -524,6 +451,14 @@ buildFor pos dir binders body = case binders of [] -> body b:bs -> WithSrc (Just pos) $ UFor dir b $ buildFor pos dir bs body +uViewExpr :: Parser UExpr +uViewExpr = do + keyWord ViewKW + bs <- some patAnn + argTerm + body <- blockOrExpr + return $ buildLam (zip bs (repeat TabArrow)) body + uForExpr :: Parser UExpr uForExpr = do ((dir, trailingUnit), pos) <- withPos $ @@ -533,23 +468,24 @@ uForExpr = do <|> (keyWord Rof_KW $> (Rev, True )) e <- buildFor pos dir <$> (some patAnn <* argTerm) <*> blockOrExpr if trailingUnit - then return $ noSrc $ UDecl (ULet PlainLet underscorePat e) unit + then return $ ns $ UDecl (ULet PlainLet (UnderscoreUPat, Nothing) e) $ + ns unitExpr else return e - where - underscorePat :: UPatAnn - underscorePat = (noSrc $ UPatBinder $ Ignore (), Nothing) - unit :: UExpr - unit = noSrc $ UPrimExpr $ ConExpr UnitCon +nameToPat :: Name -> UPat' +nameToPat v = UPatBinder (Bind (v:>())) -noSrc :: a -> WithSrc a -noSrc = WithSrc Nothing +unitExpr :: UExpr' +unitExpr = UPrimExpr $ ConExpr UnitCon + +ns :: a -> WithSrc a +ns = WithSrc Nothing blockOrExpr :: Parser UExpr blockOrExpr = block <|> expr unitCon :: Parser UExpr -unitCon = withSrc $ symbol "()" $> (UPrimExpr $ ConExpr $ UnitCon) +unitCon = withSrc $ symbol "()" $> unitExpr uTabCon :: Parser UExpr uTabCon = withSrc $ do @@ -571,7 +507,7 @@ wrapUStatements statements = case statements of (s, pos):rest -> WithSrc (Just pos) $ case s of Left d -> UDecl d $ wrapUStatements rest Right e -> UDecl d $ wrapUStatements rest - where d = ULet PlainLet (WithSrc (Just pos) (UPatBinder (Ignore ())), Nothing) e + where d = ULet PlainLet (UnderscoreUPat, Nothing) e [] -> error "Shouldn't be reachable" uStatement :: Parser UStatement @@ -585,16 +521,17 @@ uPiType = withSrc $ UPi <$> piBinderPat <*> arrow effects <*> uType b <- annBinder return $ case b of Bind (n:>a@(WithSrc pos _)) -> - (Just $ WithSrc pos $ UPatBinder $ Bind $ n:>(), a) - Ignore a -> (Nothing, a) + (WithSrc pos $ nameToPat n, Just a) + Ignore a -> (UnderscoreUPat, Just a) annBinder :: Parser UAnnBinder annBinder = try $ namedBinder <|> anonBinder namedBinder :: Parser UAnnBinder -namedBinder = label "named annoted binder" $ lowerName - >>= \v -> sym ":" >> containedExpr - >>= \ty -> return $ Bind (v:>ty) +namedBinder = label "named annoted binder" $ do + v <- lowerName + ty <- annot containedExpr + return $ Bind (v:>ty) anonBinder :: Parser UAnnBinder anonBinder = @@ -621,16 +558,34 @@ ifExpr :: Parser UExpr ifExpr = withSrc $ do keyWord IfKW e <- expr - withIndent $ mayNotBreak $ do - alt1 <- keyWord ThenKW >> blockOrExpr - nextLine - alt2 <- keyWord ElseKW >> blockOrExpr - return $ UCase e - [ UAlt (globalEnumPat "True") alt1 + (alt1, maybeAlt2) <- oneLineThenElse <|> blockThenElse + let alt2 = case maybeAlt2 of + Nothing -> ns unitExpr + Just alt -> alt + return $ UCase e + [ UAlt (globalEnumPat "True" ) alt1 , UAlt (globalEnumPat "False") alt2] +oneLineThenElse :: Parser (UExpr, Maybe UExpr) +oneLineThenElse = do + keyWord ThenKW + alt1 <- eitherP block expr + case alt1 of + Left e -> return (e, Nothing) + Right e -> do + alt2 <- optional $ keyWord ElseKW >> blockOrExpr + return (e, alt2) + +blockThenElse :: Parser (UExpr, Maybe UExpr) +blockThenElse = withIndent $ mayNotBreak $ do + alt1 <- keyWord ThenKW >> blockOrExpr + alt2 <- optional $ do + try $ nextLine >> keyWord ElseKW + blockOrExpr + return (alt1, alt2) + globalEnumPat :: Tag -> UPat -globalEnumPat s = noSrc $ UPatCon (GlobalName s) Empty +globalEnumPat s = ns $ UPatCon (GlobalName s) Empty onePerLine :: Parser a -> Parser [a] onePerLine p = liftM (:[]) p @@ -650,8 +605,8 @@ leafPat = <|> (variantPat `fallBackTo` recordPat) <|> brackets (UPatTable <$> leafPat `sepBy` sym ",") ) - where pun pos l = WithSrc (Just pos) $ UPatBinder $ Bind (mkName l:>()) - def pos = WithSrc (Just pos) $ UPatBinder $ Ignore () + where pun pos l = WithSrc (Just pos) $ nameToPat $ mkName l + def pos = WithSrc (Just pos) $ UPatBinder (Ignore ()) variantPat = parseVariant leafPat UPatVariant UPatVariantLift recordPat = UPatRecord <$> parseLabeledItems "," "=" leafPat (Just pun) (Just def) @@ -712,15 +667,20 @@ uLabeledExprs = withSrc $ varPun :: SrcPos -> Label -> UExpr varPun pos str = WithSrc (Just pos) $ UVar (mkName str :> ()) +uDoSugar :: Parser UExpr +uDoSugar = withSrc $ do + keyWord DoKW + body <- blockOrExpr + return $ ULam (WithSrc Nothing UPatUnit, Nothing) (PlainArrow ()) body + uIsoSugar :: Parser UExpr uIsoSugar = withSrc (char '#' *> options) where options = (recordFieldIso <$> fieldLabel) <|> char '?' *> (variantFieldIso <$> fieldLabel) <|> char '&' *> (recordZipIso <$> fieldLabel) <|> char '|' *> (variantZipIso <$> fieldLabel) - ns = WithSrc Nothing var s = ns $ UVar $ mkName s :> () - patb s = ns $ UPatBinder $ Bind $ mkName s :> () + patb s = ns $ nameToPat $ mkName s plain = PlainArrow () lam p b = ns $ ULam (p, Nothing) plain b recordFieldIso field = @@ -924,8 +884,8 @@ mkSymName s = mkName $ "(" <> s <> ")" prefixNegOp :: Operator Parser UExpr prefixNegOp = Prefix $ label "negation" $ do ((), pos) <- withPos $ sym "-" - let f = WithSrc (Just pos) $ UVar $ mkName "neg" :> () - return $ \case + let f = WithSrc (Just pos) "neg" + return \case -- Special case: negate literals directly WithSrc litpos (IntLitExpr i) -> WithSrc (joinPos (Just pos) litpos) (IntLitExpr (-i)) @@ -947,10 +907,10 @@ infixArrow :: Parser (UType -> UType -> UType) infixArrow = do notFollowedBy (sym "=>") -- table arrows have special fixity (arr, pos) <- withPos $ arrow effects - return $ \a b -> WithSrc (Just pos) $ UPi (Nothing, a) arr b + return \a b -> WithSrc (Just pos) $ UPi (UnderscoreUPat, Just a) arr b mkArrow :: Arrow -> UExpr -> UExpr -> UExpr -mkArrow arr a b = joinSrc a b $ UPi (Nothing, a) arr b +mkArrow arr a b = joinSrc a b $ UPi (UnderscoreUPat, Just a) arr b withSrc :: Parser a -> Parser (WithSrc a) withSrc p = do @@ -992,24 +952,11 @@ inpostfix' :: Parser a -> Parser (a -> Maybe a -> a) -> Operator Parser a inpostfix' p op = Postfix $ do f <- op rest <- optional p - return $ \x -> f x rest + return \x -> f x rest mkName :: String -> Name mkName s = Name SourceName (fromString s) 0 -nameToStr :: Name -> String -nameToStr = tagToStr . nameTag - --- This function is used to generate a string that is guaranteed to never shadow --- any user-defined name, as "#" is an invalid identifier character in normal --- source code. -mkNoShadowingStr :: String -> String -mkNoShadowingStr = (++ "#") - -mkInterfaceConsName :: Name -> Name -mkInterfaceConsName = - GlobalName . fromString . mkNoShadowingStr . nameToStr - -- === lexemes === -- These `Lexer` actions must be non-overlapping and never consume input on failure @@ -1017,7 +964,8 @@ type Lexer = Parser data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW | ReadKW | WriteKW | StateKW | DataKW | InterfaceKW - | InstanceKW | WhereKW | IfKW | ThenKW | ElseKW + | InstanceKW | WhereKW | IfKW | ThenKW | ElseKW | DoKW + | ExceptKW | IOKW | ViewKW upperName :: Lexer Name upperName = liftM mkName $ label "upper-case name" $ lexeme $ @@ -1027,6 +975,12 @@ lowerName :: Lexer Name lowerName = liftM mkName $ label "lower-case name" $ lexeme $ checkNotKeyword $ (:) <$> lowerChar <*> many nameTailChar +anyCaseName :: Lexer Name +anyCaseName = lowerName <|> upperName + +anyName :: Lexer Name +anyName = lowerName <|> upperName <|> symName + checkNotKeyword :: Parser String -> Parser String checkNotKeyword p = try $ do s <- p @@ -1050,15 +1004,19 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar ReadKW -> "Read" WriteKW -> "Accum" StateKW -> "State" + ExceptKW -> "Except" + IOKW -> "IO" DataKW -> "data" InterfaceKW -> "interface" InstanceKW -> "instance" WhereKW -> "where" + DoKW -> "do" + ViewKW -> "view" keyWordStrs :: [String] keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam", - "Read", "Write", "Accum", "data", "interface", - "instance", "where", "if", "then", "else"] + "Read", "Write", "Accum", "Except", "IO", "data", "interface", + "instance", "where", "if", "then", "else", "do", "view"] fieldLabel :: Lexer Label fieldLabel = label "field label" $ lexeme $ @@ -1172,6 +1130,9 @@ mayPair p = local (\ctx -> ctx { canPair = True }) p mayNotPair :: Parser a -> Parser a mayNotPair p = local (\ctx -> ctx { canPair = False }) p +optionalMonoid :: Monoid a => Parser a -> Parser a +optionalMonoid p = p <|> return mempty + nameString :: Parser String nameString = lexeme . try $ (:) <$> lowerChar <*> many alphaNumChar @@ -1214,7 +1175,7 @@ withPos p = do nextLine :: Parser () nextLine = do - void eol + eol n <- asks curIndent void $ mayNotBreak $ many $ try (sc >> eol) void $ replicateM n (char ' ') @@ -1231,8 +1192,11 @@ withIndent p = do indent <- liftM length $ some (char ' ') local (\ctx -> ctx { curIndent = curIndent ctx + indent }) $ p +eol :: Parser () +eol = void MC.eol + eolf :: Parser () -eolf = void eol <|> eof +eolf = eol <|> eof failIf :: Bool -> String -> Parser () failIf True s = fail s diff --git a/src/lib/PipeRPC.hs b/src/lib/PipeRPC.hs deleted file mode 100644 index 4369160f2..000000000 --- a/src/lib/PipeRPC.hs +++ /dev/null @@ -1,60 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module PipeRPC (PipeServer, startPipeServer, callPipeServer, psPop) where - -import Control.Concurrent.MVar -import Control.Monad -import Control.Monad.IO.Class -import Data.Aeson -import Data.ByteString.Lazy.Char8 (pack, unpack) -import GHC.IO.Handle.FD -import System.IO -import System.Process - -data PipeServer f = PipeServer { _psLock :: MVar () - , _psSendHandle :: Handle - , _psReceiveHandle :: Handle - , psFunctionIndex :: Int} - -startPipeServer :: MonadIO m => FilePath -> [String] -> m (PipeServer f) -startPipeServer cmd args = liftIO $ do - ((clientRead, _), (_, serverWrite)) <- createPipeWithNames - ((_, serverRead), (clientWrite, _)) <- createPipeWithNames - void $ createProcess $ proc cmd $ args ++ [serverRead, serverWrite] - lock <- newMVar () - return $ PipeServer lock clientWrite clientRead 0 - -psPop :: PipeServer (head, tail) -> PipeServer tail -psPop server = server { psFunctionIndex = 1 + psFunctionIndex server } - -callPipeServer :: (MonadIO m, ToJSON a, FromJSON b) - => PipeServer (a -> b, tail) -> a -> m b -callPipeServer (PipeServer lock sendHandle receiveHandle fIdx) arg = liftIO $ do - void $ takeMVar lock - let request = unpack $ encode (fIdx, arg) - hPutStrLn sendHandle request - response <- hGetLine receiveHandle - putMVar lock () - case eitherDecode (pack response) of - Right x -> case x of - Right x' -> return x' - Left s -> error $ "Error thrown by server:\n" ++ s - Left s -> error $ s ++ "\nDecoding error. Full response:\n" ++ response - -createPipeWithNames :: IO ((Handle, String), (Handle, String)) -createPipeWithNames = do - (r, w) <- createPipe - hSetBuffering r LineBuffering - hSetBuffering w LineBuffering - rName <- unixHandleName r - wName <- unixHandleName w - return ((r,rName), (w, wName)) - -unixHandleName :: Handle -> IO String -unixHandleName h = do - fd <- handleToFd h - return $ "/dev/fd/" ++ show fd diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index cfd03ad15..6a2308a66 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -19,10 +19,11 @@ import Control.Monad import Text.Megaparsec hiding (chunk) import Text.Megaparsec.Char as C +import Resources (cssSource) import Syntax import PPrint import Parser -import Serialize() +import Serialize () pprintHtml :: ToMarkup a => a -> String pprintHtml x = renderHtml $ toMarkup x @@ -34,7 +35,7 @@ progHtml blocks = renderHtml $ wrapBody $ map toHtmlBlock blocks wrapBody :: [Html] -> Html wrapBody blocks = docTypeHtml $ do H.head $ do - H.link ! rel "stylesheet" ! href "style.css" ! type_ "text/css" + H.style ! type_ "text/css" $ toHtml cssSource H.meta ! charset "UTF-8" H.body $ H.div inner ! At.id "main-output" where inner = foldMap (cdiv "cell") blocks diff --git a/src/lib/Serialize.hs b/src/lib/Serialize.hs index 7275c6968..602fdeed2 100644 --- a/src/lib/Serialize.hs +++ b/src/lib/Serialize.hs @@ -22,7 +22,6 @@ import Interpreter import Syntax import Type import PPrint -import Interpreter (indices) pprintVal :: Val -> IO String pprintVal val = asStr <$> prettyVal val @@ -32,7 +31,7 @@ getDexString :: Val -> IO String getDexString (DataCon _ _ 0 [_, xs]) = do let (TabTy b _) = getType xs idxs <- indices $ getType b - forM idxs $ \i -> do + forM idxs \i -> do ~(Con (Lit (Word8Lit c))) <- evalBlock mempty (Block Empty (App xs i)) return $ toEnum $ fromIntegral c getDexString x = error $ "Not a string: " ++ pprint x @@ -50,7 +49,7 @@ prettyVal val = case val of _ -> "@" <> pretty idxSet -- Otherwise, show explicit index set -- Pretty-print elements. idxs <- indices idxSet - elems <- forM idxs $ \idx -> do + elems <- forM idxs \idx -> do atom <- evalBlock mempty $ snd $ applyAbs abs idx case atom of Con (Lit (Word8Lit c)) -> diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 05a87b191..d1c03e4fd 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -9,7 +9,6 @@ module Simplify (simplifyModule, simplifyCase, splitSimpModule) where import Control.Monad -import Control.Monad.Identity import Control.Monad.Reader import Data.Maybe import Data.Foldable (toList) @@ -17,6 +16,7 @@ import Data.Functor import Data.List (partition, elemIndex) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M +import qualified Data.Set as S import Autodiff import Env @@ -27,7 +27,7 @@ import Type import PPrint import Util -type SimplifyM = SubstEmbedT Identity +type SimplifyM = SubstEmbed simplifyModule :: TopEnv -> Module -> Module simplifyModule scope (Module Core decls bindings) = do @@ -61,7 +61,7 @@ hoistDepDataCons scope (Module Simp decls bindings) = where (bindings', (_, decls')) = flip runEmbed scope $ do mapM_ emitDecl decls - forM bindings $ \(ty, info) -> case info of + forM bindings \(ty, info) -> case info of LetBound ann x | isData ty -> do x' <- emit x return (ty, LetBound ann $ Atom x') _ -> return (ty, info) @@ -89,7 +89,7 @@ simplifyDecl (Let ann b expr) = do simplifyStandalone :: Expr -> SimplifyM Atom simplifyStandalone (Atom (LamVal b body)) = do b' <- mapM substEmbedR b - buildLam b' PureArrow $ \x -> + buildLam b' PureArrow \x -> extendR (b@>x) $ simplifyBlock body simplifyStandalone block = error $ "@noinline decorator applied to non-function" ++ pprint block @@ -139,9 +139,9 @@ simplifyAtom atom = case atom of case simplifyCase e' alts of Just (env, result) -> extendR env $ simplifyAtom result Nothing -> do - alts' <- forM alts $ \(Abs bs a) -> do + alts' <- forM alts \(Abs bs a) -> do bs' <- mapM (mapM substEmbedR) bs - (Abs bs'' b) <- buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ simplifyAtom a + (Abs bs'' b) <- buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ simplifyAtom a case b of Block Empty (Atom r) -> return $ Abs bs'' r _ -> error $ "Nontrivial block in ACase simplification" @@ -192,7 +192,7 @@ simplifyLams numArgs lam = do Left res -> (res, Nothing) Right (dat, (ctx, recon), atomf) -> ( mkConsList $ (toList dat) ++ (toList ctx) - , Just $ \vals -> do + , Just \vals -> do (datEls', ctxEls') <- splitAt (length dat) <$> unpackConsList vals let dat' = restructure datEls' dat let ctx' = restructure ctxEls' ctx @@ -200,7 +200,7 @@ simplifyLams numArgs lam = do ) go n scope ~(Block Empty (Atom (Lam (Abs b (arr, body))))) = do b' <- mapM substEmbedR b - buildLamAux b' (\x -> extendR (b@>x) $ substEmbedR arr) $ \x@(Var v) -> do + buildLamAux b' (\x -> extendR (b@>x) $ substEmbedR arr) \x@(Var v) -> do let scope' = scope <> v @> (varType v, LamBound (void arr)) extendR (b@>x) $ go (n-1) scope' body @@ -278,7 +278,7 @@ separateDataComponent localVars v = do True -> nubCtx t False -> h : (nubCtx t) result = nubCtx $ toList ll - inv ctx' result' = for ll $ \x -> case elemIndex x (toList ctx) of + inv ctx' result' = for ll \x -> case elemIndex x (toList ctx) of Just i -> (toList ctx') !! i Nothing -> result' !! (fromJust $ elemIndex x result) @@ -299,7 +299,7 @@ simplifyExpr expr = case expr of case all isCurriedFun alts of True -> return $ ACase e (fmap appAlt alts) rty' False -> do - let alts' = for alts $ \(Abs bs a) -> Abs bs $ Block Empty (App a x') + let alts' = for alts \(Abs bs a) -> Abs bs $ Block Empty (App a x') dropSub $ simplifyExpr $ Case e alts' rty' where isCurriedFun alt = case alt of @@ -321,16 +321,16 @@ simplifyExpr expr = case expr of Nothing -> do if isData resultTy' then do - alts' <- forM alts $ \(Abs bs body) -> do + alts' <- forM alts \(Abs bs body) -> do bs' <- mapM (mapM substEmbedR) bs - buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ simplifyBlock body + buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ simplifyBlock body emit $ Case e' alts' resultTy' else do -- Construct the blocks of new cases. The results will only get replaced -- later, once we learn the closures of the non-data component of each case. - (alts', facs) <- liftM unzip $ forM alts $ \(Abs bs body) -> do + (alts', facs) <- liftM unzip $ forM alts \(Abs bs body) -> do bs' <- mapM (mapM substEmbedR) bs - buildNAbsAux bs' $ \xs -> do + buildNAbsAux bs' \xs -> do ~(Right fac@(dat, (ctx, _), _)) <- extendR (newEnv bs' xs) $ defunBlock (boundVars bs') body -- NB: The return value here doesn't really matter as we're going to replace it afterwards. return (mkConsList $ toList dat ++ toList ctx, fac) @@ -361,9 +361,9 @@ simplifyExpr expr = case expr of -- a single output. This can probably be made quite a bit faster. -- NB: All the non-data trees have the same structure, so we pick an arbitrary one. nondatTree <- (\(_, (ctx, rec), _) -> rec dat ctx) $ head facs - nondat <- forM (enumerate nondatTree) $ \(i, _) -> do - aalts <- forM facs $ \(_, (ctx, rec), _) -> do - Abs bs' b <- buildNAbs (toNest $ toList $ fmap (Ignore . getType) ctx) $ \ctxVals -> + nondat <- forM (enumerate nondatTree) \(i, _) -> do + aalts <- forM facs \(_, (ctx, rec), _) -> do + Abs bs' b <- buildNAbs (toNest $ toList $ fmap (Ignore . getType) ctx) \ctxVals -> ((!! i) . toList) <$> rec dat (restructure ctxVals ctx) case b of Block Empty (Atom r) -> return $ Abs bs' r @@ -441,16 +441,15 @@ simplifyHof hof = case hof of ans <- emit $ Hof $ For d lam' case recon of Nothing -> return ans - Just f -> buildLam i TabArrow $ \i' -> app ans i' >>= f + Just f -> buildLam i TabArrow \i' -> app ans i' >>= f Tile d fT fS -> do ~(fT', Nothing) <- simplifyLam fT ~(fS', Nothing) <- simplifyLam fS emit $ Hof $ Tile d fT' fS' PTileReduce _ _ -> error "Unexpected PTileReduce" - While cond body -> do - ~(cond', Nothing) <- simplifyLam cond + While body -> do ~(body', Nothing) <- simplifyLam body - emit $ Hof $ While cond' body' + emit $ Hof $ While body' Linearize lam -> do ~(lam', Nothing) <- simplifyLam lam scope <- getScope @@ -475,6 +474,83 @@ simplifyHof hof = case hof of (ans, sOut) <- fromPair =<< (emit $ Hof $ RunState s' lam') ans' <- applyRecon recon ans return $ PairVal ans' sOut + RunIO lam -> do + ~(lam', recon) <- simplifyLam lam + ans <- emit $ Hof $ RunIO lam' + applyRecon recon ans + CatchException lam -> do + ~(Lam (Abs _ (_, body)), Nothing) <- simplifyLam lam + dropSub $ exceptToMaybeBlock body where applyRecon Nothing x = return x applyRecon (Just f) x = f x + +exceptToMaybeBlock :: Block -> SubstEmbed Atom +exceptToMaybeBlock (Block Empty result) = exceptToMaybeExpr result +exceptToMaybeBlock (Block (Nest (Let _ b expr) decls) result) = do + a <- substEmbedR $ getType result + maybeResult <- exceptToMaybeExpr expr + case maybeResult of + -- These two cases are just an optimization + JustAtom _ x -> extendR (b@>x) $ exceptToMaybeBlock $ Block decls result + NothingAtom _ -> return $ NothingAtom a + _ -> do + emitMaybeCase maybeResult (return $ NothingAtom a) \x -> do + extendR (b@>x) $ exceptToMaybeBlock $ Block decls result + +exceptToMaybeExpr :: Expr -> SubstEmbed Atom +exceptToMaybeExpr expr = do + a <- substEmbedR $ getType expr + case expr of + Case e alts resultTy -> do + e' <- substEmbedR e + resultTy' <- substEmbedR $ MaybeTy resultTy + alts' <- forM alts \(Abs bs body) -> do + bs' <- substEmbedR bs + buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ exceptToMaybeBlock body + emit $ Case e' alts' resultTy' + Atom x -> substEmbedR $ JustAtom (getType x) x + Op (ThrowException _) -> return $ NothingAtom a + Hof (For ann ~(Lam (Abs b (_, body)))) -> do + b' <- substEmbedR b + maybes <- buildForAnn ann b' \i -> extendR (b@>i) $ exceptToMaybeBlock body + catMaybesE maybes + Hof (RunState s lam) -> do + s' <- substEmbedR s + let BinaryFunVal _ b _ body = lam + result <- emitRunState "ref" s' \ref -> + extendR (b@>ref) $ exceptToMaybeBlock body + (maybeAns, newState) <- fromPair result + emitMaybeCase maybeAns (return $ NothingAtom a) \ans -> + return $ JustAtom a $ PairVal ans newState + Hof (While ~(Lam (Abs _ (_, body)))) -> do + eff <- getAllowedEffects + lam <- buildLam (Ignore UnitTy) (PlainArrow eff) \_ -> + exceptToMaybeBlock body + runMaybeWhile lam + _ | not (hasExceptions expr) -> do + x <- substEmbedR expr >>= emit + return $ JustAtom (getType x) x + | otherwise -> + error $ "Unexpected exception-throwing expression: " ++ pprint expr + +hasExceptions :: Expr -> Bool +hasExceptions expr = case t of + Nothing -> ExceptionEffect `S.member` effs + Just _ -> error "Shouldn't have tail left" + where (EffectRow effs t) = exprEffs expr + +catMaybesE :: MonadEmbed m => Atom -> m Atom +catMaybesE maybes = simplifyEmbed $ do + let (TabTy b (MaybeTy a)) = getType maybes + applyPreludeFunction "seqMaybes" [binderAnn b, a, maybes] + +runMaybeWhile :: MonadEmbed m => Atom -> m Atom +runMaybeWhile lam = simplifyEmbed $ do + let (Pi (Abs _ (PlainArrow eff, _))) = getType lam + applyPreludeFunction "whileMaybe" [Eff eff, lam] + +simplifyEmbed :: MonadEmbed m => m Atom -> m Atom +simplifyEmbed m = do + block <- buildScoped m + liftEmbed $ runReaderT (simplifyBlock block) mempty diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 0120cc14c..4aa7d8d3c 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -10,17 +10,15 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE Rank2Types #-} module Syntax ( Type, Kind, BaseType (..), ScalarBaseType (..), - Effect, EffectName (..), EffectRow (..), + Effect (..), RWS (..), EffectRow (..), ClassName (..), TyQual (..), SrcPos, Var, Binder, Block (..), Decl (..), Expr (..), Atom (..), ArrowP (..), Arrow, PrimTC (..), Abs (..), - PrimExpr (..), PrimCon (..), LitVal (..), - PrimEffect (..), PrimOp (..), EffectSummary, pattern NoEffects, + PrimExpr (..), PrimCon (..), LitVal (..), PrimEffect (..), PrimOp (..), PrimHof (..), LamExpr, PiType, WithSrc (..), srcPos, LetAnn (..), BinOp (..), UnOp (..), CmpOp (..), SourceBlock (..), ReachedEOF, SourceBlock' (..), SubstEnv, ScopedSubstEnv, @@ -31,29 +29,32 @@ module Syntax ( IPrimOp, IVar, IBinder, IType, SetVal (..), MonMap (..), LitProg, IFunType (..), IFunVar, CallingConvention (..), IsCUDARequired (..), UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, - reflectLabels, withLabels, ExtLabeledItems (..), prefixExtLabeledItems, + lookupLabelHead, reflectLabels, withLabels, ExtLabeledItems (..), + prefixExtLabeledItems, getLabels, IScope, BinderInfo (..), Bindings, CUDAKernel (..), BenchStats, - SrcCtx, Result (..), Output (..), OutFormat (..), DataFormat (..), + SrcCtx, Result (..), Output (..), OutFormat (..), Err (..), ErrType (..), Except, throw, throwIf, modifyErr, addContext, addSrcContext, catchIOExcept, liftEitherIO, (-->), (--@), (==>), boundUVars, PassName (..), boundVars, renamingSubst, bindingsAsVars, freeVars, freeUVars, Subst, HasVars, BindsVars, Ptr, PtrType, - AddressSpace (..), PtrOrigin (..), showPrimName, strToPrimName, primNameToStr, + AddressSpace (..), showPrimName, strToPrimName, primNameToStr, monMapSingle, monMapLookup, Direction (..), Limit (..), - UExpr, UExpr' (..), UType, UPatAnn, UPiPatAnn, UAnnBinder, UVar, + UExpr, UExpr' (..), UType, UPatAnn, UAnnBinder, UVar, + UMethodDef (..), UPatAnnArrow, UPat, UPat' (..), UModule (..), UDecl (..), UArrow, arrowEff, DataDef (..), DataConDef (..), UConDef (..), Nest (..), toNest, subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, applyNaryAbs, applyDataDefParams, freshSkolemVar, IndexStructure, mkConsList, mkConsListTy, fromConsList, fromConsListTy, extendEffRow, - getProjection, + getProjection, outputStreamPtrName, initTopEnv, varType, binderType, isTabTy, LogLevel (..), IRVariant (..), applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, getIntLit, getFloatLit, sizeOf, vectorWidth, + pattern MaybeTy, pattern JustAtom, pattern NothingAtom, pattern IdxRepTy, pattern IdxRepVal, pattern IIdxRepVal, pattern IIdxRepTy, pattern TagRepTy, pattern TagRepVal, pattern Word8Ty, pattern IntLitExpr, pattern FloatLitExpr, - pattern UnitTy, pattern PairTy, pattern FunTy, + pattern UnitTy, pattern PairTy, pattern FunTy, pattern PiTy, pattern FixedIntRange, pattern Fin, pattern RefTy, pattern RawRefTy, pattern BaseTy, pattern PtrTy, pattern UnitVal, pattern PairVal, pattern PureArrow, @@ -61,24 +62,25 @@ module Syntax ( pattern TabTy, pattern TabTyAbs, pattern TabVal, pattern TabValA, pattern Pure, pattern BinaryFunTy, pattern BinaryFunVal, pattern Unlabeled, pattern NoExt, pattern LabeledRowKind, - pattern NoLabeledItems, pattern InternalSingletonLabel, pattern EffKind) + pattern NoLabeledItems, pattern InternalSingletonLabel, pattern EffKind, + pattern NestOne, pattern NewTypeCon, pattern BinderAnn, + pattern ClassDictDef, pattern ClassDictCon, pattern UnderscoreUPat) where import qualified Data.Map.Strict as M import Control.Exception hiding (throw) -import Control.Monad.Fail import Control.Monad.Identity import Control.Monad.Writer hiding (Alt) import Control.Monad.Except hiding (Except) import qualified Data.ByteString.Char8 as B -import Data.List (sort) import qualified Data.List.NonEmpty as NE import qualified Data.Set as S import Data.Store (Store) import Data.Tuple (swap) -import Data.Foldable (toList) +import Data.Foldable (toList, fold) import Data.Int import Data.Word +import Data.String (IsString, fromString) import Foreign.Ptr import GHC.Generics @@ -186,11 +188,19 @@ labeledSingleton label value = LabeledItems $ M.singleton label (value NE.:|[]) reflectLabels :: LabeledItems a -> LabeledItems (Label, Int) reflectLabels (LabeledItems items) = LabeledItems $ - flip M.mapWithKey items $ \k xs -> fmap (\(i,_) -> (k,i)) (enumerate xs) + flip M.mapWithKey items \k xs -> fmap (\(i,_) -> (k,i)) (enumerate xs) + +getLabels :: LabeledItems a -> [Label] +getLabels labeledItems = map fst $ toList $ reflectLabels labeledItems withLabels :: LabeledItems a -> LabeledItems (Label, Int, a) withLabels (LabeledItems items) = LabeledItems $ - flip M.mapWithKey items $ \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) + flip M.mapWithKey items \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) + +lookupLabelHead :: LabeledItems a -> Label -> Maybe a +lookupLabelHead (LabeledItems items) l = case M.lookup l items of + Nothing -> Nothing + Just (x NE.:| _) -> Just x instance Semigroup (LabeledItems a) where LabeledItems items <> LabeledItems items' = @@ -216,7 +226,7 @@ prefixExtLabeledItems items (Ext items' rest) = Ext (items <> items') rest type UExpr = WithSrc UExpr' data UExpr' = UVar UVar | ULam UPatAnn UArrow UExpr - | UPi UPiPatAnn Arrow UType + | UPi UPatAnn Arrow UType | UApp UArrow UExpr UExpr | UDecl UDecl UExpr | UFor Direction UPatAnn UExpr @@ -236,17 +246,21 @@ data UExpr' = UVar UVar deriving (Show, Generic) data UConDef = UConDef Name (Nest UAnnBinder) deriving (Show, Generic) -data UDecl = ULet LetAnn UPatAnn UExpr - | UData UConDef [UConDef] - deriving (Show, Generic) +data UDecl = + ULet LetAnn UPatAnn UExpr + | UData UConDef [UConDef] + | UInterface [UType] UConDef [UAnnBinder] -- superclasses, constructor, methods + | UInstance (Nest UPatAnnArrow) UType [UMethodDef] -- args, type, methods + deriving (Show, Generic) type UType = UExpr type UArrow = ArrowP () type UVar = VarP () type UBinder = BinderP () +data UMethodDef = UMethodDef UVar UExpr deriving (Show, Generic) -type UPatAnn = (UPat, Maybe UType) -type UPiPatAnn = (Maybe UPat, UType) +type UPatAnn = (UPat, Maybe UType) +type UPatAnnArrow = (UPatAnn, UArrow) type UAnnBinder = BinderP UType data UAlt = UAlt UPat UExpr deriving (Show, Generic) @@ -271,6 +285,12 @@ data WithSrc a = WithSrc SrcCtx a srcPos :: WithSrc a -> SrcCtx srcPos (WithSrc pos _) = pos +instance IsString UExpr' where + fromString s = UVar $ Name SourceName (fromString s) 0 :> () + +pattern UnderscoreUPat :: UPat +pattern UnderscoreUPat = WithSrc Nothing (UPatBinder (Ignore ())) + -- === primitive constructors and operators === data PrimExpr e = @@ -322,12 +342,14 @@ data PrimOp e = | SndRef e | FFICall String e [e] | Inject e - | PtrOffset e e - | PtrLoad e - | GetPtr e - | MakePtrType e | SliceOffset e e -- Index slice first, inner index second | SliceCurry e e -- Index slice first, curried index second + -- Low-level memory operations + | IOAlloc BaseType e + | IOFree e + | PtrOffset e e + | PtrLoad e + | PtrStore e e -- SIMD operations | VectorBinOp BinOp e e | VectorPack [e] -- List should have exactly vectorWidth elements @@ -337,6 +359,7 @@ data PrimOp e = | ToOrdinal e | IdxSetSize e | ThrowError e + | ThrowException e -- Catchable exceptions (unlike `ThrowError`) | CastOp e e -- Type, then value. See Type.hs for valid coercions. -- Extensible record and variant operations: -- Add fields to a record (on the left). Left arg contains values to add. @@ -360,10 +383,12 @@ data PrimOp e = data PrimHof e = For ForAnn e | Tile Int e e -- dimension number, tiled body, scalar body - | While e e + | While e | RunReader e e | RunWriter e | RunState e e + | RunIO e + | CatchException e | Linearize e | Transpose e | PTileReduce e e -- index set, thread body @@ -418,33 +443,39 @@ showPrimName prim = primNameToStr $ fmap (const ()) prim -- === effects === -type Effect = (EffectName, Name) -data EffectRow = EffectRow [Effect] (Maybe Name) - deriving (Show, Generic) -data EffectName = Reader | Writer | State deriving (Show, Eq, Ord, Generic) +data EffectRow = EffectRow (S.Set Effect) (Maybe Name) + deriving (Show, Eq, Generic) -type EffectSummary = S.Set Effect +data RWS = Reader | Writer | State deriving (Show, Eq, Ord, Generic) +data Effect = RWSEffect RWS Name | ExceptionEffect | IOEffect + deriving (Show, Eq, Ord, Generic) -instance HasVars EffectSummary where - freeVars effs = foldMap (\(_, reg) -> reg @> (TyKind, UnknownBinder)) effs +pattern Pure :: EffectRow +pattern Pure <- ((\(EffectRow effs t) -> (S.null effs, t)) -> (True, Nothing)) + where Pure = mempty -instance Subst EffectSummary where - subst (env, _) effs = S.map substEff effs - where - substEff (eff, name) = case envLookup env name of - Just ~(Var (name':>_)) -> (eff, name') - Nothing -> (eff, name) +outputStreamPtrName :: Name +outputStreamPtrName = GlobalName "OUT_STREAM_PTR" -pattern Pure :: EffectRow -pattern Pure = EffectRow [] Nothing +initTopEnv :: TopEnv +initTopEnv = fold [v @> (ty, LamBound ImplicitArrow) | (v, ty) <- + [(outputStreamPtrName , BaseTy $ hostPtrTy $ hostPtrTy $ Scalar Word8Type)]] + +hostPtrTy :: BaseType -> BaseType +hostPtrTy ty = PtrType (Heap CPU, ty) -pattern NoEffects :: EffectSummary -pattern NoEffects <- ((S.null) -> True) - where NoEffects = mempty +instance Semigroup EffectRow where + EffectRow effs t <> EffectRow effs' t' = + EffectRow (S.union effs effs') newTail + where + newTail = case (t, t') of + (Nothing, effTail) -> effTail + (effTail, Nothing) -> effTail + _ | t == t' -> t + | otherwise -> error "Can't combine effect rows with mismatched tails" -instance Eq EffectRow where - EffectRow effs t == EffectRow effs' t' = - sort effs == sort effs' && t == t' +instance Monoid EffectRow where + mempty = EffectRow mempty Nothing -- === top-level constructs === @@ -462,14 +493,13 @@ data SourceBlock' = RunModule UModule | Command CmdName (Name, UModule) | GetNameType Name | IncludeSourceFile String - | LoadData UPatAnn DataFormat String | ProseBlock String | CommentLine | EmptyLines | UnParseable ReachedEOF String deriving (Show, Generic) -data CmdName = GetType | EvalExpr OutFormat | ExportFun String | Dump DataFormat String +data CmdName = GetType | EvalExpr OutFormat | ExportFun String deriving (Show, Generic) data LogLevel = LogNothing | PrintEvalTime | PrintBench String @@ -514,7 +544,7 @@ data ImpFunction = ImpFunction IFunVar [IBinder] ImpBlock data ImpBlock = ImpBlock (Nest ImpDecl) [IExpr] deriving (Show) data ImpDecl = ImpLet [IBinder] ImpInstr deriving (Show) data ImpInstr = IFor Direction IBinder Size ImpBlock - | IWhile ImpBlock ImpBlock -- cond block, body block + | IWhile ImpBlock | ICond IExpr ImpBlock ImpBlock | IQueryParallelism IFunVar IExpr -- returns the number of available concurrent threads | ISyncWorkgroup @@ -529,7 +559,7 @@ data ImpInstr = IFor Direction IBinder Size ImpBlock | IPrimOp IPrimOp deriving (Show) -data Backend = LLVM | LLVMCUDA | LLVMMC | Interp deriving (Show, Eq) +data Backend = LLVM | LLVMCUDA | LLVMMC | Interpreter deriving (Show, Eq) newtype CUDAKernel = CUDAKernel B.ByteString deriving (Show) -- === base types === @@ -553,8 +583,7 @@ data BaseType = Scalar ScalarBaseType data Device = CPU | GPU deriving (Show, Eq, Ord, Generic) data AddressSpace = Stack | Heap Device deriving (Show, Eq, Ord, Generic) -data PtrOrigin = DerivedPtr | AllocatedPtr deriving (Show, Eq, Ord, Generic) -type PtrType = (PtrOrigin, AddressSpace, BaseType) +type PtrType = (AddressSpace, BaseType) sizeOf :: BaseType -> Int sizeOf t = case t of @@ -597,7 +626,7 @@ monMapLookup (MonMap m) k = case M.lookup k m of Nothing -> mempty -- === passes === data PassName = Parse | TypePass | SynthPass | SimpPass | ImpPass | JitPass - | Flops | LLVMOpt | AsmPass | JAXPass | JAXSimpPass | LLVMEval + | LLVMOpt | AsmPass | JAXPass | JAXSimpPass | LLVMEval | ResultPass | JaxprAndHLO | OptimPass deriving (Ord, Eq, Bounded, Enum) @@ -605,7 +634,7 @@ instance Show PassName where show p = case p of Parse -> "parse" ; TypePass -> "typed" ; SynthPass -> "synth" SimpPass -> "simp" ; ImpPass -> "imp" ; JitPass -> "llvm" - Flops -> "flops" ; LLVMOpt -> "llvmopt" ; AsmPass -> "asm" + LLVMOpt -> "llvmopt" ; AsmPass -> "asm" JAXPass -> "jax" ; JAXSimpPass -> "jsimp"; ResultPass -> "result" LLVMEval -> "llvmeval" ; JaxprAndHLO -> "jaxprhlo"; OptimPass -> "optimized" @@ -615,7 +644,7 @@ type LitProg = [(SourceBlock, Result)] type SrcCtx = Maybe SrcPos data Result = Result [Output] (Except ()) deriving (Show, Eq) -type BenchStats = Int -- number of runs +type BenchStats = (Int, Double) -- number of runs, total benchmarking time data Output = TextOut String | HtmlOut String | PassInfo PassName String @@ -627,7 +656,6 @@ data Output = TextOut String deriving (Show, Eq, Generic) data OutFormat = Printed | RenderHtml deriving (Show, Eq, Generic) -data DataFormat = DexObject | DexBinaryObject deriving (Show, Eq, Generic) data Err = Err ErrType SrcCtx String deriving (Show, Eq) instance Exception Err @@ -657,10 +685,10 @@ throwIf True e s = throw e s throwIf False _ _ = return () modifyErr :: MonadError e m => m a -> (e -> e) -> m a -modifyErr m f = catchError m $ \e -> throwError (f e) +modifyErr m f = catchError m \e -> throwError (f e) addContext :: MonadError Err m => String -> m a -> m a -addContext s m = modifyErr m $ \(Err e p s') -> Err e p (s' ++ "\n" ++ s) +addContext s m = modifyErr m \(Err e p s') -> Err e p (s' ++ "\n" ++ s) addSrcContext :: MonadError Err m => SrcCtx -> m a -> m a addSrcContext ctx m = modifyErr m updateErr @@ -671,9 +699,9 @@ addSrcContext ctx m = modifyErr m updateErr catchIOExcept :: (MonadIO m , MonadError Err m) => IO a -> m a catchIOExcept m = (liftIO >=> liftEither) $ (liftM Right m) `catches` - [ Handler $ \(e::Err) -> return $ Left e - , Handler $ \(e::IOError) -> return $ Left $ Err DataIOErr Nothing $ show e - , Handler $ \(e::SomeException) -> return $ Left $ Err CompilerErr Nothing $ show e + [ Handler \(e::Err) -> return $ Left e + , Handler \(e::IOError) -> return $ Left $ Err DataIOErr Nothing $ show e + , Handler \(e::SomeException) -> return $ Left $ Err CompilerErr Nothing $ show e ] liftEitherIO :: (Exception e, MonadIO m) => Either e a -> m a @@ -771,11 +799,23 @@ instance BindsUVars UPat' where instance HasUVars UDecl where freeUVars (ULet _ p expr) = freeUVars p <> freeUVars expr freeUVars (UData (UConDef _ bs) dataCons) = freeUVars $ Abs bs dataCons + freeUVars (UInterface superclasses tc methods) = + freeUVars $ Abs tc (superclasses, methods) + freeUVars (UInstance bsArrows ty methods) = freeUVars $ Abs bs (ty, methods) + where bs = fmap fst bsArrows + +instance HasUVars UMethodDef where + freeUVars (UMethodDef _ def) = freeUVars def + +instance BindsUVars UPatAnn where + boundUVars (p, _) = boundUVars p instance BindsUVars UDecl where boundUVars decl = case decl of - ULet _ (p,_) _ -> boundUVars p - UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons + ULet _ (p,_) _ -> boundUVars p + UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons + UInterface _ _ _ -> mempty + UInstance _ _ _ -> mempty instance HasUVars UModule where freeUVars (UModule decls) = freeUVars decls @@ -799,7 +839,12 @@ instance BindsUVars SourceBlock where instance HasUVars EffectRow where freeUVars (EffectRow effs tailVar) = - foldMap (nameAsEnv . snd) effs <> foldMap nameAsEnv tailVar + foldMap freeUVars effs <> foldMap nameAsEnv tailVar + +instance HasUVars Effect where + freeUVars (RWSEffect _ h) = nameAsEnv h + freeUVars ExceptionEffect = mempty + freeUVars IOEffect = mempty instance HasUVars a => HasUVars (LabeledItems a) where freeUVars (LabeledItems items) = foldMap freeUVars items @@ -988,8 +1033,9 @@ applyNaryAbs (Abs (Nest b bs) body) (x:xs) = applyNaryAbs ab xs applyNaryAbs _ _ = error "wrong number of arguments" applyDataDefParams :: DataDef -> [Type] -> [DataConDef] -applyDataDefParams (DataDef _ paramBs cons) params = - applyNaryAbs (Abs paramBs cons) params +applyDataDefParams (DataDef _ bs cons) params + | length params == length (toList bs) = applyNaryAbs (Abs bs cons) params + | otherwise = error $ "Wrong number of parameters: " ++ show (length params) makeAbs :: HasVars a => Binder -> a -> Abs Binder a makeAbs b body | b `isin` freeVars body = Abs b body @@ -1115,13 +1161,24 @@ instance Subst Module where where Abs decls' bindings' = subst env $ Abs decls bindings instance HasVars EffectRow where - freeVars (EffectRow row t) = - foldMap (\(_,v) -> v@>(TyKind , UnknownBinder)) row - <> foldMap (\v -> v@>(EffKind, UnknownBinder)) t + freeVars (EffectRow row t) = foldMap freeVars row + <> foldMap (\v -> v@>(EffKind, UnknownBinder)) t instance Subst EffectRow where - subst (env, _) (EffectRow row t) = extendEffRow - (fmap (\(effName, v) -> (effName, substName env v)) row) - (substEffTail env t) + subst env (EffectRow row t) = extendEffRow row' t' + where + row' = S.map (subst env) row + t' = substEffTail (fst env) t + +instance HasVars Effect where + freeVars eff = case eff of + RWSEffect _ v -> v@>(TyKind , UnknownBinder) + ExceptionEffect -> mempty + IOEffect -> mempty +instance Subst Effect where + subst (env,_) eff = case eff of + RWSEffect rws v -> RWSEffect rws (substName env v) + ExceptionEffect -> ExceptionEffect + IOEffect -> IOEffect instance HasVars BinderInfo where freeVars binfo = case binfo of @@ -1155,10 +1212,10 @@ instance Subst (ExtLabeledItems Type Name) where prefixExtLabeledItems (subst env items) (substExtLabeledItemsTail env' rest) substEffTail :: SubstEnv -> Maybe Name -> EffectRow -substEffTail _ Nothing = EffectRow [] Nothing +substEffTail _ Nothing = EffectRow mempty Nothing substEffTail env (Just v) = case envLookup env (v:>()) of - Nothing -> EffectRow [] (Just v) - Just (Var (v':>_)) -> EffectRow [] (Just v') + Nothing -> EffectRow mempty (Just v) + Just (Var (v':>_)) -> EffectRow mempty (Just v') Just (Eff r) -> r _ -> error "Not a valid effect substitution" @@ -1168,7 +1225,7 @@ substName env v = case envLookup env (v:>()) of Just (Var (v':>_)) -> v' _ -> error "Should only substitute with a name" -extendEffRow :: [Effect] -> EffectRow -> EffectRow +extendEffRow :: S.Set Effect -> EffectRow -> EffectRow extendEffRow effs (EffectRow effs' t) = EffectRow (effs <> effs') t substExtLabeledItemsTail :: SubstEnv -> Maybe Name -> ExtLabeledItems Type Name @@ -1238,7 +1295,7 @@ instance HasIVars ImpBlock where instance HasIVars ImpInstr where freeIVars i = case i of IFor _ b n p -> freeIVars n <> (freeIVars p `envDiff` (b @> ())) - IWhile c p -> freeIVars c <> freeIVars p + IWhile p -> freeIVars p ICond c t f -> freeIVars c <> freeIVars t <> freeIVars f IQueryParallelism _ s -> freeIVars s ISyncWorkgroup -> mempty @@ -1429,6 +1486,9 @@ fromConsList xs = case xs of pattern FunTy :: Binder -> EffectRow -> Type -> Type pattern FunTy b eff bodyTy = Pi (Abs b (PlainArrow eff, bodyTy)) +pattern PiTy :: Binder -> Arrow -> Type -> Type +pattern PiTy b arr bodyTy = Pi (Abs b (arr, bodyTy)) + pattern BinaryFunTy :: Binder -> Binder -> EffectRow -> Type -> Type pattern BinaryFunTy b1 b2 eff bodyTy = FunTy b1 Pure (FunTy b2 eff bodyTy) @@ -1463,7 +1523,48 @@ pattern Unlabeled as <- (_getUnlabeled -> Just as) Just ne -> LabeledItems (M.singleton InternalSingletonLabel ne) Nothing -> NoLabeledItems - -- TODO: Enable once https://gitlab.haskell.org//ghc/ghc/issues/13363 is fixed... +maybeDataDef :: DataDef +maybeDataDef = DataDef (GlobalName "Maybe") (Nest (Bind ("a":>TyKind)) Empty) + [ DataConDef (GlobalName "Nothing") Empty + , DataConDef (GlobalName "Just" ) (Nest (Ignore (Var ("a":>TyKind))) Empty)] + +pattern MaybeTy :: Type -> Type +pattern MaybeTy a = TypeCon MaybeDataDef [a] + +pattern MaybeDataDef :: DataDef +pattern MaybeDataDef <- ((\def -> def == maybeDataDef) -> True) + where MaybeDataDef = maybeDataDef + +pattern NothingAtom :: Type -> Atom +pattern NothingAtom ty = DataCon MaybeDataDef [ty] 0 [] + +pattern JustAtom :: Type -> Atom -> Atom +pattern JustAtom ty x = DataCon MaybeDataDef [ty] 1 [x] + +pattern NestOne :: a -> Nest a +pattern NestOne x = Nest x Empty + +pattern BinderAnn :: a -> BinderP a +pattern BinderAnn x <- ((\case Ignore ann -> ann + Bind (_:>ann) -> ann) -> x) + where BinderAnn x = Ignore x + +pattern NewTypeCon :: Name -> Type -> [DataConDef] +pattern NewTypeCon con ty = [DataConDef con (NestOne (BinderAnn ty))] + +pattern ClassDictDef :: Name + -> LabeledItems Type -> LabeledItems Type -> [DataConDef] +pattern ClassDictDef conName superclasses methods = + [DataConDef conName + (Nest (BinderAnn (RecordTy (NoExt superclasses))) + (Nest (BinderAnn (RecordTy (NoExt methods))) Empty))] + +pattern ClassDictCon :: DataDef -> [Type] + -> LabeledItems Atom -> LabeledItems Atom -> Atom +pattern ClassDictCon def params superclasses methods = + DataCon def params 0 [Record superclasses, Record methods] + +-- TODO: Enable once https://gitlab.haskell.org//ghc/ghc/issues/13363 is fixed... -- {-# COMPLETE TypeVar, ArrowType, TabTy, Forall, TypeAlias, Effect, NoAnn, TC #-} -- TODO: Can we derive these generically? Or use Show/Read? @@ -1492,7 +1593,8 @@ builtinNames = M.fromList , ("idxSetSize" , OpExpr $ IdxSetSize ()) , ("unsafeFromOrdinal", OpExpr $ UnsafeFromOrdinal () ()) , ("toOrdinal" , OpExpr $ ToOrdinal ()) - , ("throwError" , OpExpr $ ThrowError ()) + , ("throwError" , OpExpr $ ThrowError ()) + , ("throwException" , OpExpr $ ThrowException ()) , ("ask" , OpExpr $ PrimEffect () $ MAsk) , ("tell" , OpExpr $ PrimEffect () $ MTell ()) , ("get" , OpExpr $ PrimEffect () $ MGet) @@ -1500,12 +1602,14 @@ builtinNames = M.fromList , ("indexRef" , OpExpr $ IndexRef () ()) , ("inject" , OpExpr $ Inject ()) , ("select" , OpExpr $ Select () () ()) - , ("while" , HofExpr $ While () ()) + , ("while" , HofExpr $ While ()) , ("linearize" , HofExpr $ Linearize ()) , ("linearTranspose" , HofExpr $ Transpose ()) , ("runReader" , HofExpr $ RunReader () ()) , ("runWriter" , HofExpr $ RunWriter ()) , ("runState" , HofExpr $ RunState () ()) + , ("runIO" , HofExpr $ RunIO ()) + , ("catchException" , HofExpr $ CatchException ()) , ("tiled" , HofExpr $ Tile 0 () ()) , ("tiledd" , HofExpr $ Tile 1 () ()) , ("TyKind" , TCExpr $ TypeKind) @@ -1514,6 +1618,9 @@ builtinNames = M.fromList , ("Int64" , TCExpr $ BaseType $ Scalar Int64Type) , ("Int32" , TCExpr $ BaseType $ Scalar Int32Type) , ("Word8" , TCExpr $ BaseType $ Scalar Word8Type) + , ("Int32Ptr", TCExpr $ BaseType $ ptrTy $ Scalar Int32Type) + , ("Word8Ptr", TCExpr $ BaseType $ ptrTy $ Scalar Word8Type) + , ("PtrPtr" , TCExpr $ BaseType $ ptrTy $ ptrTy $ Scalar Word8Type) , ("IntRange", TCExpr $ IntRange () ()) , ("Ref" , TCExpr $ RefType (Just ()) ()) , ("PairType", TCExpr $ PairType () ()) @@ -1531,11 +1638,11 @@ builtinNames = M.fromList , ("cast", OpExpr $ CastOp () ()) , ("sliceOffset", OpExpr $ SliceOffset () ()) , ("sliceCurry", OpExpr $ SliceCurry () ()) + , ("alloc", OpExpr $ IOAlloc (Scalar Word8Type) ()) + , ("free" , OpExpr $ IOFree ()) , ("ptrOffset", OpExpr $ PtrOffset () ()) , ("ptrLoad" , OpExpr $ PtrLoad ()) - , ("getPtr" , OpExpr $ GetPtr () ) - , ("makePtrType", OpExpr $ MakePtrType ()) - , ("CharPtr" , ptrTy Word8Type) + , ("ptrStore" , OpExpr $ PtrStore () ()) , ("dataConTag", OpExpr $ DataConTag ()) , ("toEnum" , OpExpr $ ToEnum () ()) ] @@ -1543,8 +1650,7 @@ builtinNames = M.fromList vbinOp op = OpExpr $ VectorBinOp op () () binOp op = OpExpr $ ScalarBinOp op () () unOp op = OpExpr $ ScalarUnOp op () - ptrTy ty = TCExpr $ BaseType $ PtrType $ - (AllocatedPtr, Heap CPU, Scalar ty) + ptrTy ty = PtrType (Heap CPU, ty) instance Store a => Store (PrimOp a) instance Store a => Store (PrimCon a) @@ -1562,7 +1668,8 @@ instance Store Atom instance Store Expr instance Store Block instance Store Decl -instance Store EffectName +instance Store RWS +instance Store Effect instance Store EffectRow instance Store Direction instance Store UnOp @@ -1576,6 +1683,5 @@ instance Store LitVal instance Store ScalarBaseType instance Store BaseType instance Store AddressSpace -instance Store PtrOrigin instance Store Device instance Store DataConRefBinding diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 07e9a8e1e..63b203099 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -1,25 +1,22 @@ --- Copyright 2019 Google LLC +-- Copyright 2020 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RecordWildCards #-} -module TopLevel (evalSourceBlock, evalDecl, evalSource, evalFile, - exportFunctions, EvalConfig (..)) where +module TopLevel (evalSourceBlock, evalDecl, evalSource, evalFile, EvalConfig (..)) where import Control.Monad.State.Strict import Control.Monad.Reader -import Control.Monad.Writer hiding (pass) import Control.Monad.Except hiding (Except) import Data.Text.Prettyprint.Doc import Data.String -import Data.List (partition, nub) -import Data.Time.Clock (getCurrentTime, diffUTCTime) +import Data.List (partition) import qualified Data.Map.Strict as M -import Algebra import Syntax import Embed import Cat @@ -36,7 +33,7 @@ import Logging import LLVMExec import PPrint import Parser -import Util (highlightRegion) +import Util (highlightRegion, measureSeconds) import Optimize import Parallelize @@ -80,7 +77,7 @@ evalSourceBlock opts env block = do Right env' -> return (env' , Result outs' (Right ())) runTopPassM :: Bool -> EvalConfig -> TopPassM a -> IO (Except a, [Output]) -runTopPassM bench opts m = runLogger (logFile opts) $ \logger -> +runTopPassM bench opts m = runLogger (logFile opts) \logger -> runExceptT $ catchIOExcept $ runReaderT m $ TopPassEnv logger bench opts evalSourceBlockM :: TopEnv -> SourceBlock -> TopPassM TopEnv @@ -99,7 +96,7 @@ evalSourceBlockM env block = case sbContents block of logTop $ HtmlOut s ExportFun name -> do f <- evalUModuleVal env v m - void $ traverseLiterals f $ \val -> case val of + void $ traverseLiterals f \val -> case val of PtrLit _ _ -> liftEitherIO $ throw CompilerErr $ "Can't export functions with captured pointers (not implemented)." _ -> return $ Con $ Lit val @@ -107,12 +104,12 @@ evalSourceBlockM env block = case sbContents block of GetType -> do -- TODO: don't actually evaluate it val <- evalUModuleVal env v m logTop $ TextOut $ pprint $ getType val - Dump _ _ -> error "Not implemented" GetNameType v -> case envLookup env (v:>()) of Just (ty, _) -> logTop (TextOut $ pprint ty) >> return mempty _ -> liftEitherIO $ throw UnboundVarErr $ pprint v IncludeSourceFile fname -> do - source <- liftIO $ readFile fname + fullPath <- liftIO $ findSourceFile fname + source <- liftIO $ readFile fullPath evalSourceBlocks env $ parseProg source UnParseable _ s -> liftEitherIO $ throw ParseErr s _ -> return mempty @@ -121,7 +118,7 @@ processLogs :: LogLevel -> [Output] -> [Output] processLogs logLevel logs = case logLevel of LogAll -> logs LogNothing -> [] - LogPasses passes -> flip filter logs $ \l -> case l of + LogPasses passes -> flip filter logs \case PassInfo pass _ | pass `elem` passes -> True | otherwise -> False _ -> False @@ -131,16 +128,18 @@ processLogs logLevel logs = case logLevel of where (compileTime, runTime, benchStats) = timesFromLogs logs timesFromLogs :: [Output] -> (Double, Double, Maybe BenchStats) -timesFromLogs logs = (totalTime - evalTime, evalTime, benchStats) +timesFromLogs logs = (totalTime - totalEvalTime, singleEvalTime, benchStats) where - (evalTime, benchStats) = case [(t, stats) | EvalTime t stats <- logs] of - [] -> (0.0, Nothing) - [(t, stats)] -> (t, stats) - _ -> error "Expect at most one result" + (totalEvalTime, singleEvalTime, benchStats) = + case [(t, stats) | EvalTime t stats <- logs] of + [] -> (0.0 , 0.0, Nothing) + [(t, stats)] -> (total, t , stats) + where total = maybe t snd stats + _ -> error "Expect at most one result" totalTime = case [tTotal | TotalTime tTotal <- logs] of - [] -> 0.0 - [t] -> t - _ -> error "Expect at most one result" + [] -> 0.0 + [t] -> t + _ -> error "Expect at most one result" isLogInfo :: Output -> Bool isLogInfo out = case out of @@ -200,8 +199,8 @@ evalBackend env block = do let (ptrBinders, ptrVals, block') = abstractPtrLiterals block let funcName = "entryFun" let mainName = Name TopFunctionName (fromString funcName) 0 - let cc = case backend of LLVMCUDA -> EntryFun CUDARequired - _ -> EntryFun CUDANotRequired + let (cc, needsSync) = case backend of LLVMCUDA -> (EntryFun CUDARequired , True ) + _ -> (EntryFun CUDANotRequired, False) let (mainFunc, impModuleUnoptimized, reconAtom) = toImpModule env backend cc mainName ptrBinders Nothing block' -- TODO: toImpModule might generate invalid Imp code, because GPU allocations @@ -214,17 +213,15 @@ evalBackend env block = do checkPass ImpPass impModule llvmAST <- liftIO $ impToLLVM logger impModule let IFunType _ _ resultTypes = impFunType $ mainFunc - let llvmEvaluate = if bench then compileAndBench else compileAndEval + let llvmEvaluate = if bench then compileAndBench needsSync else compileAndEval resultVals <- liftM (map (Con . Lit)) $ liftIO $ llvmEvaluate logger llvmAST funcName ptrVals resultTypes return $ applyNaryAbs reconAtom resultVals withCompileTime :: TopPassM a -> TopPassM a withCompileTime m = do - t1 <- liftIO $ getCurrentTime - ans <- m - t2 <- liftIO $ getCurrentTime - logTop $ TotalTime $ realToFrac $ t2 `diffUTCTime` t1 + (ans, t) <- measureSeconds m + logTop $ TotalTime t return ans checkPass :: (Pretty a, Checkable a) => PassName -> a -> TopPassM () @@ -249,97 +246,9 @@ logTop x = do logger <- asks logService logThis logger [x] -type CArgM = WriterT [IBinder] (CatT CArgEnv Embed) -type CArgEnv = (Env IBinder, Env ()) - -runCArg :: CArgEnv -> CArgM a -> Embed (a, [IBinder], CArgEnv) -runCArg initEnv m = repack <$> runCatT (runWriterT m) initEnv - where repack ((ans, cargs), env) = (ans, cargs, env) - -exportFunctions :: FilePath -> [(String, Atom)] -> TopEnv -> EvalConfig -> IO () -exportFunctions objPath funcs env opts = do - let names = fmap fst funcs - unless (length (nub names) == length names) $ liftEitherIO $ - throw CompilerErr "Duplicate export names" - modules <- forM funcs $ \(nameStr, func) -> do - -- Create a module that simulates an application of arguments to the function - let ((dest, cargs), (_, decls)) = flip runEmbed (freeVars func) $ do - (args, cargArgs, cargEnv) <- runCArg mempty $ createArgs $ getType func - resultAtom <- naryApp func args - (resultDest, cdestArgs, _) <- runCArg cargEnv $ createDest mempty $ getType resultAtom - void $ emitTo outputName PlainLet $ Atom resultAtom - return (resultDest, cargArgs <> cdestArgs) - - let coreModule = Module Core decls mempty - let defunctionalized = simplifyModule env coreModule - let Module _ optDecls optBindings = optimizeModule defunctionalized - let (_, LetBound PlainLet outputExpr) = optBindings ! outputName - let block = Block optDecls outputExpr - - let backend = backendName opts - let name = Name TopFunctionName (fromString nameStr) 0 - let (_, impModule, _) = toImpModule env backend CEntryFun name cargs (Just dest) block - llvmAST <- execLogger Nothing $ flip impToLLVM impModule - return (llvmAST, [nameStr]) - exportObjectFile objPath modules - where - outputName = GlobalName "_ans_" - - createArgs :: Type -> CArgM [Atom] - createArgs ty = case ty of - FunTy b Pure result -> do - argSubst <- fmap (\(Bind (n:>bt)) -> Var $ n :> BaseTy bt) <$> looks fst - arg <- createArg $ subst (argSubst, mempty) $ b - (arg:) <$> createArgs result - FunTy _ _ _ -> error $ "Unexpected type for an exported function: " ++ pprint ty - _ -> return [] - - createArg :: Binder -> CArgM Atom - createArg b = case ty of - BaseTy bt@(Scalar _) -> do - ~v@(Var (name:>_)) <- newCVar bt - extend $ asFst $ b @> (Bind $ name :> bt) - return v - TabTy _ _ -> createTabArg mempty ty - _ -> error $ "Unsupported arg type: " ++ pprint ty - where ty = binderType b - - createTabArg :: IndexStructure -> Type -> CArgM Atom - createTabArg idx ty = case ty of - BaseTy bt@(Scalar _) -> do - ptrLoad =<< flip applyIdxs idx =<< newCVar (ptrTy bt) - TabTy b elemTy -> do - buildLam b TabArrow $ \(Var i) -> do - elemTy' <- substEmbed (b@>Var i) elemTy - createTabArg (idx <> Nest (Bind i) Empty) elemTy' - _ -> unsupported - where unsupported = error "Unsupported table type" - - createDest :: IndexStructure -> Type -> CArgM Atom - createDest idx ty = case ty of - BaseTy bt@(Scalar _) -> do - liftM (Con . BaseTypeRef) $ flip applyIdxs idx =<< newCVar (ptrTy bt) - TabTy b elemTy -> do - liftM (Con . TabRef) $ buildLam b TabArrow $ \(Var i) -> do - elemTy' <- substEmbed (b@>Var i) elemTy - createDest (idx <> Nest (Bind i) Empty) elemTy' - _ -> unsupported - where unsupported = error "Unsupported table type" - - -- TODO: I guess that the address space depends on the backend? - -- TODO: Have an ExternalPtr tag? - ptrTy ty = PtrType (DerivedPtr, Heap CPU, ty) - - newCVar :: BaseType -> CArgM Atom - newCVar bt = do - name <- genFresh (Name CArgName "arg" 0) <$> looks snd - extend $ asSnd $ name @> () - tell [Bind $ name :> bt] - return $ Var $ name :> BaseTy bt - abstractPtrLiterals :: Block -> ([IBinder], [LitVal], Block) abstractPtrLiterals block = flip evalState mempty $ do - block' <- traverseLiterals block $ \val -> case val of + block' <- traverseLiterals block \val -> case val of PtrLit ty ptr -> do ptrName <- gets $ M.lookup (ty, ptr) . fst case ptrName of @@ -373,3 +282,7 @@ traverseLiterals block f = traverseAtomLiterals atom = case atom of Con (Lit x) -> lift $ lift $ f x _ -> traverseAtom def atom + +-- TODO: use something like a `DEXPATH` env var for finding source files +findSourceFile :: FilePath -> IO FilePath +findSourceFile fpath = return $ "lib/" ++ fpath diff --git a/src/lib/Type.hs b/src/lib/Type.hs index c2c4faeff..29248533a 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -11,9 +11,9 @@ module Type ( getType, checkType, HasType (..), Checkable (..), litType, isPure, functionEffs, exprEffs, blockEffs, extendEffect, isData, checkBinOp, checkUnOp, - checkIntBaseType, checkFloatBaseType, withBinder, isDependent, + checkIntBaseType, checkFloatBaseType, withBinder, isDependent, checkExtends, indexSetConcreteSize, checkNoShadow, traceCheckM, traceCheck, projectLength, - typeReduceBlock, typeReduceAtom, typeReduceExpr) where + typeReduceBlock, typeReduceAtom, typeReduceExpr, oneEffect) where import Prelude hiding (pi) import Control.Monad @@ -74,7 +74,7 @@ instance Checkable Module where checkValid m@(Module ir decls bindings) = addContext ("Checking module:\n" ++ pprint m) $ asCompilerErr $ do let env = freeVars m - forM_ (envNames env) $ \v -> when (not $ isGlobal $ v:>()) $ + forM_ (envNames env) \v -> when (not $ isGlobal $ v:>()) $ throw CompilerErr $ "Non-global free variable in module: " ++ pprint v addContext "Checking IR variant" $ checkModuleVariant m addContext "Checking body types" $ do @@ -92,7 +92,7 @@ checkBindings env ir bs = void $ runTypeCheck (CheckWith (env <> bs, Pure)) $ mapM_ (checkBinding ir) $ envPairs bs checkBinding :: IRVariant -> (Name, (Type, BinderInfo)) -> TypeM () -checkBinding ir (GlobalName v, b@(ty, info)) = +checkBinding ir (v, b@(ty, info)) | isGlobal (v:>()) = addContext ("binding: " ++ pprint (v, b)) $ do ty |: TyKind when (ir >= Evaluated && not (all isGlobal (envAsVars $ freeVars b))) $ @@ -152,21 +152,21 @@ instance HasType Atom where ACase e alts resultTy -> checkCase e alts resultTy DataConRef ~def@(DataDef _ paramBs [DataConDef _ argBs]) params args -> do checkEq (length paramBs) (length params) - forM_ (zip (toList paramBs) (toList params)) $ \(b, param) -> + forM_ (zip (toList paramBs) (toList params)) \(b, param) -> param |: binderAnn b let argBs' = applyNaryAbs (Abs paramBs argBs) params checkDataConRefBindings argBs' args return $ RawRefTy $ TypeCon def params BoxedRef b ptr numel body -> do - PtrTy (_, _, t) <- typeCheck ptr + PtrTy (_, t) <- typeCheck ptr checkEq (binderAnn b) (BaseTy t) numel |: IdxRepTy void $ typeCheck b withBinder b $ typeCheck body ProjectElt (i NE.:| is) v -> do ty <- typeCheck $ case NE.nonEmpty is of - Nothing -> Var v - Just is' -> ProjectElt is' v + Nothing -> Var v + Just is' -> ProjectElt is' v case ty of TypeCon def params -> do [DataConDef _ bs'] <- return $ applyDataDefParams def params @@ -184,7 +184,8 @@ instance HasType Atom where PairTy x _ | i == 0 -> return x PairTy _ y | i == 1 -> return y Var _ -> throw CompilerErr $ "Tried to project value of unreduced type " <> pprint ty - _ -> throw TypeErr $ "Only single-member ADTs and record types can be projected. Got " <> pprint ty + _ -> throw TypeErr $ + "Only single-member ADTs and record types can be projected. Got " <> pprint ty <> " " <> pprint v checkDataConRefBindings :: Nest Binder -> Nest DataConRefBinding -> TypeM () @@ -202,7 +203,7 @@ typeCheckVar v@(name:>annTy) = do annTy |: TyKind when (annTy == EffKind) $ throw CompilerErr "Effect variables should only occur in effect rows" - checkWithEnv $ \(env, _) -> case envLookup env v of + checkWithEnv \(env, _) -> case envLookup env v of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq annTy ty $ "Annotation on var: " ++ pprint name return annTy @@ -226,19 +227,19 @@ instance HasType Expr where checkCase :: HasType b => Atom -> [AltP b] -> Type -> TypeM Type checkCase e alts resultTy = do - checkWithEnv $ \_ -> do + checkWithEnv \_ -> do ety <- typeCheck e case ety of TypeCon def params -> do let cons = applyDataDefParams def params checkEq (length cons) (length alts) - forM_ (zip cons alts) $ \((DataConDef _ bs'), (Abs bs body)) -> do + forM_ (zip cons alts) \((DataConDef _ bs'), (Abs bs body)) -> do checkEq bs' bs resultTy' <- flip (foldr withBinder) bs $ typeCheck body checkEq resultTy resultTy' VariantTy (NoExt types) -> do checkEq (length types) (length alts) - forM_ (zip (toList types) alts) $ \(ty, (Abs bs body)) -> do + forM_ (zip (toList types) alts) \(ty, (Abs bs body)) -> do [b] <- pure $ toList bs checkEq (getType b) ty resultTy' <- flip (foldr withBinder) bs $ typeCheck body @@ -257,7 +258,7 @@ checkApp fTy x = do return resultTy -- TODO: replace with something more precise (this is too cautious) -blockEffs :: Block -> EffectSummary +blockEffs :: Block -> EffectRow blockEffs (Block decls result) = foldMap declEffs decls <> exprEffs result where declEffs (Let _ _ expr) = exprEffs expr @@ -265,37 +266,46 @@ blockEffs (Block decls result) = isPure :: Expr -> Bool isPure expr = exprEffs expr == mempty -exprEffs :: Expr -> EffectSummary +exprEffs :: Expr -> EffectRow exprEffs expr = case expr of - Atom _ -> NoEffects + Atom _ -> Pure App f _ -> functionEffs f Op op -> case op of PrimEffect ref m -> case m of - MGet -> S.singleton (State, h) - MPut _ -> S.singleton (State, h) - MAsk -> S.singleton (Reader, h) - MTell _ -> S.singleton (Writer, h) + MGet -> oneEffect (RWSEffect State h) + MPut _ -> oneEffect (RWSEffect State h) + MAsk -> oneEffect (RWSEffect Reader h) + MTell _ -> oneEffect (RWSEffect Writer h) where RefTy (Var (h:>_)) _ = getType ref - _ -> NoEffects + ThrowException _ -> oneEffect ExceptionEffect + IOAlloc _ _ -> oneEffect IOEffect + IOFree _ -> oneEffect IOEffect + PtrLoad _ -> oneEffect IOEffect + PtrStore _ _ -> oneEffect IOEffect + FFICall _ _ _ -> oneEffect IOEffect + _ -> Pure Hof hof -> case hof of For _ f -> functionEffs f Tile _ _ _ -> error "not implemented" - While cond body -> functionEffs cond <> functionEffs body + While body -> functionEffs body Linearize _ -> mempty -- Body has to be a pure function Transpose _ -> mempty -- Body has to be a pure function - RunReader _ f -> handleRunner Reader f - RunWriter f -> handleRunner Writer f - RunState _ f -> handleRunner State f + RunReader _ f -> handleRWSRunner Reader f + RunWriter f -> handleRWSRunner Writer f + RunState _ f -> handleRWSRunner State f PTileReduce _ _ -> mempty + RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> + EffectRow (S.delete IOEffect effs) t + CatchException ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> + EffectRow (S.delete ExceptionEffect effs) t Case _ alts _ -> foldMap (\(Abs _ block) -> blockEffs block) alts where - handleRunner effName ~(BinaryFunVal (Bind (h:>_)) _ (EffectRow effs Nothing) _) = - S.delete (effName, h) $ S.fromList effs + handleRWSRunner rws ~(BinaryFunVal (Bind (h:>_)) _ (EffectRow effs t) _) = + EffectRow (S.delete (RWSEffect rws h) effs) t -functionEffs :: Atom -> EffectSummary +functionEffs :: Atom -> EffectRow functionEffs f = case getType f of - Pi (Abs _ (arr, _)) -> S.fromList effs - where EffectRow effs Nothing = arrowEff arr + Pi (Abs _ (arr, _)) -> arrowEff arr _ -> error "Expected a function type" instance HasType Block where @@ -311,7 +321,7 @@ instance HasType Block where instance HasType Binder where typeCheck b = do - checkWithEnv $ \(env, _) -> checkNoShadow env b + checkWithEnv \(env, _) -> checkNoShadow env b let ty = binderType b ty |: TyKind return ty @@ -336,7 +346,7 @@ infixr 7 |: checkEq reqTy ty checkEq :: (Show a, Pretty a, Eq a) => a -> a -> TypeM () -checkEq reqTy ty = checkWithEnv $ \_ -> assertEq reqTy ty "" +checkEq reqTy ty = checkWithEnv \_ -> assertEq reqTy ty "" withBinder :: Binder -> TypeM a -> TypeM a withBinder b m = typeCheck b >> extendTypeEnv (boundVars b) m @@ -399,7 +409,7 @@ instance CoreVariant Expr where Hof e -> checkVariant e >> forM_ e checkVariant Case e alts _ -> do checkVariant e - forM_ alts $ \(Abs _ body) -> checkVariant body + forM_ alts \(Abs _ body) -> checkVariant body instance CoreVariant Decl where -- let annotation restrictions? @@ -421,6 +431,7 @@ instance CoreVariant (PrimTC a) where instance CoreVariant (PrimOp a) where checkVariant e = case e of + ThrowException _ -> goneBy Simp Select _ _ _ -> alwaysAllowed -- TODO: only scalar select after Simp _ -> alwaysAllowed @@ -432,14 +443,16 @@ instance CoreVariant (PrimCon a) where instance CoreVariant (PrimHof a) where checkVariant e = case e of For _ _ -> alwaysAllowed - While _ _ -> alwaysAllowed + While _ -> alwaysAllowed RunReader _ _ -> alwaysAllowed RunWriter _ -> alwaysAllowed RunState _ _ -> alwaysAllowed + RunIO _ -> alwaysAllowed Linearize _ -> goneBy Simp Transpose _ -> goneBy Simp Tile _ _ _ -> alwaysAllowed PTileReduce _ _ -> absentUntil Simp -- really absent until parallelization + CatchException _ -> goneBy Simp -- TODO: namespace restrictions? alwaysAllowed :: VariantM () @@ -459,7 +472,7 @@ goneBy ir = do when (curIR >= ir) $ throw IRVariantErr $ "shouldn't appear after " ++ show ir addExpr :: (Pretty e, MonadError Err m) => e -> m a -> m a -addExpr x m = modifyErr m $ \e -> case e of +addExpr x m = modifyErr m \e -> case e of Err IRVariantErr ctx s -> Err CompilerErr ctx (s ++ ": " ++ pprint x) _ -> e @@ -467,19 +480,20 @@ addExpr x m = modifyErr m $ \e -> case e of checkEffRow :: EffectRow -> TypeM () checkEffRow (EffectRow effs effTail) = do - forM_ effs $ \(_, v) -> Var (v:>TyKind) |: TyKind - forM_ effTail $ \v -> do - checkWithEnv $ \(env, _) -> case envLookup env (v:>()) of + forM_ effs \eff -> case eff of + RWSEffect _ v -> Var (v:>TyKind) |: TyKind + ExceptionEffect -> return () + IOEffect -> return () + forM_ effTail \v -> do + checkWithEnv \(env, _) -> case envLookup env (v:>()) of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq EffKind ty "Effect var" -declareEff :: (EffectName, Maybe Name) -> TypeM () -declareEff (effName, Just h) = - declareEffs $ EffectRow [(effName, h)] Nothing -declareEff (_, Nothing) = return () +declareEff :: Effect -> TypeM () +declareEff eff = declareEffs $ oneEffect eff declareEffs :: EffectRow -> TypeM () -declareEffs effs = checkWithEnv $ \(_, allowedEffects) -> +declareEffs effs = checkWithEnv \(_, allowedEffects) -> checkExtends allowedEffects effs checkExtends :: MonadError Err m => EffectRow -> EffectRow -> m () @@ -488,20 +502,23 @@ checkExtends allowed (EffectRow effs effTail) = do case effTail of Just _ -> assertEq allowedEffTail effTail "" Nothing -> return () - forM_ effs $ \eff -> unless (eff `elem` allowedEffs) $ + forM_ effs \eff -> unless (eff `elem` allowedEffs) $ throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ "\nAllowed: " ++ pprint allowed extendEffect :: Effect -> EffectRow -> EffectRow -extendEffect eff (EffectRow effs t) = EffectRow (eff:effs) t +extendEffect eff (EffectRow effs t) = EffectRow (S.insert eff effs) t + +oneEffect :: Effect -> EffectRow +oneEffect eff = EffectRow (S.singleton eff) Nothing -- === labeled row types === checkLabeledRow :: ExtLabeledItems Type Name -> TypeM () checkLabeledRow (Ext items rest) = do mapM_ (|: TyKind) items - forM_ rest $ \v -> do - checkWithEnv $ \(env, _) -> case envLookup env (v:>()) of + forM_ rest \v -> do + checkWithEnv \(env, _) -> case envLookup env (v:>()) of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq LabeledRowKind ty "Labeled row var" @@ -511,7 +528,7 @@ labeledRowDifference :: ExtLabeledItems Type Name labeledRowDifference (Ext (LabeledItems items) rest) (Ext (LabeledItems subitems) subrest) = do -- Check types in the right. - _ <- flip M.traverseWithKey subitems $ \label subtypes -> + _ <- flip M.traverseWithKey subitems \label subtypes -> case M.lookup label items of Just types -> assertEq subtypes (NE.fromList $ NE.take (length subtypes) types) $ @@ -539,7 +556,7 @@ checkWithEnv check = do CheckWith env -> check env updateTypeEnv :: (TypeEnv -> TypeEnv) -> TypeM a -> TypeM a -updateTypeEnv f m = flip local m $ fmap $ \(env, eff) -> (f env, eff) +updateTypeEnv f m = flip local m $ fmap \(env, eff) -> (f env, eff) extendTypeEnv :: TypeEnv -> TypeM a -> TypeM a extendTypeEnv new m = updateTypeEnv (<> new) m @@ -551,7 +568,7 @@ extendAllowedEffect :: Effect -> TypeM () -> TypeM () extendAllowedEffect eff m = updateAllowedEff (extendEffect eff) m updateAllowedEff :: (EffectRow -> EffectRow) -> TypeM a -> TypeM a -updateAllowedEff f m = flip local m $ fmap $ \(env, eff) -> (env, f eff) +updateAllowedEff f m = flip local m $ fmap \(env, eff) -> (env, f eff) withAllowedEff :: EffectRow -> TypeM a -> TypeM a withAllowedEff eff m = updateAllowedEff (const eff) m @@ -583,7 +600,7 @@ typeCheckCon con = case con of IndexRangeVal t l h i -> i|:IdxRepTy >> return (TC $ IndexRange t l h) IndexSliceVal _ _ _ -> error "not implemented" BaseTypeRef p -> do - (PtrTy (_, _, b)) <- typeCheck p + (PtrTy (_, b)) <- typeCheck p return $ RawRefTy $ BaseTy b TabRef tabTy -> do TabTy b (RawRefTy a) <- typeCheck tabTy @@ -636,7 +653,11 @@ checkFloatBaseType allowVector t = case t of "floating-point type, but found: " ++ pprint t checkValidCast :: Type -> Type -> TypeM () -checkValidCast sourceTy destTy = checkScalarType sourceTy >> checkScalarType destTy +checkValidCast (BaseTy (PtrType _)) (BaseTy (PtrType _)) = return () +checkValidCast (BaseTy (PtrType _)) (BaseTy (Scalar Int64Type)) = return () +checkValidCast (BaseTy (Scalar Int64Type)) (BaseTy (PtrType _)) = return () +checkValidCast sourceTy destTy = + checkScalarType sourceTy >> checkScalarType destTy where checkScalarType ty = case ty of BaseTy (Scalar Int64Type ) -> return () @@ -666,12 +687,13 @@ typeCheckOp op = case op of ToOrdinal i -> typeCheck i $> IdxRepTy IdxSetSize i -> typeCheck i $> IdxRepTy FFICall _ ansTy args -> do - forM_ args $ \arg -> do + forM_ args \arg -> do argTy <- typeCheck arg case argTy of BaseTy _ -> return () _ -> throw TypeErr $ "All arguments of FFI calls have to be " ++ "fixed-width base types, but got: " ++ pprint argTy + declareEff IOEffect return ansTy Inject i -> do TC tc <- typeCheck i @@ -680,15 +702,12 @@ typeCheckOp op = case op of ParIndexRange ty _ _ -> return ty _ -> throw TypeErr $ "Unsupported inject argument type: " ++ pprint (TC tc) PrimEffect ref m -> do - TC (RefType h s) <- typeCheck ref - let h'' = case h of - Just ~(Var (h':>TyKind)) -> Just h' - Nothing -> Nothing + TC (RefType ~(Just (Var (h':>TyKind))) s) <- typeCheck ref case m of - MGet -> declareEff (State , h'') $> s - MPut x -> x|:s >> declareEff (State , h'') $> UnitTy - MAsk -> declareEff (Reader, h'') $> s - MTell x -> x|:s >> declareEff (Writer, h'') $> UnitTy + MGet -> declareEff (RWSEffect State h') $> s + MPut x -> x|:s >> declareEff (RWSEffect State h') $> UnitTy + MAsk -> declareEff (RWSEffect Reader h') $> s + MTell x -> x|:s >> declareEff (RWSEffect Writer h') $> UnitTy IndexRef ref i -> do RefTy h (TabTyAbs a) <- typeCheck ref i |: absArgType a @@ -699,17 +718,27 @@ typeCheckOp op = case op of SndRef ref -> do RefTy h (PairTy _ b) <- typeCheck ref return $ RefTy h b + IOAlloc t n -> do + n |: IdxRepTy + declareEff IOEffect + return $ PtrTy (Heap CPU, t) + IOFree ptr -> do + PtrTy _ <- typeCheck ptr + declareEff IOEffect + return UnitTy PtrOffset arr off -> do - PtrTy (_, a, b) <- typeCheck arr + PtrTy (a, b) <- typeCheck arr off |: IdxRepTy - return $ PtrTy (DerivedPtr, a, b) + return $ PtrTy (a, b) PtrLoad ptr -> do - PtrTy (_, _, t) <- typeCheck ptr + PtrTy (_, t) <- typeCheck ptr + declareEff IOEffect return $ BaseTy t - GetPtr tab -> do - TabTy _ (BaseTy a) <- typeCheck tab - return $ BaseTy $ PtrType (AllocatedPtr, Heap CPU, a) - MakePtrType ty -> ty|:TyKind >> return TyKind + PtrStore ptr val -> do + PtrTy (_, t) <- typeCheck ptr + val |: BaseTy t + declareEff IOEffect + return $ UnitTy SliceOffset s i -> do TC (IndexSlice n l) <- typeCheck s l' <- typeCheck i @@ -733,7 +762,9 @@ typeCheckOp op = case op of i |: TC (IntRange (IdxRepVal 0) (IdxRepVal $ fromIntegral vectorWidth)) return $ BaseTy $ Scalar sb ThrowError ty -> ty|:TyKind $> ty - -- TODO: this should really be a 32 bit integer for unicode code point: but for now is 8 bit ASCII code point + ThrowException ty -> do + declareEff ExceptionEffect + ty|:TyKind $> ty CastOp t@(Var _) _ -> t |: TyKind $> t CastOp destTy e -> do sourceTy <- typeCheck e @@ -784,7 +815,7 @@ typeCheckOp op = case op of t |: TyKind x |: Word8Ty (TypeCon (DataDef _ _ dataConDefs) _) <- return t - forM_ dataConDefs $ \(DataConDef _ binders) -> + forM_ dataConDefs \(DataConDef _ binders) -> assertEq binders Empty "Not an enum" return t @@ -833,13 +864,10 @@ typeCheckHof hof = case hof of checkEq threadRange (binderType threadRange') -- PTileReduce n mapping : (n=>a, ro) return $ PairTy (TabTy (Ignore n) tileElemTy) accTy - While cond body -> do - Pi (Abs (Ignore UnitTy) (arr , condTy)) <- typeCheck cond - Pi (Abs (Ignore UnitTy) (arr', bodyTy)) <- typeCheck body + While body -> do + Pi (Abs (Ignore UnitTy) (arr , condTy)) <- typeCheck body declareEffs $ arrowEff arr - declareEffs $ arrowEff arr' checkEq (BaseTy $ Scalar Word8Type) condTy - checkEq UnitTy bodyTy return UnitTy Linearize f -> do Pi (Abs (Ignore a) (PlainArrow Pure, b)) <- typeCheck f @@ -848,21 +876,31 @@ typeCheckHof hof = case hof of Pi (Abs (Ignore a) (LinArrow, b)) <- typeCheck f return $ b --@ a RunReader r f -> do - (resultTy, readTy) <- checkAction Reader f + (resultTy, readTy) <- checkRWSAction Reader f r |: readTy return resultTy - RunWriter f -> uncurry PairTy <$> checkAction Writer f + RunWriter f -> uncurry PairTy <$> checkRWSAction Writer f RunState s f -> do - (resultTy, stateTy) <- checkAction State f + (resultTy, stateTy) <- checkRWSAction State f s |: stateTy return $ PairTy resultTy stateTy - -checkAction :: EffectName -> Atom -> TypeM (Type, Type) -checkAction effName f = do + RunIO f -> do + FunTy b eff resultTy <- typeCheck f + checkEq (binderAnn b) UnitTy + extendAllowedEffect IOEffect $ declareEffs eff + return resultTy + CatchException f -> do + FunTy b eff resultTy <- typeCheck f + checkEq (binderAnn b) UnitTy + extendAllowedEffect ExceptionEffect $ declareEffs eff + return $ MaybeTy resultTy + +checkRWSAction :: RWS -> Atom -> TypeM (Type, Type) +checkRWSAction rws f = do BinaryFunTy (Bind regionBinder) refBinder eff resultTy <- typeCheck f regionName:>_ <- return regionBinder let region = Var regionBinder - extendAllowedEffect (effName, regionName) $ declareEffs eff + extendAllowedEffect (RWSEffect rws regionName) $ declareEffs eff checkEq (varAnn regionBinder) TyKind RefTy region' referentTy <- return $ binderAnn refBinder checkEq region' region @@ -1042,9 +1080,4 @@ typeReduceExpr scope expr = case expr of typeReduceBlock scope $ subst (b@>x', scope) block TypeCon con xs -> Just $ TypeCon con $ xs ++ [x'] _ -> Nothing - Op (MakePtrType ty) -> do - let ty' = typeReduceAtom scope ty - case ty' of - BaseTy b -> return $ PtrTy (AllocatedPtr, Heap CPU, b) - _ -> Nothing _ -> Nothing diff --git a/src/lib/Util.hs b/src/lib/Util.hs index b85fee9fd..eb405a1c0 100644 --- a/src/lib/Util.hs +++ b/src/lib/Util.hs @@ -12,6 +12,7 @@ module Util (IsBool (..), group, ungroup, pad, padLeft, delIdx, replaceIdx, scanM, composeN, mapMaybe, uncons, repeated, transitiveClosure, showErr, listDiff, splitMap, enumerate, restructure, onSnd, onFst, highlightRegion, findReplace, swapAt, uncurry3, + measureSeconds, bindM2, foldMapM, lookupWithIdx, (...), zipWithT, for) where import Data.Functor.Identity (Identity(..)) @@ -21,6 +22,7 @@ import Prelude import qualified Data.Set as Set import qualified Data.Map.Strict as M import Control.Monad.State.Strict +import System.CPUTime import Cat @@ -232,3 +234,10 @@ transitiveClosure getParents seeds = unless (x `Set.member` visited) $ do extend $ Set.singleton x mapM_ go $ getParents x + +measureSeconds :: MonadIO m => m a -> m (a, Double) +measureSeconds m = do + t1 <- liftIO $ getCPUTime + ans <- m + t2 <- liftIO $ getCPUTime + return (ans, (fromIntegral $ t2 - t1) / 1e12) diff --git a/src/lib/dexrt.cpp b/src/lib/dexrt.cpp index c3ef5780b..be89d4028 100644 --- a/src/lib/dexrt.cpp +++ b/src/lib/dexrt.cpp @@ -38,6 +38,10 @@ void free_dex(char* ptr) { free(ptr); } +void* fdopen_w(int fd) { + return fdopen(fd, "w"); +} + uint32_t rotate_left(uint32_t x, uint32_t d) { return (x << d) | (x >> (32 - d)); } @@ -144,9 +148,9 @@ double randunif(uint64_t keypair) { return out - 1; } -void showHex(char **resultPtr, int x) { +void showHex(char **resultPtr, char x) { auto p = reinterpret_cast(malloc_dex(100)); // TODO: something better! - auto n = sprintf(p, "%02x", x); + auto n = sprintf(p, "%02hhX", x); auto result1Ptr = reinterpret_cast(resultPtr[0]); auto result2Ptr = reinterpret_cast( resultPtr[1]); *result1Ptr = n; diff --git a/src/resources/Resources.hs b/src/resources/Resources.hs index cb9c8cd3b..d834767e1 100644 --- a/src/resources/Resources.hs +++ b/src/resources/Resources.hs @@ -1,6 +1,6 @@ {-# LANGUAGE TemplateHaskell #-} -module Resources (dexrtBC, preludeSource, curResourceVersion) where +module Resources (dexrtBC, preludeSource, cssSource, curResourceVersion) where import qualified Data.ByteString.Char8 as B import Data.FileEmbed @@ -11,5 +11,10 @@ curResourceVersion = __TIME__ dexrtBC :: B.ByteString dexrtBC = $(embedFile "src/lib/dexrt.bc") +-- The Dex prelude source code. preludeSource :: String -preludeSource = B.unpack $ $(embedFile "prelude.dx") +preludeSource = B.unpack $(embedFile "lib/prelude.dx") + +-- The source code of the CSS used for rendering Dex programs as HTML. +cssSource :: String +cssSource = B.unpack $(embedFile "static/style.css") \ No newline at end of file diff --git a/stack-macos.yaml b/stack-macos.yaml index fbc7107e1..c14681f7f 100644 --- a/stack-macos.yaml +++ b/stack-macos.yaml @@ -10,8 +10,11 @@ packages: - . extra-deps: - - llvm-hs-9.0.1 - - llvm-hs-pure-9.0.0 + - github: apaszke/llvm-hs + commit: a9a74be1a7c15f3d21b2fffd35a425002ae7736f + subdirs: + - llvm-hs + - llvm-hs-pure - megaparsec-8.0.0 - prettyprinter-1.6.2 - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 diff --git a/stack.yaml b/stack.yaml index 1d5bae6ae..445dd9ffd 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,4 +1,4 @@ -# Copyright 2019 Google LLC +# Copyright 2020 Google LLC # # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file or at @@ -10,8 +10,11 @@ packages: - . extra-deps: - - llvm-hs-9.0.1 - - llvm-hs-pure-9.0.0 + - github: apaszke/llvm-hs + commit: a9a74be1a7c15f3d21b2fffd35a425002ae7736f + subdirs: + - llvm-hs + - llvm-hs-pure - megaparsec-8.0.0 - prettyprinter-1.6.2 - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 diff --git a/static/dynamic.js b/static/dynamic.js index 41f5a8634..6e6efcb13 100644 --- a/static/dynamic.js +++ b/static/dynamic.js @@ -4,6 +4,18 @@ // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd +var katexOptions = { + delimiters: [ + {left: "$$", right: "$$", display: true}, + {left: "\\[", right: "\\]", display: true}, + {left: "$", right: "$", display: false}, + {left: "\\(", right: "\\)", display: false} + ], + // Enable commands that load resources or change HTML attributes + // (e.g. hyperlinks): https://katex.org/docs/security.html. + trust: true +}; + var cells = {}; function append_contents(key, contents) { @@ -65,4 +77,6 @@ source.onmessage = function(event) { } Object.assign(cells, new_cells); } + // Render LaTeX equations via KaTeX. + renderMathInElement(body, katexOptions); }; diff --git a/static/index.html b/static/index.html index 5084094db..d1774f2ec 100644 --- a/static/index.html +++ b/static/index.html @@ -4,7 +4,12 @@ Dex Output + + + + + diff --git a/static/style.css b/static/style.css index 77a7ce208..f978675d4 100644 --- a/static/style.css +++ b/static/style.css @@ -11,7 +11,6 @@ body { font-family: Helvetica, sans-serif; font-size: 100%; color: #333; - padding-bottom: 500px; } .cell { diff --git a/tests/GenExpr.hs b/tests/GenExpr.hs deleted file mode 100644 index 53e8a0aef..000000000 --- a/tests/GenExpr.hs +++ /dev/null @@ -1,366 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module GenExpr (sampleExpr, defaultGenOptions, GenOptions (..) - , testSampleExpr, testSample, makeGenEnv, genSourceBlock, - genUTopDecl, TypeEnv) where - -import Control.Monad -import Control.Monad.Reader -import GHC.Float -import Hedgehog hiding (Var, Command) -import Hedgehog.Internal.Shrink (towards) -import qualified Hedgehog.Gen as Gen -import qualified Hedgehog.Range as Range -import qualified Data.Map.Strict as M -import qualified Data.Set as S -import Lens.Micro.Platform -import Data.Text.Prettyprint.Doc -import Data.String - -import Record -import Env -import Syntax -import PPrint - -testSample :: (Pretty a) => TypeEnv -> GenM a -> Range.Size -> IO () -testSample env m s = - Gen.sample (runReaderT (Gen.resize s (pprint <$> m)) - (makeGenEnv env defaultGenOptions)) - >>= putStrLn - -testSampleExpr :: Int -> IO () -testSampleExpr n = testSample mempty sampleExpr (fromIntegral n) - - --- Variable names associated with a type -type TypeEnv = M.Map SigmaType [Name] - --- Variable names in scope -type ScopeEnv = S.Set Name - -data GenOptions = GenOptions { - tableSize :: Int - , numberSize :: Int - , tupleSize :: Int - , returnTypePref :: Int - } - deriving (Show, Eq, Ord) - -defaultGenOptions :: GenOptions -defaultGenOptions = GenOptions { - tableSize = 10 - , numberSize = 10 - , tupleSize = 5 - , returnTypePref = 2 - } - -data GenEnv = GenEnv { - typeEnv :: TypeEnv - , scopeEnv :: ScopeEnv - , optsEnv :: GenOptions} - deriving (Show, Eq, Ord) - - -makeGenEnv :: TypeEnv -> GenOptions -> GenEnv -makeGenEnv te opts = GenEnv te (S.fromList (concat (M.elems te))) opts - --- lens -typeEnvL :: Lens' GenEnv TypeEnv -typeEnvL = lens typeEnv (\e t -> e{typeEnv = t}) -scopeEnvL :: Lens' GenEnv ScopeEnv -scopeEnvL = lens scopeEnv (\e s -> e{scopeEnv = s}) -optsEnvL :: Lens' GenEnv GenOptions -optsEnvL = lens optsEnv (\e s -> e{optsEnv = s}) -tableSizeL :: Lens' GenEnv Int -tableSizeL = optsEnvL . lens tableSize (\e t -> e{tableSize=t}) -numberSizeL :: Lens' GenEnv Int -numberSizeL = optsEnvL . lens numberSize (\e t -> e{numberSize = t}) -tupleSizeL :: Lens' GenEnv Int -tupleSizeL = optsEnvL . lens tupleSize (\e t -> e{tupleSize = t}) -returnTypePrefL :: Lens' GenEnv Int -returnTypePrefL = optsEnvL . lens returnTypePref (\e t -> e{returnTypePref = t}) - --- utils -setBinding :: (Name, SigmaType) -> GenEnv -> GenEnv -setBinding (v, ty) = - over (typeEnvL . at ty) ((Just [v]) `mappend`) . over scopeEnvL (S.insert v) -setBinding' :: (Name, Type) -> GenEnv -> GenEnv -setBinding' (v, ty) = setBinding (v, Forall [] ty) -setBindings' :: [(Name, Type)] -> GenEnv -> GenEnv -setBindings' vs = foldl (.) id (setBinding' <$> vs) -withBindings :: [(Name, Type)] -> GenM a -> GenM a -withBindings vs g = local (setBindings' vs) g -notShadowed :: Name -> GenM Bool -notShadowed n = view (scopeEnvL . to (S.notMember n)) - -genUntil :: MonadGen m => (a -> m Bool) -> m a -> m a -genUntil f gen = do - x <- gen - isValid <- f x - if isValid then return x else genUntil f gen - -small :: MonadGen m => m a -> m a -small = Gen.scale (`div` 2) - -genSized :: MonadGen m => m a -> m a -> m a -genSized leaf tree = Gen.sized (\n -> if n == 0 then leaf else tree) - -element :: MonadGen m => [a] -> m a -element = Gen.prune . Gen.element - -prefer :: MonadGen m => Int -> m a -> m a -> m a -prefer w p r = Gen.prune (Gen.frequency [(w, p), (1, r)]) - -type GenM a = ReaderT GenEnv Gen a - -allTypes :: Type -> [Type] -allTypes ty = ty : case ty of - ArrType _ t1 t2 -> allTypes t1 ++ allTypes t2 - TabType _ t -> allTypes t - RecType _ ~(Tup ts) -> concatMap allTypes ts - _ -> [] - -preferReturnType :: Type -> GenM Type -> GenM Type -preferReturnType ty b = view returnTypePrefL >>= (\n -> prefer n (element (allTypes ty)) b) - - --- type utils -checkData :: Type -> Bool -checkData ty = case ty of - BaseType _ -> True - TabType _ a -> checkData a - RecType _ r -> all checkData r - IdxSetLit _ -> True - _ -> False - - --- | TODO: StrType -genBaseType :: GenM BaseType -genBaseType = element [IntType, BoolType, RealType] - - -genRecTypeWith :: GenM a -> GenM (Record a) -genRecTypeWith g = Tup <$> record - where - record = view tupleSizeL >>= \n -> Gen.list (Range.linear 2 n) (small g) - -genRecType :: GenM (Record Type) -genRecType = genRecTypeWith genType - -genTabTypeWith :: GenM Type -> GenM Type -genTabTypeWith g = liftM2 TabType genIdxSet (small g) - -genTabType :: GenM Type -genTabType = genTabTypeWith genType - --- types belong to the Data class -genDataType :: GenM Type -genDataType = genSized leaf tree - where - leaf = BaseType <$> genBaseType - tree = Gen.frequency [ - (1, leaf) - , (2, (RecType Cart) <$> genRecTypeWith genDataType) - , (2, genTabTypeWith genDataType) - ] - - -genIdxSet :: GenM IdxSet -genIdxSet = genSized leaf tree - where - lit = view tableSizeL >>= \n -> Gen.integral_ (Range.constant 1 n) - leaf = IdxSetLit <$> lit - tree = Gen.frequency [ - (1, leaf) - -- Tuple index has not been implemented in JIT - -- , (2, RecType <$> genRecTypeWith genIdxSet) - ] - --- TODO: TypeVar, Exists, BoundTVar. -genLeafType :: GenM Type -genLeafType = BaseType <$> genBaseType - --- TODO: Linear type, Tens -genTreeType :: GenM Type -genTreeType = Gen.choice [ - genLeafType - , arr - , genTabType - , (RecType Cart) <$> genRecType - ] - where - sub = small genType - arr = liftM2 (ArrType (Mult NonLin)) sub sub - - -genType :: GenM Type -genType = Gen.shrink shrinkType $ Gen.prune (genSized genLeafType genTreeType) - -shrinkType :: Type -> [Type] -shrinkType = tail . shrinkLis - where - shrinkLis :: Type -> [Type] - shrinkLis ty = case ty of - ArrType lin t1 t2 -> - -- TODO: generate smaller list - liftM2 (ArrType lin) (shrinkLis t1) (shrinkLis t2) ++ shrinkType t1 - TabType idx t -> - liftM2 TabType (shrinkLis idx) (shrinkLis t) ++ shrinkType t - (IdxSetLit n) -> IdxSetLit <$> towards n 1 - _ -> [ty] - - -genPatP :: (Type -> Ann) -> Type -> GenM (UPat, [(Name, Type)]) -genPatP ann ty = case ty of - (RecType _ (Tup as)) -> Gen.frequency [(1, genLeafPat), (2, genTupPat as)] - _ -> genLeafPat - where - genLeafPat = do - n <- genName - return (RecLeaf (n :> (ann ty)), [(n, ty)]) - genTreePat :: [Type] -> GenM ([UPat], [(Name, Type)]) - genTreePat [] = return ([], []) - genTreePat (t:ts) = do - (p1, vs1) <- genPatP ann t - (restp, restv) <- withBindings vs1 (genTreePat ts) -- make sure names are unique - return (p1:restp, vs1 ++ restv) - genTupPat :: [Type] -> GenM (UPat, [(Name, Type)]) - genTupPat ts = do - (ps, vs) <- genTreePat ts - return (RecTree (Tup ps), vs) - - --- | variable or literal value --- -genLit :: BaseType -> GenM (ExprP b) -genLit ty = Lit <$> case ty of - IntType -> - view numberSizeL >>= \n -> IntLit <$> Gen.integral_ (Range.constant (negate n) n) - BoolType -> BoolLit <$> Gen.bool_ - RealType -> do - n <- view (numberSizeL . to fromIntegral) - (RealLit . roundTripDouble) <$> Gen.realFrac_ (Range.constant (negate n) n) - StrType -> error "Str type not implemented" - --- TODO: remove this once we have more control over precision of printed floats -roundTripDouble :: Double -> Double -roundTripDouble x = read (show (double2Float x)) - -genName :: GenM Name -genName = Gen.prune (genUntil notShadowed (fromString <$> str)) - where - strLen = Range.constant 0 5 - strTail = Gen.frequency [(10, Gen.alphaNum), (1, return '\'')] - str = liftM2 (:) Gen.lower (Gen.list strLen strTail) - -genVars :: Type -> GenM [ExprP b] -genVars t = view (typeEnvL . at (Forall [] t) . to (maybe [] id) . to (map (flip Var []))) - -withVars :: Type -> GenM (ExprP b) -> GenM (ExprP b) -withVars t g = do - vs <- genVars t - e <- g - if null vs - then return e - else prefer 3 (Gen.element vs) (return e) -- preference to variable - --- TODO: Linear type -genLam :: Type -> Type -> GenM UExpr -genLam a b = do - (pat, env) <- genPatP Ann a - body <- withBindings env (genExpr b) - return (Lam (Ann (Mult NonLin)) pat body) - - --- table -genTabCon :: Int -> Type -> [GenM UExpr] -genTabCon n ty - | checkData ty = [TabCon NoAnn <$> replicateM n (small (genExpr ty))] - | otherwise = [] - -genFor :: Type -> Type -> GenM UExpr -genFor a b = do - (pat, env) <- small (genPatP Ann a) - body <- withBindings env (small (genExpr b)) - return (For pat body) - -genTable :: IdxSet -> Type -> GenM UExpr -genTable ty@(IdxSetLit n) b = Gen.choice (genFor ty b : genTabCon n b) -genTable ty b = genFor ty b - --- TODO: LetPoly, TAlias, Unpack -genDecl :: Type -> GenM UExpr -genDecl ty = do - -- preference over return type to increase variable usage - declTy <- small (preferReturnType ty genType) - declExpr <- small (genExpr declTy) - (declPat, env) <- small (genPatP (const NoAnn) declTy) - body <- small (withBindings env (genExpr ty)) - return (Decl (LetMono declPat declExpr) body) - -genGet :: Type -> GenM UExpr -genGet ty = do - idxty <- small genIdxSet - idx <- small (genExpr idxty) - body <- small (genExpr (TabType idxty ty)) - return (Get body idx) - - -genApp :: Type -> GenM UExpr -genApp ty = do - argty <- small (preferReturnType ty genType) - fun <- small (genExpr (ArrType (Mult NonLin) argty ty)) - arg <- small (genExpr argty) - return (App fun arg) - --- TODO: Tens -genRecCon :: Record Type -> GenM UExpr -genRecCon ~(Tup ts) = RecCon Cart <$> Tup <$> traverse (small . genExpr) ts - - -genLeafExpr :: Type -> GenM UExpr -genLeafExpr ty = withVars ty $ case ty of - BaseType t -> genLit t - ArrType _ t1 t2 -> genLam t1 t2 - TabType i t -> genTable i t - RecType _ rt -> genRecCon rt - IdxSetLit n -> do - val <- Gen.integral_ (Range.constant 0 (n - 1)) - return $ Annot (PrimOp IntAsIndex [] [Lit (IntLit val)]) ty - _ -> undefined - -genTreeExpr :: Type -> GenM UExpr -genTreeExpr ty = Gen.choice $ case ty of - BaseType{} -> commons - ArrType _ t1 t2 -> genLam t1 t2 : commons - TabType i t -> genTable i t : commons - RecType _ rt -> genRecCon rt : commons - _ -> commons - where - commons = [ - genDecl ty - , genApp ty - , genGet ty - ] - -genExpr :: Type -> GenM UExpr -genExpr ty = genSized (genLeafExpr ty) (genTreeExpr ty) - -sampleExpr :: GenM UExpr -sampleExpr = do - ty <- genDataType - genExpr ty - - -genUTopDecl :: GenM UTopDecl -genUTopDecl = (EvalCmd . Command (EvalExpr Printed)) <$> sampleExpr - -genSourceBlock :: GenM SourceBlock -genSourceBlock = do - topdecl <- UTopDecl <$> genUTopDecl - case topdecl of - ~(UTopDecl (EvalCmd (Command _ e))) -> return (SourceBlock 0 0 (pprint e) topdecl) diff --git a/tests/PropTests.hs b/tests/PropTests.hs deleted file mode 100644 index a149f787a..000000000 --- a/tests/PropTests.hs +++ /dev/null @@ -1,52 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE OverloadedStrings #-} - -import Control.Monad -import qualified Hedgehog as H -import Control.Monad.Reader -import qualified Data.Map.Strict as M - -import Syntax hiding (Result) -import Parser -import PPrint -import GenExpr -import TestPass - -main :: IO () -main = void tests - -prop_jitEval :: TypeEnv -> Evaluator -> Evaluator -> H.Property -prop_jitEval tenv jit interp = - H.property $ do - srcBlk <- H.forAllWith pprint (runReaderT genSourceBlock (makeGenEnv tenv defaultGenOptions)) - interres <- H.evalIO (interp srcBlk) - H.annotate ("Interpreter result: " ++ pprint interres) - jitres <- H.evalIO (jit srcBlk) - pprint interres H.=== pprint jitres - - -getExpr :: TopDeclP b -> ExprP b -getExpr ~(EvalCmd (Command _ e)) = e - -prop_pprint :: H.Property -prop_pprint = - H.property $ do - expr <- H.forAllWith pprint (runReaderT sampleExpr (makeGenEnv mempty defaultGenOptions)) - H.tripping expr pprintEsc (\s -> (getExpr . stripSrcAnnotTopDecl) <$> parseTopDecl s) - -tests :: IO Bool -tests = do - let prelude = "prelude.dx" - jit <- runTestPass prelude fullPassJit - interp <- runTestPass prelude fullPassInterp - preludeEnv <- loadTypeEnv prelude - let tyEnv = M.fromListWith (++) [(ty, [name]) | (ty, name) <- preludeEnv] - H.checkParallel $ H.Group "TypeCheck" [ - ("prop_jitEval", prop_jitEval tyEnv jit interp) - , ("prop_pprint", prop_pprint) - ] diff --git a/tests/TestPass.hs b/tests/TestPass.hs deleted file mode 100644 index c677709e1..000000000 --- a/tests/TestPass.hs +++ /dev/null @@ -1,76 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module TestPass (typeCheckPass, fullPassInterp, fullPassJit, - runTestPass, Evaluator, loadTypeEnv) where - -import Data.Void -import Control.Monad.State.Strict -import qualified Data.Map.Strict as M -import Unsafe.Coerce - -import Pass -import DeShadow -import Inference -import Imp -import Syntax -import Type -import JIT -import Flops -import Normalize -import Simplify -import Interpreter -import Parser -import Env - -typeCheckPass :: TopPass SourceBlock TopDecl -typeCheckPass = sourcePass >+> deShadowPass >+> typePass >+> checkTyped - -fullPassInterp :: TopPass SourceBlock Void -fullPassInterp = typeCheckPass >+> interpPass - -fullPassJit :: TopPass SourceBlock Void -fullPassJit = typeCheckPass >+> normalizePass >+> checkNExpr - >+> derivPass >+> checkNExpr - >+> simpPass >+> checkNExpr - >+> impPass >+> checkImp - >+> flopsPass - >+> jitPass - - -type TestFullPass env b = SourceBlock -> TopPassM env b - -evalDecl :: Monoid env => TestFullPass env b -> SourceBlock -> StateT env IO () -evalDecl pass block = do - env <- get - (_, env') <- liftIO (runTopPassM env (pass block)) - modify (<> env') - -loadFile :: (Monoid env) => String -> TestFullPass env b -> IO env -loadFile fname pass = do - source <- readFile fname - let sourceBlocks = parseProg source - execStateT (mapM (evalDecl pass) sourceBlocks) mempty - -type Evaluator = SourceBlock -> IO Result' - -runTestPass :: String -> TopPass SourceBlock Void -> IO Evaluator -runTestPass fname (TopPass pass) = do - env <- loadFile fname pass - let eval source = do - ~(Left res, _) <- runTopPassM env (pass source) - return res - return eval - - -loadTypeEnv :: String -> IO [(SigmaType, Name)] -loadTypeEnv fname = - case sourcePass >+> deShadowPass >+> typePass of - TopPass pass -> do - envs <- loadFile fname pass - let env = (snd (unsafeCoerce envs)) :: TypeEnv - return $ case env of - Env m -> [(ty, name) | (name, L ty) <- M.toList m] diff --git a/tests/actor-test.hs b/tests/actor-test.hs deleted file mode 100644 index 4f9277298..000000000 --- a/tests/actor-test.hs +++ /dev/null @@ -1,69 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -import Actor -import Control.Monad -import Control.Monad.State -import System.IO -import qualified Data.Map.Strict as M - -type Key = String -type Val = String -type StoreID = Int -type ServerMsg = Either Val ClientToServer -type Server a = StateT (M.Map Key (Proc ())) (Actor ServerMsg) a - -data ClientToServer = Write Key Val | Read Key - -inputDriver :: Proc ServerMsg -> Actor () () -inputDriver server = do - command <- liftIO $ getLine - case words command of - ["write", s1, s2] -> server `send` (Right (Write s1 s2)) - ["read" , s ] -> server `send` (Right (Read s)) - _ -> liftIO $ putStrLn "didn't understand command" - loop - where loop = inputDriver server - -outputDriver :: Actor String () -outputDriver = do - receive $ \_ msg -> liftIO $ putStrLn msg - outputDriver - -serverProc :: Server () -serverProc = do - self <- getSelf - input <- spawn NoTrap (inputDriver self) - client <- spawn NoTrap outputDriver - forever $ mainServerLoop client - -storeProc :: Proc ServerMsg -> Val -> Actor () () -storeProc server val = receive $ \_ _ -> do - if length val > 5 then error "oops!" - else server `send` (Left val) >> loop - where loop = storeProc server val - -mainServerLoop :: Proc String -> Server () -mainServerLoop client = receiveAny handleMsg handleErr - where - handleMsg :: UProc -> ServerMsg -> Server () - handleMsg _ msg = case msg of - Left val -> send client val - Right req -> case req of - Write key val -> do - self <- getSelf - store <- spawnLink NoTrap (storeProc self val) - modify $ M.insert key store - Read key -> do - ans <- gets (M.lookup key) - case ans of Nothing -> sorry key - Just store -> store `send` () - handleErr err = client `send` ("Store " ++ show err ++ " down") - sorry key = client `send` ("Store " ++ key ++ " doesn't exist") - - -main :: IO () -main = runActor Trap (evalStateT serverProc mempty) diff --git a/examples/ad-tests.dx b/tests/ad-tests.dx similarity index 86% rename from examples/ad-tests.dx rename to tests/ad-tests.dx index 5d5b4e7c1..6affc69f6 100644 --- a/examples/ad-tests.dx +++ b/tests/ad-tests.dx @@ -1,54 +1,54 @@ -- TODO: use prelude sum instead once we can differentiate state effect -def sum' (xs:n=>Float) : Float = snd $ withAccum \ref. for i. ref += xs.i +def sum' (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs.i :p f : Float -> Float = \x. x jvp f 3.0 2.0 -> 2.0 +> 2. :p f = \x. x * x jvp f 3.0 1.5 -> 9.0 +> 9. :p f = \x. x + x jvp f 3.0 2.0 -> 4.0 +> 4. :p f = \x. x * x * x jvp f 2.0 1.5 -> 18.0 +> 18. :p f : Float --o Float = \x. x transposeLinear f 2.0 -> 2.0 +> 2. :p f : Float --o Float = \x. x + x transposeLinear f 1.0 -> 2.0 +> 2. :p f : Float --o Float = \x. x + (x + x) * 2.0 transposeLinear f 1.0 -> 5.0 +> 5. :p f : Float --o Float = \x. x * 2.0 transposeLinear f 1.0 -> 2.0 +> 2. :p f : Float --o Float = \x. 2.0 * x transposeLinear f 1.0 -> 2.0 +> 2. :p grad (\x. x * x) 1.0 -> 2.0 +> 2. :p deriv (\x. 3.0 / x) 2.0 > -0.75 @@ -61,49 +61,49 @@ def sum' (xs:n=>Float) : Float = snd $ withAccum \ref. for i. ref += xs.i \xs. for i. xs.i * xs.i jvp f [1.,2.] [3.,4.] -> [6.0, 16.0] +> [6., 16.] :p jvp transpose [[1.,2.], [3.,4.]] [[10.,20.], [30.,40.]] -> [[10.0, 30.0], [20.0, 40.0]] +> [[10., 30.], [20., 40.]] :p jvp sum' [1., 2.] [10.0, 20.0] -> 30.0 +> 30. -f : Float -> Float = \x. snd $ withAccum \ref. ref += x +f : Float -> Float = \x. yieldAccum \ref. ref += x :p jvp f 1.0 1.0 -> 1.0 +> 1. :p f = \x. x * x * x jvp (\x. jvp f x 1.0) 2.0 1.0 -> 12.0 +> 12. :p f = \x. 4.0 * x * x * x deriv (deriv (deriv f)) 1.234 -> 24.0 +> 24. :p f : Float --o (Float & Float) = \x. (x, 2.0 * x) transposeLinear f (1.0, 3.0) -> 7.0 +> 7. :p f : (Float & Float) --o Float = \(x,y). x + 2.0 * y transposeLinear f 1.0 -> (1.0, 2.0) +> (1., 2.) :p deriv cos 0.0 -> 0.0 +> 0. :p deriv sin 0.0 -> 1.0 +> 1. :p (sin 1.0, deriv (deriv sin) 1.0) -> (0.84147096, -0.84147096) +> (0.841471, -0.841471) :p (cos 1.0, deriv (deriv (deriv sin)) 1.0) -> (0.5403023, -0.5403023) +> (0.540302, -0.540302) :p checkDeriv sin 1.0 > True @@ -141,33 +141,33 @@ f : Float -> Float = \x. snd $ withAccum \ref. ref += x -- Perturbation confusion test suggested by Barak Pearlmutter -- https://github.com/HIPS/autograd/issues/4 :p deriv (\x. x * deriv (\y. x * y) 2.0) 1.0 -> 2.0 +> 2. tripleit : Float --o Float = \x. x + x + x :p tripleit 1.0 -> 3.0 +> 3. :p transposeLinear tripleit 1.0 -> 3.0 +> 3. :p transposeLinear (transposeLinear tripleit) 1.0 -> 3.0 +> 3. :p f : n:Type ?-> Float --o n=>Float = \x. for i. x transposeLinear f [1.0, 2.0] -> 3.0 +> 3. :p f : n:Type ?-> n=>Float --o n=>Float = \x. for i. x.i * 2.0 transposeLinear f [1.0, 2.0] -> [2.0, 4.0] +> [2., 4.] myOtherSquare : Float -> Float = - \x. snd $ withAccum \w. w += x * x + \x. yieldAccum \w. w += x * x :p checkDeriv myOtherSquare 3.0 > True @@ -177,59 +177,59 @@ myOtherSquare : Float -> Float = \x. fst (x * x, 2 + 1) jvp f 1.0 3.0 -> 6.0 +> 6. :p f : Float -> Float = \x. x * IToF (1 + 1) jvp f 1.0 2.0 -> 4.0 +> 4. :p f : (Fin 2)=>Float -> Float = \xs. xs.(0 @ Fin 2) * xs.(1 @ Fin 2) jvp f [1., 2.] [3.0, 4.0] -> 10.0 +> 10. :p f : (Float & Float) -> Float = \(x,y). x * y jvp f (1., 2.) (3.0, 4.0) -> 10.0 +> 10. :p f : n:Type ?-> n=>Float -> n=>Float = \xs. for i. xs.i * xs.i jvp f [1.,2.] [3.,4.] -> [6.0, 16.0] +> [6., 16.] :p jvp sum' [1., 2.] [3.0, 4.0] -> 7.0 +> 7. :p grad sum' [1.,2.] -> [1.0, 1.0] +> [1., 1.] vec = [1.] :p jvp (\x. vec) [1.] [1.] -> [0.0] +> [0.] :p grad (\(x, y). vdot x y) ([1.,2.], [3.,4.]) -> ([3.0, 4.0], [1.0, 2.0]) +> ([3., 4.], [1., 2.]) :p f : Float -> Float = \x. y = x * 2.0 - snd $ withAccum \a. + yieldAccum \a. a += x * 2.0 a += y grad f 1.0 -> 4.0 +> 4. :p f : Float -> Float = \x. @@ -242,7 +242,7 @@ vec = [1.] :p f : Float -> Float = \x. - snd $ withState x \xr. + yieldState x \xr. for i:(Fin 2). xr := get xr * get xr checkDeriv f 2.0 @@ -269,7 +269,7 @@ vec = [1.] :p f = \x. for i:(Fin 4). { x=x * x * (IToF $ ordinal i) } jvp f 2.0 1.0 -> [{x = 0.0}, {x = 4.0}, {x = 8.0}, {x = 12.0}] +> [{x = 0.}, {x = 4.}, {x = 8.}, {x = 12.}] :p s : { a : Float | b : Float } = case 2 == 2 of @@ -297,7 +297,7 @@ vec = [1.] :p f = \c. v = for i:(Fin 2). 2.0 - (c, v) = snd $ withState (c, v) \r. for i:(Fin 2). + (c, v) = yieldState (c, v) \r. for i:(Fin 2). (c, v) = get r r := (c + sum v, v) c diff --git a/examples/adt-tests.dx b/tests/adt-tests.dx similarity index 86% rename from examples/adt-tests.dx rename to tests/adt-tests.dx index 364b73d25..08a28c6d0 100644 --- a/examples/adt-tests.dx +++ b/tests/adt-tests.dx @@ -97,11 +97,12 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] :p for i. case myTab.i of MyLeft val -> val - MyRight _ -> todo + MyRight _ -> error "nope" +> nope > Runtime error :p - snd $ withAccum \ref. + yieldAccum \ref. for i. case myTab.i of MyLeft tmp -> () MyRight val -> ref += 1.0 + val @@ -109,7 +110,7 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] :p -- check that the order of the case alternatives doesn't matter - snd $ withAccum \ref. + yieldAccum \ref. for i. case myTab.i of MyRight val -> ref += 1.0 + val MyLeft tmp -> () @@ -127,7 +128,7 @@ threeCaseTab : (Fin 4)=>ThreeCases = > [(TheIntCase 3), TheEmptyCase, (ThePairCase 2 0.1), TheEmptyCase] :p - snd $ withAccum \ref. + yieldAccum \ref. for i. case threeCaseTab.i of TheEmptyCase -> ref += 1000.0 ThePairCase x y -> ref += 100.0 + y + IToF x @@ -242,14 +243,14 @@ def zerosLikeList (l : List a) : (Fin (listLength l))=>Float = for i:(Fin $ listLength l). 0.0 :p zerosLikeList l2 -> [0.0, 0.0, 0.0] +> [0., 0., 0.] data Graph a:Type = MkGraph n:Type nodes:(n=>a) m:Type edges:(m=>(n & n)) def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = init = for i j. False - snd $ withState init \mRef. + yieldState init \mRef. for i:m. (from, to) = edges.i mRef!from!to := True @@ -288,3 +289,32 @@ def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = :p nestedUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6)) > 4 + +data MySum = + Foo Float + Bar String + +-- bug #348 +:p + xs = for i:(Fin 3). + if ordinal i < 2 + then Foo 2.0 + else Foo 1.0 + (xs, xs) +> ([(Foo 2.), (Foo 2.), (Foo 1.)], [(Foo 2.), (Foo 2.), (Foo 1.)]) + +data MySum2 = + Foo2 + Bar2 (Fin 3 => Int) + +-- bug #348 +:p concat for i:(Fin 4). AsList _ [(Foo2, Foo2)] +> (AsList 4 [(Foo2, Foo2), (Foo2, Foo2), (Foo2, Foo2), (Foo2, Foo2)]) + +-- reproducer for a shadowing bug (PR #440) +:p concat $ for i:(Fin 2). toList [(Just [0,0,0], Just [0,0,0]), + (Just [0,0,0], Just [0,0,0])] +> (AsList 4 [ ((Just [0, 0, 0]), (Just [0, 0, 0])) +> , ((Just [0, 0, 0]), (Just [0, 0, 0])) +> , ((Just [0, 0, 0]), (Just [0, 0, 0])) +> , ((Just [0, 0, 0]), (Just [0, 0, 0])) ]) diff --git a/examples/complex-tests.dx b/tests/complex-tests.dx similarity index 81% rename from examples/complex-tests.dx rename to tests/complex-tests.dx index 7af6ea331..bf799cb39 100644 --- a/examples/complex-tests.dx +++ b/tests/complex-tests.dx @@ -1,11 +1,11 @@ :p complex_floor $ MkComplex 0.3 0.6 -> (MkComplex 0.0 0.0) +> (MkComplex 0. 0.) :p complex_floor $ MkComplex 0.6 0.8 -> (MkComplex 0.0 1.0) +> (MkComplex 0. 1.) :p complex_floor $ MkComplex 0.8 0.6 -> (MkComplex 1.0 0.0) +> (MkComplex 1. 0.) :p complex_floor $ MkComplex 0.6 0.3 -> (MkComplex 0.0 0.0) +> (MkComplex 0. 0.) a = MkComplex 2.1 0.4 b = MkComplex (-1.1) 1.3 @@ -48,10 +48,10 @@ b = MkComplex (-1.1) 1.3 > True :p sinh (MkComplex 1.2 3.2) -> (MkComplex -1.5068874 -0.10569556) +> (MkComplex -1.506887 -0.105696) :p cosh (MkComplex 1.2 3.2) -> (MkComplex -1.807568 8.811359e-2) +> (MkComplex -1.807568 0.088114) :p tanh (MkComplex 1.1 0.1) -> (MkComplex 0.80337524 3.5809334e-2) +> (MkComplex 0.803375 0.035809) :p tan (MkComplex 1.2 3.2) -> (MkComplex 2.2501666e-3 1.002451) +> (MkComplex 0.00225 1.002451) diff --git a/examples/eval-tests.dx b/tests/eval-tests.dx similarity index 90% rename from examples/eval-tests.dx rename to tests/eval-tests.dx index 4dd02dfe8..f853dceab 100644 --- a/examples/eval-tests.dx +++ b/tests/eval-tests.dx @@ -1,10 +1,10 @@ :p 1.0 + 2.0 -> 3.0 +> 3. :p double = \x. x * 2.0 double 10.0 -> 20.0 +> 20. :p sum (iota (Fin 10)) > 45 @@ -25,14 +25,14 @@ x = iota (Fin 3) y = map IToF x vdot' y y -> 10.0 +> 10. :p x = iota $ Fin 3 y = iota $ Fin 4 z = for i j. IToF x.i * IToF y.j sum (for i. sum z.i) -> 18.0 +> 18. -- :p randint (hash 0 0) 10 -- :p let x = unpack range 10000 @@ -51,7 +51,7 @@ arr = iota NArr fun = \y. sum (map IToF arr) + y :p fun 3.0 -> 24.0 +> 24. :p arr > [0, 1, 2, 3, 4, 5, 6] @@ -60,10 +60,10 @@ fun = \y. sum (map IToF arr) + y > 21 :p 6.0 - 10.0 -> -4.0 +> -4. :p (\(x, y). x + y) (1.0, 2.0) -> 3.0 +> 3. :p f : a:Type ?-> b:Type ?-> (a -> b & a) -> b = @@ -75,40 +75,40 @@ fun = \y. sum (map IToF arr) + y (x,y) = ((1.0,2.0),3.0) (x1, x2) = x x1 + x2 + y -> 6.0 +> 6. :p x = (1.0,2.0) (y,z) = x y + z -> 3.0 +> 3. -- :p let f (x, y) = x + 2 * y; -- z.i = (x.i, x.i * x.i) -- in sum (for i. f z.i) :p exp 1.0 -> 2.7182817 +> 2.718282 :p exp2 3.0 -> 8.0 +> 8. :p log 1.0 -> 0.0 +> 0. :p log2 8.0 -> 3.0 +> 3. :p log10 100.0 -> 2.0 +> 2. :p sqrt 2.0 -> 1.4142135 +> 1.414214 :p sin 3.14159 -> 2.5351817e-6 +> 0.000003 :p cos 0.0 -> 1.0 +> 1. :p tan 1.57079 > 159378.27 @@ -125,7 +125,7 @@ fun = \y. sum (map IToF arr) + y s = 1.0 :p s -> 1.0 +> 1. :p [2, 4, 8] > [2, 4, 8] @@ -141,7 +141,7 @@ cumsumplus : n=>Float -> n=>Float = (ans, 1.0 + ans) :p cumsumplus [1.0, 2.0, 3.0] -> [2.0, 4.0, 7.0] +> [2., 4., 7.] :p [False, False, True] > [False, False, True] @@ -213,15 +213,15 @@ litArr = [10, 5, 3] :p k = newKey 0 mean for i:(Fin 100). randn (ixkey k i) -> -0.1157995 +> -0.1158 :p k = newKey 0 mean for i:(Fin 100). sq $ randn (ixkey k i) -> 1.2581898 +> 1.25819 :p for i:(Fin 3) j:(Fin 2). rand $ ixkey2 (newKey 11) i j -> [[0.47415292, 0.9145164], [0.7944602, 0.27679908], [0.58958626, 0.7116251]] +> [[0.474153, 0.914516], [0.79446, 0.276799], [0.589586, 0.711625]] :p x = for i:(Fin 3). 0 @@ -234,7 +234,7 @@ litArr = [10, 5, 3] > [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] :p fold (for i:(Fin 3). 0.0) $ \i:(Fin 2) c. (for j. c.j + IToF (ordinal j)) -> [0.0, 2.0, 4.0] +> [0., 2., 4.] :p mat2 = for i:(Fin 4) j:(Fin 4) . ordinal i @@ -275,7 +275,7 @@ litArr = [10, 5, 3] > 1 :p select False 1.0 2.0 -> 2.0 +> 2. :p select True [1,2,3] [10,20,30] > [1, 2, 3] @@ -296,7 +296,7 @@ litArr = [10, 5, 3] > (False, (True, (True, True))) :p [(for i:(Fin 1). (False, for j:(Fin 2). 1.0)), [(True, for k:(Fin 2) . 2.0)]] -> [[(False, [1.0, 1.0])], [(True, [2.0, 2.0])]] +> [[(False, [1., 1.])], [(True, [2., 2.])]] -- TODO: parse negative integer literals -- :p (mod 5 3, mod 7 3, mod (-1) 3, mod -5 3) @@ -324,15 +324,15 @@ litArr = [10, 5, 3] > 2 :p fold (1.0, 2.0) \i:(Fin 2) (x, y). (y, x) -> (1.0, 2.0) +> (1., 2.) :p fold (1.0, 2.0) \i:(Fin 3) (x, y). (y, x) -> (2.0, 1.0) +> (2., 1.) :p id 2 > 2 :p min 2.0 3.0 -> 2.0 +> 2. :p minBy sq 0.5 (-2.0) > 0.5 @@ -344,16 +344,16 @@ litArr = [10, 5, 3] > (1.5, 15) :p max 2.0 3.0 -> 3.0 +> 3. :p maxBy sq 0.5 (-2.0) -> -2.0 +> -2. :p maximum [2.0, 4.0, 1.5, 7.0] -> 7.0 +> 7. :p maximumBy fst [(2.0, 20), (1.5, 15), (10.0, 100)] -> (10.0, 100) +> (10., 100) :p (1 == 2, (-1) == (-1), 1 < 2, -1 < 2, 2 < (-1)) > (False, (True, (True, (True, False)))) @@ -367,7 +367,7 @@ litArr = [10, 5, 3] σ = 1.0 + 2.0 :p σ -> 3.0 +> 3. δ : Int -> Int = \x. x @@ -399,59 +399,59 @@ litArr = [10, 5, 3] -- line comment should be ok here 2.0 * x f 1.0 -> 2.0 +> 2. -- Not sure why the ordinary `sum/for` version doesn't work anymore :p n = 3 + 7 - fsum \i:(Fin n). 1.0 -> 10.0 + fsum view i:(Fin n). 1.0 +> 10. :p n = 4 - fsum \i:(Fin n). 1.0 -> 4.0 + fsum view i:(Fin n). 1.0 +> 4. :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one..i). 1.0 -> [0.0, 1.0, 2.0, 3.0] +> [0., 1., 2., 3.] :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one<..i). 1.0 -> [0.0, 0.0, 1.0, 2.0] +> [0., 0., 1., 2.] :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one<..i). 1.0 -> [0.0, 0.0, 1.0, 2.0] +> [0., 0., 1., 2.] :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one.. [0.0, 0.0, 1.0, 2.0] +> [0., 0., 1., 2.] :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one<.. [0.0, 0.0, 0.0, 1.0] +> [0., 0., 0., 1.] :p for i:(Fin 4). sum for j:(..i). 1.0 -> [1.0, 2.0, 3.0, 4.0] +> [1., 2., 3., 4.] :p for i:(Fin 4). sum for j:(.. [0.0, 1.0, 2.0, 3.0] +> [0., 1., 2., 3.] :p for i:(Fin 4). sum for j:(i..). 1.0 -> [4.0, 3.0, 2.0, 1.0] +> [4., 3., 2., 1.] :p for i:(Fin 4). sum for j:(i<..). 1.0 -> [3.0, 2.0, 1.0, 0.0] +> [3., 2., 1., 0.] :p idiv 10 3 > 3 @@ -464,7 +464,7 @@ litArr = [10, 5, 3] ys = [1.,2.,3.] xys = for (i,j). xs.i + ys.j sum xys -> 102.0 +> 102. :p xs = [10.,20.] @@ -472,7 +472,7 @@ litArr = [10, 5, 3] zs = [1.] xys = for (i,(j,k)). xs.i + ys.j + zs.k sum xys -> 108.0 +> 108. :p xs = [[1,2],[3,4]] @@ -502,18 +502,18 @@ litArr = [10, 5, 3] -- > [2.0, 2.0, 2.0] :p - withState 0.0 \ref. for i:(Fin 4). + runState 0.0 \ref. for i:(Fin 4). c = get ref ref := c + 1.0 c -> ([0.0, 1.0, 2.0, 3.0], 4.0) +> ([0., 1., 2., 3.], 4.) :p - withState 0.0 \ref. rof i:(Fin 4). + runState 0.0 \ref. rof i:(Fin 4). c = get ref ref := c + 1.0 c -> ([3.0, 2.0, 1.0, 0.0], 4.0) +> ([3., 2., 1., 0.], 4.) def eitherFloor (x:(Int|Float)) : Int = case x of Left i -> i @@ -582,7 +582,7 @@ def unflatten (params:Params n m) : (Weights n m & Biases n) = -- TODO: within-module version of this (currently fails in Imp checking) upperBound = sum $ for i:(Fin 4). 1 :p for j:(Fin upperBound). 1.0 -> [1.0, 1.0, 1.0, 1.0] +> [1., 1., 1., 1.] :p (for i:(Fin upperBound). 1, for j:(Fin 2). 2) > ([1, 1, 1, 1], [2, 2]) @@ -609,7 +609,7 @@ for i:(Range 0 x). 1.0 -- Make sure that we can construct and print an array using a pair index set for i:(Fin 2 & Fin 2). 1.0 -> [1.0, 1.0, 1.0, 1.0]@(Fin 2 & Fin 2) +> [1., 1., 1., 1.]@(Fin 2 & Fin 2) 1@(Fin 2 & Fin 2) > ((0@Fin 2), (1@Fin 2)) @@ -618,11 +618,11 @@ for i:(Fin 5). for j:(..i). ir = IToF $ ordinal i jr = IToF $ ordinal j ir * (ir + 1.0) / 2.0 + jr -> [ [0.0]@(..(0@Fin 5)) -> , [1.0, 2.0]@(..(1@Fin 5)) -> , [3.0, 4.0, 5.0]@(..(2@Fin 5)) -> , [6.0, 7.0, 8.0, 9.0]@(..(3@Fin 5)) -> , [10.0, 11.0, 12.0, 13.0, 14.0]@(..(4@Fin 5)) ] +> [ [0.]@(..(0@Fin 5)) +> , [1., 2.]@(..(1@Fin 5)) +> , [3., 4., 5.]@(..(2@Fin 5)) +> , [6., 7., 8., 9.]@(..(3@Fin 5)) +> , [10., 11., 12., 13., 14.]@(..(4@Fin 5)) ] -- TODO: fix! -- -- Exercise the use of free variables in the sum solver @@ -640,24 +640,28 @@ for i:(Fin 5). for j:(..i). > [-1, -8] :p 2.0 .* [[1.0, 2.0], [3.0, 4.0]] -> [[2.0, 4.0], [6.0, 8.0]] +> [[2., 4.], [6., 8.]] def newtonIter (f: Float -> Float) (x:Float) : Float = x - (f x / deriv f x) def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = - snd $ withState x0 \x. - while (\(). abs (f $ get x) > tol) \(). - x := newtonIter f $ get x + yieldState x0 \x. + iter \i. + if abs (f $ get x) <= tol + then Done () + else + x := newtonIter f $ get x + Continue :p newtonSolve 0.001 (\x. sq x - 2.0) 1.0 -> 1.4142157 +> 1.414216 -- :p -- x = for i:(Fin 3). for j:(Fin 200). 1.0 -- -- Last dimension split to allow for vector loads -- y = for i:(Fin 200). for j:(Fin 4). for h:(Fin VectorWidth). IToF $ (iota _).(i,j,h) --- z = snd $ withAccum \acc. +-- z = yieldAccum \acc. -- for l. -- for i. -- xil = (broadcastVector x.i.l) @@ -685,7 +689,7 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = > [0, 2, 4, 6] :p - f = fst $ withState 1 \ref. + f = withState 1 \ref. x = get ref ref := 3 + x y = get ref @@ -694,12 +698,12 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = > 415 :p - (f, w) = withAccum \ref. + (f, w) = runAccum \ref. ref += 2.0 w = 2 \z. z + w (f 5, w) -> (7, 2.0) +> (7, 2.) -- def add (n : Type) ?-> (a : n=>Float) (b : n=>Float) : n=>Float = -- (tile (\t:(Tile n (Fin VectorWidth)). storeVector $ loadTile t a + loadTile t b) @@ -713,7 +717,7 @@ arr2d.(1@_) > [2, 3] :p - withState (1,2) \ref. + runState (1,2) \ref. r1 = fstRef ref r2 = sndRef ref x = get r1 @@ -744,22 +748,22 @@ easy = [(-2.0), 3.0, 3.0, 0.1, 0.0] hard = [(-1000.0), 1000.0, 1000.0, 0.1, 0.0] :p logsumexp easy - (log $ sum for j. exp easy.j) -> 0.0 +> 0. :p sum $ softmax hard -> 1.0 +> 1. :p all for i. ((softmax hard).i >= 0.0) > True :p sum for i. exp $ (logsoftmax hard).i -> 0.9999709 +> 0.999971 :p all for i. abs ((softmax hard).i - exp (logsoftmax hard).i) < 0.0001 > True :p evalpoly [2.0, 3.0, 4.0] 10.0 -> 234.0 +> 234. str = ['x', 'y'] @@ -769,7 +773,7 @@ str = ['x', 'y'] s1 = "hello world" :p s1 -> "hello world" +> (AsList 11 "hello world") :p codepoint 'a' > 97 @@ -815,9 +819,10 @@ triLit def fromLeftFloat (x:(Float | Int)) : Float = case x of Left x' -> x' - Right _ -> throw + Right _ -> error "this is an error" :p fromLeftFloat $ Right 1 +> this is an error > Runtime error :p fromLeftFloat $ Left 1.2 diff --git a/tests/exception-tests.dx b/tests/exception-tests.dx new file mode 100644 index 000000000..3df621614 --- /dev/null +++ b/tests/exception-tests.dx @@ -0,0 +1,59 @@ + + +def checkFloatInUnitInterval (x:Float) : {Except} Float = + assert $ x >= 0.0 + assert $ x <= 1.0 + x + +:p catch do assert False +> Nothing + +:p catch do assert True +> (Just ()) + +:p catch do checkFloatInUnitInterval 1.2 +> Nothing + +:p catch do checkFloatInUnitInterval (-1.2) +> Nothing + +:p catch do checkFloatInUnitInterval 0.2 +> (Just 0.2) + +:p yieldState 0 \ref. + catch do + ref := 1 + assert False + ref := 2 +> 1 + +:p catch do + for i:(Fin 5). + if ordinal i > 3 + then throw () + else 23 +> Nothing + +:p catch do + for i:(Fin 3). + if ordinal i > 3 + then throw () + else 23 +> (Just [23, 23, 23]) + +-- Is this the result we want? +:p yieldState zero \ref. + catch do + for i:(Fin 6). + if (ordinal i `rem` 2) == 0 + then throw () + else () + ref!i := 1 +> [0, 1, 0, 1, 0, 1] + +:p catch do + runState 0 \ref. + ref := 1 + assert False + ref := 2 +> Nothing diff --git a/examples/gpu-tests.dx b/tests/gpu-tests.dx similarity index 93% rename from examples/gpu-tests.dx rename to tests/gpu-tests.dx index 7116a4f19..9f0ed5831 100644 --- a/examples/gpu-tests.dx +++ b/tests/gpu-tests.dx @@ -27,7 +27,7 @@ testNestedLoops.(4@_).(5@_) -- single GPU thread. It should get lifted to a top-level allocation instead. allocationLiftingTest = for i:(Fin 100). - snd $ withState (for j:(Fin 1000). ordinal i) $ \s. + yieldState (for j:(Fin 1000). ordinal i) $ \s. s!(0@_) := get s!(0@_) + 1 (allocationLiftingTest.(4@_).(0@_), allocationLiftingTest.(4@_).(1@_)) > (5, 4) diff --git a/tests/io-tests.dx b/tests/io-tests.dx new file mode 100644 index 000000000..896e2d98c --- /dev/null +++ b/tests/io-tests.dx @@ -0,0 +1,74 @@ + +:p unsafeIO \(). + withTempFile \fname. + withFile fname WriteMode \stream. + fwrite stream "lorem ipsum\n" + fwrite stream "dolor sit amet\n" + readFile fname +> (AsList 27 "lorem ipsum +> dolor sit amet +> ") + +:p unsafeIO \(). + withAlloc 4 \ptr:(Ptr Int). + for i:(Fin 4). store (ptr +>> ordinal i) (ordinal i) + tabFromPtr (Fin 4) ptr +> [0, 1, 2, 3] + +unsafeIO \(). + print "testing log" +> testing log +> () + +unsafeIO \(). + for i':(Fin 10). + i = ordinal i' + if rem i 2 == 0 + then print $ show i <> " is even" + else print $ show i <> " is odd" +> 0 is even +> 1 is odd +> 2 is even +> 3 is odd +> 4 is even +> 5 is odd +> 6 is even +> 7 is odd +> 8 is even +> 9 is odd +> [(), (), (), (), (), (), (), (), (), ()] + +:p storageSize Int +> 4 + +:p unsafeIO \(). + withAlloc 1 \ptr:(Ptr Int). + store ptr 3 + load ptr +> 3 + +:p unsafeIO \(). + withDynamicBuffer \buf. + extendDynBuffer buf $ toList for i:(Fin 1000). ordinal i + extendDynBuffer buf $ toList for i:(Fin 1000). ordinal i + (AsList _ xs) = loadDynBuffer buf + sum xs +> 999000 + +:p unsafeIO \(). + s = for i:(Fin 10000). IToW8 $ FToI $ 128.0 * rand (ixkey (newKey 0) i) + withTempFile \fname. + withFile fname WriteMode \stream. + fwrite stream $ AsList _ s + (AsList _ s') = readFile fname + sum (for i. W8ToI s.i) == sum (for i. W8ToI s'.i) +> True + +:p unsafeIO do getEnv "NOT_AN_ENV_VAR" +> Nothing + +:p unsafeIO do getEnv "DEX_TEST_MODE" +> (Just (AsList 1 "t")) + +:p dex_test_mode () +> True diff --git a/examples/linear-tests.dx b/tests/linear-tests.dx similarity index 100% rename from examples/linear-tests.dx rename to tests/linear-tests.dx diff --git a/examples/loopy-ad-tests.dx b/tests/loopy-ad-tests.dx similarity index 100% rename from examples/loopy-ad-tests.dx rename to tests/loopy-ad-tests.dx diff --git a/examples/monad-tests.dx b/tests/monad-tests.dx similarity index 76% rename from examples/monad-tests.dx rename to tests/monad-tests.dx index 4612954c0..8f9994252 100644 --- a/examples/monad-tests.dx +++ b/tests/monad-tests.dx @@ -1,12 +1,12 @@ :p def m (h:Type) ?-> (ref:Ref h Int) : {State h} Int = get ref - withState 2 m + runState 2 m > (2, 2) :p def m (h:Type) ?-> (ref:Ref h Int) : {State h} Unit = ref := 3 - withState 0 m + runState 0 m > ((), 3) :p @@ -21,8 +21,8 @@ z = get ref ref := (z * 3.0) - withState 1.0 stateAction -> ((), 9.0) + runState 1.0 stateAction +> ((), 9.) :p def rwsAction @@ -37,10 +37,10 @@ r + 2 withReader 2 \r. - withState True \s. - withAccum \w. + runState True \s. + runAccum \w. rwsAction r w s -> ((4, 6.0), False) +> ((4, 6.), False) :p def m (h:Type) ?-> (s:Ref h (Fin 3=>Int)) : {State h} Unit = @@ -48,7 +48,7 @@ s!(fromOrdinal _ 2) := 20 x = get (s!(fromOrdinal _ 0)) s!(fromOrdinal _ 1) := x - withState [0,0,0] m + runState [0,0,0] m > ((), [10, 10, 20]) :p withReader [1,2,3] \r . ask r!(fromOrdinal _ 1) @@ -60,15 +60,15 @@ : {Accum wh, State sh} Unit = x = get s w += x - withState 1.0 \s. withAccum \w . m w s -> (((), 1.0), 1.0) + runState 1.0 \s. runAccum \w . m w s +> (((), 1.), 1.) def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = x = ask r w += x w += 2.0 -:p withReader 1.5 \r. withAccum \w. myAction w r +:p withReader 1.5 \r. runAccum \w. myAction w r > ((), 3.5) :p @@ -78,14 +78,14 @@ def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = w1 += 1.0 w2 += 3.0 w1 += 1.0 - withAccum \w1. withAccum \w2. m w1 w2 -> (((), 3.0), 2.0) + runAccum \w1. runAccum \w2. m w1 w2 +> (((), 3.), 2.) def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = s!(fromOrdinal _ 0) := 1 s!(fromOrdinal _ 2) := 2 -:p withState [0,0,0] foom +:p runState [0,0,0] foom > ((), [1, 0, 2]) -- TODO: handle effects returning functions @@ -102,7 +102,7 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = -- :p -- foo : Float -> (Float, Float) -- foo x = --- (f, ans) = withState x \s. +-- (f, ans) = runState x \s. -- y = get s -- \z. 100.0 * x + 10.0 * y + z -- (f 1.0, ans) @@ -113,7 +113,7 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = -- :p -- foo : Float -> (Float, Float) -- foo x = --- (f, ans) = withAccumulator \s. +-- (f, ans) = runAccumulator \s. -- s += x -- \y. 10.0 * x + y -- (f 1.0, ans) @@ -121,23 +121,23 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = -- foo 3.0 -- > (31.0, 3.0) --- TODO: some way to explicitly give type to `withAccum` +-- TODO: some way to explicitly give type to `runAccum` -- (maybe just explicit implicit args) :p withReader 2.0 \r. - withAccum \w. - withAccum \w'. - withState 3 \s. + runAccum \w. + runAccum \w'. + runState 3 \s. x = ask r y = get s w += x w' += x + x s := 4 (x, y) -> ((((2.0, 3), 4), 4.0), 2.0) +> ((((2., 3), 4), 4.), 2.) def symmetrizeInPlace (mat:n=>n=>Float) : n=>n=>Float = - snd $ withState mat \ref. + yieldState mat \ref. for i j. x = get ref!i!j y = get ref!j!i @@ -146,24 +146,31 @@ def symmetrizeInPlace (mat:n=>n=>Float) : n=>n=>Float = ref!j!i := avg symmetrizeInPlace [[1.,2.],[3.,4.]] -> [[1.0, 2.5], [2.5, 4.0]] +> [[1., 2.5], [2.5, 4.]] :p withReader 5 \r. () > () -:p snd $ withAccum \w. +:p yieldAccum \w. for i:(Fin 2). w += 1.0 w += 1.0 -> 4.0 +> 4. -:p snd $ withAccum \w. +:p yieldAccum \w. for i:(Fin 2). w += 1.0 w += 1.0 -> 3.0 +> 3. -:p snd $ withAccum \ref. +:p yieldAccum \ref. ref += [1.,2.,3.] ref += [2.,4.,5.] -> [3.0, 6.0, 8.0] +> [3., 6., 8.] + +def effectsAtZero (eff:Effects)?-> (f: Int ->{|eff} Unit) : {|eff} Unit = + f 0 + () + +:p runState 0 \ref. effectsAtZero \_. ref := 1 +> ((), 1) diff --git a/tests/parser-combinator-tests.dx b/tests/parser-combinator-tests.dx new file mode 100644 index 000000000..983a4da83 --- /dev/null +++ b/tests/parser-combinator-tests.dx @@ -0,0 +1,54 @@ + +include "parser.dx" + +parseABC : Parser Unit = MkParser \h. + parse h $ pChar 'A' + parse h $ pChar 'B' + parse h $ pChar 'C' + +:p runParser "AAA" parseABC +> Nothing + +:p runParser "ABCABC" parseABC +> Nothing + +:p runParser "AB" parseABC +> Nothing + +:p runParser "ABC" parseABC +> (Just ()) + +def parseT : Parser Bool = MkParser \h. + parse h $ pChar 'T' + True + +def parseF : Parser Bool = MkParser \h. + parse h $ pChar 'F' + False + +def parseTF : Parser Bool = + parseT <|> parseF + +def parserTFTriple : Parser (Fin 3=>Bool) = MkParser \h. + for i. parse h parseTF + +:p runParser "TTF" parserTFTriple +> (Just [True, True, False]) + +:p runParser "TTFX" parserTFTriple +> Nothing + +:p runParser "TTFFTT" $ parseMany parseTF +> (Just (AsList 6 [True, True, False, False, True, True])) + +:p runParser "1021389" $ parseMany parseDigit +> (Just (AsList 7 [1, 0, 2, 1, 3, 8, 9])) + +:p runParser "1389" $ parseInt +> (Just 1389) + +:p runParser "01389" $ parseInt +> (Just 1389) + +:p runParser "-1389" $ parseInt +> (Just -1389) diff --git a/examples/parser-tests.dx b/tests/parser-tests.dx similarity index 83% rename from examples/parser-tests.dx rename to tests/parser-tests.dx index 81e8acc72..93d216354 100644 --- a/examples/parser-tests.dx +++ b/tests/parser-tests.dx @@ -1,28 +1,28 @@ 'For now, arithmetic is not sensitive to whitespace: :p 1.0+1.0 -> 2.0 +> 2. :p 1.0 +1.0 -> 2.0 +> 2. :p 1.0+ 1.0 -> 2.0 +> 2. :p 1.0 + 1.0 -> 2.0 +> 2. :p 1.0-1.0 -> 0.0 +> 0. :p 1.0 -1.0 -> 0.0 +> 0. :p 1.0- 1.0 -> 0.0 +> 0. :p 1.0 - 1.0 -> 0.0 +> 0. 'Applying a function to a negative literal thus requires parentheses. @@ -37,7 +37,7 @@ f = \x. x + 10. > ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ :p f (-1.0) -> 9.0 +> 9. 'Lambdas can have specific arrow annotations. @@ -94,7 +94,7 @@ lam4 = \n m ?-> (0@n, 0@m) > [1, 0, 0] :p - withState 5 \ref. + runState 5 \ref. n = get ref for_ i:(Fin n). ref := get ref + 1 @@ -111,3 +111,20 @@ def myInt : {State h} Int = 1 > 107 | def myInt : {State h} Int = 1 > | ^ > Nullary def can't have effects + +:p + yieldAccum \ref. + x = if True then 1. else 3. + if True then ref += x + + if True then + ref += 1. + ref += 2. + + if False then ref += 100. else + ref += 1. + ref += 2. + + if True + then ref += 2. +> 9. diff --git a/examples/record-variant-tests.dx b/tests/record-variant-tests.dx similarity index 96% rename from examples/record-variant-tests.dx rename to tests/record-variant-tests.dx index 871fa2941..1e0c259ea 100644 --- a/examples/record-variant-tests.dx +++ b/tests/record-variant-tests.dx @@ -48,7 +48,7 @@ Syntax for records, variants, and their types. x = {a=5.0, b=2} y : {a:Int & b:Int & ..._} = {a=3, a=4, ...x} y -> {a = 3, a = 4, a = 5.0, b = 2} +> {a = 3, a = 4, a = 5., b = 2} 'Variant (enum) types @@ -72,7 +72,7 @@ Syntax for records, variants, and their types. > {| a = 3 |} :p {| a | a = 3.0 |} : {a:Int | a:Float | a:Int} -> {|a| a = 3.0 |} +> {|a| a = 3. |} :t {| a | a = 3.0 |} : {a:Int | a:Float | a:Int} > {a: Int32 | a: Float32 | a: Int32} @@ -148,7 +148,7 @@ def getTwoFoosAndABar (rest : Fields)?-> (f1, f2, b) :p getTwoFoosAndABar {foo=1, bar=2, foo=0.0, foo=4, baz=3.0, bar=7} -> (1, (0.0, 2)) +> (1, (0., 2)) :p ({b=b, a=a1, a=a2}) = {a=1, b=2} @@ -183,7 +183,7 @@ x : {a:Int | a:Float | a:Int} = {| a | a = 3.0 |} foo = 1 bar = 2.0 {foo, bar} -> {bar = 2.0, foo = 1} +> {bar = 2., foo = 1} :p ({foo, ...}) = {foo=1, bar=2.0} @@ -207,7 +207,7 @@ x : {a:Int | a:Float | a:Int} = {| a | a = 3.0 |} {| a = x |} -> IToF x {| a | a = x |} -> x {| b = x |} -> IToF x -> 3.0 +> 3. 'Table values and imp lowering @@ -215,23 +215,23 @@ myRecordTable : (Fin 2)=>{a:Int & b:Float} = [{a=1, b=2.0}, {a=3, b=4.0}] :p myRecordTable -> [{a = 1, b = 2.0}, {a = 3, b = 4.0}] +> [{a = 1, b = 2.}, {a = 3, b = 4.}] :p for i:(Fin 2). ({a=a, b=b}) = myRecordTable.i {a=b, b=a} -> [{a = 2.0, b = 1}, {a = 4.0, b = 3}] +> [{a = 2., b = 1}, {a = 4., b = 3}] myVariantTable : (Fin 2)=>{a:Int | b:Float} = [{| a=1 |}, {| b=2.0 |}] :p myVariantTable -> [{| a = 1 |}, {| b = 2.0 |}] +> [{| a = 1 |}, {| b = 2. |}] :p for i:(Fin 2). v : {a:_ | b:_} = case myVariantTable.i of {| a=a |} -> {| b=a |} {| b=b |} -> {| a=b |} v -> [{| b = 1 |}, {| a = 2.0 |}] +> [{| b = 1 |}, {| a = 2. |}] -- Known variant, unused tail pattern :p @@ -240,7 +240,7 @@ myVariantTable : (Fin 2)=>{a:Int | b:Float} = [{| a=1 |}, {| b=2.0 |}] {| a = x |} -> 1.0 {| a | a = x |} -> x {|a|a| ..._ |} -> 5.0 -> 3.0 +> 3. -- Known variant, missing pattern :p @@ -248,7 +248,7 @@ myVariantTable : (Fin 2)=>{a:Int | b:Float} = [{| a=1 |}, {| b=2.0 |}] case x of {| a = x |} -> 1.0 {| a | a = x |} -> x -> 3.0 +> 3. -- Known variant, used tail pattern myVal = diff --git a/examples/repl-multiline-test-expected-output b/tests/repl-multiline-test-expected-output similarity index 70% rename from examples/repl-multiline-test-expected-output rename to tests/repl-multiline-test-expected-output index 6ff96b0fe..129d35875 100644 --- a/examples/repl-multiline-test-expected-output +++ b/tests/repl-multiline-test-expected-output @@ -3,7 +3,7 @@ >=> >=> >=> ->=> ... ... ... ... 30.0 +>=> ... ... ... ... 30. >=> >=> ... ... >=> @@ -11,5 +11,5 @@ >=> >=> >=> ->=> 3.0 +>=> 3. >=> \ No newline at end of file diff --git a/examples/repl-multiline-test.dx b/tests/repl-multiline-test.dx similarity index 100% rename from examples/repl-multiline-test.dx rename to tests/repl-multiline-test.dx diff --git a/examples/serialize-tests.dx b/tests/serialize-tests.dx similarity index 78% rename from examples/serialize-tests.dx rename to tests/serialize-tests.dx index ccbb34b6c..c4f27e80e 100644 --- a/examples/serialize-tests.dx +++ b/tests/serialize-tests.dx @@ -2,13 +2,13 @@ > 1 :p 1.0 -> 1.0 +> 1. :p [1, 2, 3] > [1, 2, 3] :p [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] -> [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] +> [[1., 2., 3.], [4., 5., 6.]] :p fromOrdinal (Fin 10) 7 > (7@Fin 10) @@ -19,7 +19,7 @@ :p () > () -x = "ab" +x = ['a', 'b'] :p for (i,j). [x.i, x.j] > ["aa", "ab", "ba", "bb"]@(Fin 2 & Fin 2) @@ -29,13 +29,13 @@ x = "ab" > {a = 1, b = 2} :p {a="1234", b=[1, 2, 3]} -> {a = "1234", b = [1, 2, 3]} +> {a = (AsList 4 "1234"), b = [1, 2, 3]} :p [{| a=1 |}, {| b=2.0 |}] : (Fin 2) => {a:Int | b:Float} -> [{| a = 1 |}, {| b = 2.0 |}] +> [{| a = 1 |}, {| b = 2. |}] :p {table = [{| a=1 |}, {| b=2.0 |}]} : {table: (Fin 2) => {a:Int | b:Float}} -> {table = [{| a = 1 |}, {| b = 2.0 |}]} +> {table = [{| a = 1 |}, {| b = 2. |}]} 'Values without a pretty-printer (currently shows warning message): diff --git a/examples/shadow-tests.dx b/tests/shadow-tests.dx similarity index 100% rename from examples/shadow-tests.dx rename to tests/shadow-tests.dx diff --git a/examples/show-tests.dx b/tests/show-tests.dx similarity index 90% rename from examples/show-tests.dx rename to tests/show-tests.dx index 0f3e26171..1d0d24098 100644 --- a/examples/show-tests.dx +++ b/tests/show-tests.dx @@ -1,4 +1,8 @@ '# `Show` instances +-- String + +:p show "abc" +> (AsList 3 "abc") -- Int32 @@ -16,10 +20,8 @@ :p show (IToI64 1234: Int64) > (AsList 4 "1234") --- FIXME(https://github.com/google-research/dex-lang/issues/317): --- Unexpected zext from type conversion of negative Int32 to Int64. :p show (IToI64 (-1234): Int64) -> (AsList 10 "4294966062") +> (AsList 5 "-1234") -- Float32 diff --git a/examples/trig-tests.dx b/tests/trig-tests.dx similarity index 82% rename from examples/trig-tests.dx rename to tests/trig-tests.dx index c8bd9dcf9..abd6776ae 100644 --- a/examples/trig-tests.dx +++ b/tests/trig-tests.dx @@ -20,6 +20,14 @@ > True :p atan2 (-sin (-0.44)) (cos (-0.44)) ~~ (0.44) > True +:p atan2 (-1.0) (-1.0) ~~ (-3.0/4.0*pi) +> True + +-- Test all the way around the circle. +angles = linspace (Fin 11) (-pi + 0.001) (pi) +:p all for i:(Fin 11). + angles.i ~~ atan2 (sin angles.i) (cos angles.i) +> True :p (atan2 infinity 1.0) ~~ ( pi / 2.0) > True diff --git a/examples/type-tests.dx b/tests/type-tests.dx similarity index 98% rename from examples/type-tests.dx rename to tests/type-tests.dx index 396540e69..2261e6878 100644 --- a/examples/type-tests.dx +++ b/tests/type-tests.dx @@ -152,17 +152,17 @@ MyPair : Type -> Type = ((1, 2), (1.0, 2.0)) pairs -> ((1, 2), (1.0, 2.0)) +> ((1, 2), (1., 2.)) -- TODO: put source annotation on effect for a better message here fEff : Unit -> {| a} a = todo > Type error: -> Expected: EffKind -> Actual: Type +> Expected: Type +> Actual: EffKind > > fEff : Unit -> {| a} a = todo -> ^^^^^^^^^ +> ^^ :p for i:(Fin 7). sum for j:(Fin unboundName). 1.0 @@ -367,7 +367,7 @@ def triRefIndex (ref:Ref h (i':n=>(..i')=>Float)) (i:n) : Ref h ((..i)=>Float) = -- There was a time when this wasn't possible, because checking mode would unify the -- input type with a non-dependent function type, leading to a later unification errors. id (for i:(Fin 2). for j:(..i). 1.0) -> [[1.0]@(..(0@Fin 2)), [1.0, 1.0]@(..(1@Fin 2))] +> [[1.]@(..(0@Fin 2)), [1., 1.]@(..(1@Fin 2))] def weakerInferenceReduction (l : i:n=>(..i)=>Float) (j:n): Unit = for i:(..j). diff --git a/tests/typeclass-tests.dx b/tests/typeclass-tests.dx new file mode 100644 index 000000000..7970fbfcd --- /dev/null +++ b/tests/typeclass-tests.dx @@ -0,0 +1,43 @@ + + +interface InterfaceTest1 a + InterfaceTest1 : a +> Error: variable already defined: InterfaceTest1 + +interface InterfaceTest3 a + foo : a -> Int + foo : a -> Int +> Error: variable already defined: foo + +interface InterfaceTest4 a + foo : Int + bar : a -> Int + +instance InterfaceTest4 Float + foo = 1 + bar = \_. 1 + foo = 1 +> Type error:Duplicate method: foo + +instance InterfaceTest4 Float + foo = 1 +> Type error:Missing method: bar + +instance InterfaceTest4 Float + baz = 1 +> Type error:baz is not a method of InterfaceTest4 + +instance InterfaceTest4 Float + foo = 1 + bar = \_. 'x' +> Type error: +> Expected: Int32 +> Actual: Word8 +> +> bar = \_. 'x' +> ^^^ + +instance InterfaceTest4 Float + foo = 1 + bar = \_. 1 + diff --git a/examples/uexpr-tests.dx b/tests/uexpr-tests.dx similarity index 95% rename from examples/uexpr-tests.dx rename to tests/uexpr-tests.dx index 7c76018b3..eb13abb1f 100644 --- a/examples/uexpr-tests.dx +++ b/tests/uexpr-tests.dx @@ -13,12 +13,12 @@ def returnFirstArg (a:Type) (b:Type) (x:a) (y:b) : a = x > 1 :p 1.0 + 2.0 -> 3.0 +> 3. def triple (x:Float) : Float = x + x + x :p triple 1.0 -> 3.0 +> 3. def idExplicit (a:Type) (x:a) : a = x @@ -39,7 +39,7 @@ idImplicit2 : (a:Type ?-> a -> a) = \x. x > 1 :p (\x y. x + y) 1.0 2.0 -> 3.0 +> 3. :p 1.0 + 1 > Type error: @@ -100,13 +100,13 @@ myPair = (1, 2.3) > 1 :p - snd $ withState 2 \s. + yieldState 2 \s. x = get s s := x + 3 > 5 :p - snd $ withState 1 \s. + yieldState 1 \s. for i:(Fin 10). x = get s s := x + x @@ -137,8 +137,8 @@ myPair = (1, 2.3) > [1, 2, 3, 4]@(Fin 2 & Fin 2) -:p sin 1.0 -> 0.84147096 +:p sin 1.01 +> 0.846832 :p (x,y) = (1,2) @@ -178,7 +178,7 @@ myPair = (1, 2.3) > ^^^ :p - snd $ withState [1,2,3] \xsRef. + yieldState [1,2,3] \xsRef. for i:(Fin 3). xsRef!i := ordinal i > [0, 1, 2] @@ -186,19 +186,19 @@ myPair = (1, 2.3) def passthrough (eff:Effects) ?-> (f:(a -> {|eff} b)) (x:a) : {|eff} b = f x :p - snd $ withState 1 \ref. + yieldState 1 \ref. passthrough (\(). ref := 10) () > 10 :p - withState 0 \r1. - withState 0 \r2. + runState 0 \r1. + runState 0 \r2. r1 := 1 r2 := 2 > (((), 2), 1) :p (\f x y. f x y) (+) 1.0 2.0 -> 3.0 +> 3. :p myId = fst (\x. x, 2) @@ -220,23 +220,23 @@ def myOtherFst ((x, _):(a&b)) : a = x > 1 :p sum [1.,2.] -> 3.0 +> 3. :p xs = fanout _ 1.0 for i:(Fin 3). xs.i + xs.i -> [2.0, 2.0, 2.0] +> [2., 2., 2.] :p f = \x. x * x * x jvp f 2.0 1.5 -> 18.0 +> 18. :p f : Float --o Float = \x. 2.0 * (x + x) transposeLinear f 1.0 -> 4.0 +> 4. -- FIXME: This fails due to shadowing! --def transpose' (x:n=>m=>Float) --o : m=>n=>Float = for i j. x.j.i @@ -248,7 +248,7 @@ def myOtherFst ((x, _):(a&b)) : a = x f : Float --o (Fin 3=>Float) = \x. for i. x * 2.0 transposeLinear f [1.0, 2.0, 3.0] -> 12.0 +> 12. id'' : b -> b = id