Skip to content

Commit

Permalink
Merge pull request #353 from google-research/dev
Browse files Browse the repository at this point in the history
Merging dev back into main
  • Loading branch information
dougalm authored Dec 18, 2020
2 parents 673b5fb + 5e3953d commit f8f9282
Show file tree
Hide file tree
Showing 80 changed files with 6,755 additions and 3,587 deletions.
15 changes: 10 additions & 5 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Tests

on:
push:
branches: [ main ]
branches: [ main, dev ]
pull_request:
branches: [ main ]
branches: [ main, dev ]

jobs:
build:
Expand All @@ -13,10 +13,10 @@ jobs:
os: [ubuntu-18.04, macos-latest]
include:
- os: macos-latest
install_deps: brew install llvm@9
install_deps: brew install llvm@9 pkg-config
path_extension: $(brew --prefix llvm@9)/bin
- os: ubuntu-18.04
install_deps: sudo apt-get install llvm-9-tools llvm-9-dev
install_deps: sudo apt-get install llvm-9-tools llvm-9-dev pkg-config
path_extension: /usr/lib/llvm-9/bin

runs-on: ${{ matrix.os }}
Expand All @@ -35,7 +35,7 @@ jobs:
- name: Install system dependencies
run: |
${{ matrix.install_deps }}
echo "::add-path::${{ matrix.path_extension }}"
echo "${{ matrix.path_extension }}" >> $GITHUB_PATH
- name: Cache
uses: actions/cache@v2
Expand All @@ -46,6 +46,11 @@ jobs:
key: ${{ runner.os }}-${{ hashFiles('**/*.cabal', 'stack*.yaml') }}
restore-keys: ${{ runner.os }}-

# See https://github.com/actions/cache/issues/445
- name: Remove cached Setup executables
run: rm -rf ~/.stack/setup-exe-cache
if: runner.os == 'macOS'

- name: Build
run: make

Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ garbage.hs
stack.yaml.lock
*.pyc
doc/
scratch.dx
scratch*
scratch/
test-scratch/
dist-newstyle/
Expand All @@ -22,3 +22,5 @@ dist-newstyle/
benchmarks/exe/**/*
benchmarks/parboil/data
benchmarks/rodinia/rodinia
examples/export/scalar
examples/export/array
25 changes: 17 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ or these example programs:
* [Estimating pi](https://google-research.github.io/dex-lang/pi.html)
* [Hamiltonian Monte Carlo](https://google-research.github.io/dex-lang/mcmc.html)
* [ODE integrator](https://google-research.github.io/dex-lang/ode-integrator.html)
* [Sierpinsky triangle](https://google-research.github.io/dex-lang/sierpinsky.html)
* [Sierpinski triangle](https://google-research.github.io/dex-lang/sierpinski.html)
* [Basis function regression](https://google-research.github.io/dex-lang/regression.html)
* [Brownian bridge](https://google-research.github.io/dex-lang/brownian_motion.html)

Expand All @@ -29,21 +29,30 @@ development. Expect monstrous bugs and razor-sharp edges. Contributions welcome!
* Install [stack](https://www.haskellstack.org)
* Install LLVM 9
* `apt-get install llvm-9-dev` on Ubuntu/Debian,
* `brew install llvm@9` on macOS.
* `brew install llvm@9` on macOS, and ensure it is on your `PATH` e.g. via `export PATH="$(brew --prefix llvm@9)/bin:$PATH"` before building.
* Install libpng (often included by default in *nix)

## Building

* Build Dex in development mode: `make`
* Run tests in development mode: `make tests`
* Install a release version of Dex: `make install`

The default installation directory is `$HOME/.local/bin` so make sure to add that
directory to `$PATH` once you install Dex. If you'd like to install it somewhere else
make sure to have the `PREFIX` environment variable set when you run `make install`.
For example `PREFIX=$HOME make install` would install `dex` in `$HOME/bin`.
The default installation directory is `$HOME/.local/bin`, so make sure to add
that directory to `$PATH` after installing Dex. To install Dex somewhere else,
set the `PREFIX` environment variable before running `make install`. For
example, `PREFIX=$HOME make install` installs `dex` in `$HOME/bin`.

While working in development mode, it is convenient to set up a `dex` alias
(e.g. in .bashrc): `alias dex="stack exec dex --"`.
It is convenient to set up a `dex` alias (e.g. in `.bashrc`) for running Dex in
development mode:

```console
# Linux:
alias dex="stack exec dex --"

# macOS:
alias dex="stack exec --stack-yaml=stack-macos.yaml dex --"
```

## Running

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/prepare-executables.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def prepare_rodinia_backprop():
exe_path.mkdir(parents=True, exist_ok=True)
exe_path_ad = RODINIA_EXE_ROOT / 'backprop-ad'
exe_path_ad.mkdir(parents=True, exist_ok=True)
in_features = [512, 123]
in_features = [128, 1048576]

for inf, use_ad in product(in_features, (False, True)):
outf = 1
Expand Down
20 changes: 12 additions & 8 deletions dex.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,22 @@ library dex-resources
build-depends: base, bytestring, file-embed
hs-source-dirs: src/resources
default-language: Haskell2010
default-extensions: CPP

library
exposed-modules: Env, Syntax, Type, Inference, JIT, LLVMExec,
Parser, Util, Imp, PPrint, Array, Algebra,
Actor, Cat, Flops, Embed, Serialize, Optimize,
RenderHtml, Plot, LiveOutput, Simplify, TopLevel,
Autodiff, Interpreter, Logging, PipeRPC, CUDA
Parser, Util, Imp, Imp.Embed, Imp.Optimize,
PPrint, Algebra, Parallelize, Optimize, Serialize
Actor, Cat, Flops, Embed,
RenderHtml, LiveOutput, Simplify, TopLevel,
Autodiff, Interpreter, Logging, PipeRPC, CUDA,
LLVM.JIT, LLVM.Shims
build-depends: base, containers, mtl, binary, bytestring,
time, tf-random, llvm-hs-pure ==9.*, llvm-hs ==9.*,
aeson, megaparsec >=8.0, warp, wai, filepath,
parser-combinators, http-types, prettyprinter, text,
blaze-html, cmark, diagrams-lib, ansi-terminal,
diagrams-rasterific, JuicyPixels, transformers,
base64-bytestring, vector, directory, mmap, unix,
transformers, directory, mmap, unix,
process, primitive, store, dex-resources, temporary,
if !os(darwin)
exposed-modules: Resources
Expand All @@ -54,6 +56,7 @@ library
cxx-options: -std=c++11 -fPIC
default-extensions: CPP, DeriveTraversable, TypeApplications, OverloadedStrings,
TupleSections, ScopedTypeVariables, LambdaCase, PatternSynonyms
pkgconfig-depends: libpng
if flag(cuda)
include-dirs: /usr/local/cuda/include
extra-libraries: cuda
Expand All @@ -73,7 +76,7 @@ executable dex
build-depends: dex-resources
default-language: Haskell2010
hs-source-dirs: src
default-extensions: CPP
default-extensions: CPP, LambdaCase
if flag(optimized)
ghc-options: -O3
else
Expand All @@ -85,7 +88,8 @@ foreign-library Dex
build-depends: base, dex, dex-resources, mtl
hs-source-dirs: src/foreign
c-sources: src/foreign/rts.c
ghc-options: -Wall
cc-options: -std=c11 -fPIC
ghc-options: -Wall -fPIC
default-language: Haskell2010
default-extensions: TypeApplications, ScopedTypeVariables, LambdaCase
if flag(optimized)
Expand Down
116 changes: 113 additions & 3 deletions examples/adt-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ data MyPair a:Type b:Type =
z = MkMyPair 1 2.3

:p z
> MkMyPair 1 2.3
> (MkMyPair 1 2.3)

:t z
> (MyPair Int32 Float32)
Expand All @@ -33,7 +33,7 @@ data Dual a:Type =
> (1, 2)

:p for i:(Fin 3). MkMyPair (ordinal i) (ordinal i + 1)
> [MkMyPair 0 1, MkMyPair 1 2, MkMyPair 2 3]
> [(MkMyPair 0 1), (MkMyPair 1 2), (MkMyPair 2 3)]

zz = MkMyPair 1 (MkMyPair True 2.3)

Expand All @@ -57,7 +57,7 @@ data MyEither a:Type b:Type =
x : MyEither Int Float = MyLeft 1

:p x
> MyLeft 1
> (MyLeft 1)

:p
(MyLeft x') = x
Expand Down Expand Up @@ -178,3 +178,113 @@ xsList = AsList _ [1,2,3]
(AsList _ ans) = xs <> ys
sum ans
> 15

:p
(MkMyPair x y) = case 3 < 2 of
True -> MkMyPair 1 2
False -> MkMyPair 3 4
(x, y)
> (3, 4)

def catLists (xs:List a) (ys:List a) : List a =
(AsList nx xs') = xs
(AsList ny ys') = ys
nz = nx + ny
zs = for i:(Fin nz).
i' = ordinal i
case i' < nx of
True -> xs'.(fromOrdinal _ i')
False -> ys'.(fromOrdinal _ (i' - nx))
AsList _ zs

:p
(AsList _ xs) = catLists (AsList _ [1,2,3]) (AsList _ [4,5])
sum xs
> 15

:p catLists (AsList _ [1,2,3]) (AsList _ [4,5])
> (AsList 5 [1, 2, 3, 4, 5])

:p
n = 1 + 4
AsList _ (for i:(Fin n). ordinal i)
> (AsList 5 [0, 1, 2, 3, 4])



def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs

:t listToTable
> ((a:Type) ?-> (pat:(List a)) -> (Fin ((\((AsList n _)). n) pat)) => a)

:p
l = AsList _ [1, 2, 3]
sum $ listToTable l
> 6

def listToTable2 (l: List a) : (Fin (listLength l))=>a =
(AsList _ xs) = l
xs

:t listToTable2
> ((a:Type) ?-> (l:(List a)) -> (Fin ((\((AsList n _)). n) l)) => a)

:p
l = AsList _ [1, 2, 3]
sum $ listToTable2 l
> 6

l2 = AsList _ [1, 2, 3]
:p sum $ listToTable2 l2
> 6

def zerosLikeList (l : List a) : (Fin (listLength l))=>Float =
for i:(Fin $ listLength l). 0.0

:p zerosLikeList l2
> [0.0, 0.0, 0.0]

data Graph a:Type =
MkGraph n:Type nodes:(n=>a) m:Type edges:(m=>(n & n))

def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool =
init = for i j. False
snd $ withState init \mRef.
for i:m.
(from, to) = edges.i
mRef!from!to := True

:t graphToAdjacencyMatrix
> ((a:Type)
> ?-> (pat:(Graph a))
> -> ((\((MkGraph n _ _ _)). n) pat) => ((\((MkGraph n _ _ _)). n) pat) => Bool)

:p
g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)]
graphToAdjacencyMatrix g
> [[False, True, True], [False, True, False], [True, False, False]]

-- Test how (nested) projections are handled and pretty-printed.

def pairUnpack ((v, _):(Int & Float)) : Int = v
:p pairUnpack
> \pat:(Int32 & Float32). (\(a, _). a) pat

def adtUnpack ((MkMyPair v _):MyPair Int Float) : Int = v
:p adtUnpack
> \pat:(MyPair Int32 Float32). (\((MkMyPair elt _)). elt) pat

def recordUnpack ({a=v, b=_}:{a:Int & b:Float}) : Int = v
:p recordUnpack
> \pat:{a: Int32 & b: Float32}. (\{a = a, b = _}. a) pat

def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int =
(MkMyPair _ (MkMyPair (MkIntish y, _) _)) = x
y

:p nestedUnpack
> \x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)).
> (\((MkIntish (((MkMyPair ((MkMyPair _ elt)) _)), _))). elt) x

:p nestedUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6))
> 4
4 changes: 2 additions & 2 deletions examples/brownian_motion.dx
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ sampleBM : Key -> UnitInterval -> Float =
xs = linspace (Fin 1000) 0.0 1.0
ys = map (sampleBM (newKey 0)) xs

:plot zip xs ys
> <graphical output>
-- :plot zip xs ys
-- > <html output>
57 changes: 57 additions & 0 deletions examples/complex-tests.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
:p complex_floor $ MkComplex 0.3 0.6
> (MkComplex 0.0 0.0)
:p complex_floor $ MkComplex 0.6 0.8
> (MkComplex 0.0 1.0)
:p complex_floor $ MkComplex 0.8 0.6
> (MkComplex 1.0 0.0)
:p complex_floor $ MkComplex 0.6 0.3
> (MkComplex 0.0 0.0)

a = MkComplex 2.1 0.4
b = MkComplex (-1.1) 1.3
:p (a + b - a) ~~ b
> True
:p (a * b) ~~ (b * a)
> True
:p divide (a * b) a ~~ b
> True
-- This next test can be added once we parameterize the field in the VSpace typeclass.
--:p ((a * b) / a) ~~ b
--> True
:p a == b
> False
:p a == a
> True
:p log (exp a) ~~ a
> True
:p exp (log a) ~~ a
> True
:p log2 (exp2 a) ~~ a
> True
:p exp2 (log2 a) ~~ a
> True
:p sqrt (sq a) ~~ a
> True
:p log ((MkComplex 1.0 0.0) + a) ~~ log1p a
> True
:p sin (-a) ~~ (-(sin a))
> True
:p cos (-a) ~~ cos a
> True
:p tan (-a) ~~ (- (tan a))
> True
:p exp (pi .* (MkComplex 0.0 1.0)) ~~ (MkComplex (-1.0) 0.0) -- Euler's identity
> True
:p ((sq (sin a)) + (sq (cos a))) ~~ (MkComplex 1.0 0.0)
> True
:p complex_abs b > 0.0
> True

:p sinh (MkComplex 1.2 3.2)
> (MkComplex -1.5068874 -0.10569556)
:p cosh (MkComplex 1.2 3.2)
> (MkComplex -1.807568 8.811359e-2)
:p tanh (MkComplex 1.1 0.1)
> (MkComplex 0.80337524 3.5809334e-2)
:p tan (MkComplex 1.2 3.2)
> (MkComplex 2.2501666e-3 1.002451)
2 changes: 1 addition & 1 deletion examples/ctc.dx
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ logits = for i:time j:vocab. (randn $ ixkey2 (newKey 0) i j)
-- Create random labels
labels = for i:position. randIdxNoZero vocab (newKey (ordinal i))
:p labels
> [1@(Fin 6), 5@(Fin 6), 5@(Fin 6)]
> [(1@Fin 6), (5@Fin 6), (5@Fin 6)]

-- Evaluate marginal probability of labels given logits
:p exp $ ctc blank logits labels
Expand Down
Loading

0 comments on commit f8f9282

Please sign in to comment.