diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000..ce8d353 --- /dev/null +++ b/.github/workflows/CompatHelper.yml @@ -0,0 +1,16 @@ +name: CompatHelper + +on: + schedule: + - cron: '00 00 * * *' + +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Pkg.add("CompatHelper") + run: julia -e 'using Pkg; Pkg.add("CompatHelper")' + - name: CompatHelper.main() + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: julia -e 'using CompatHelper; CompatHelper.main()' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3f71db4..0774533 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,9 +7,11 @@ jobs: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: version: - '1.3' + - '1' - 'nightly' os: - ubuntu-latest @@ -23,5 +25,16 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} + - name: Cache artifacts + uses: actions/cache@v1 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest diff --git a/Project.toml b/Project.toml index 47bbf4c..969d788 100644 --- a/Project.toml +++ b/Project.toml @@ -1,21 +1,25 @@ name = "UNet" uuid = "0d73aaa9-994a-4556-95d0-da67cb772a03" authors = ["dhairyagandhi "] -version = "0.1.0" +version = "0.2.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" +ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" +ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] -Distributions = "0.20" -Flux = "0.10" -Images = "0.22" -Reexport = "0.2.2" +Distributions = "0.20, 0.21, 0.22, 0.23" +FileIO = "1" +Flux = "0.10, 0.11" +ImageCore = "0.8" +ImageTransformations = "0.8" +Reexport = "0.2" StatsBase = "0.30" julia = "1.3" diff --git a/README.md b/README.md index 82a6437..bc04987 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ UNet: UNetUpBlock(256, 64) ``` -To default input channel dimension is expected to be `1` ie. grayscale. To support different channel images, you can pass the `channels` to `Unet`. +The default input channel dimension is expected to be `1` ie. grayscale. To support different channel images, you can pass the `channels` to `Unet`. ```julia julia> u = Unet(3) # for RGB images @@ -38,6 +38,13 @@ julia> u = Unet(3) # for RGB images The input size can be any power of two sized batch. Something like `(256,256, channels, batch_size)`. +The default output channel dimension is the input channel dimension. So, `1` for a `Unet()` and e.g. `3` for a `Unet(3)`. +The output channel dimension can be set by supplying a second argument: + +```julia +julia> u = Unet(3, 5) # 3 input channels, 5 output channels. +``` + ## GPU Support To train the model on UNet, it is as simple as calling `gpu` on the model. diff --git a/src/UNet.jl b/src/UNet.jl index 0aead27..f339cd7 100644 --- a/src/UNet.jl +++ b/src/UNet.jl @@ -7,7 +7,9 @@ using Reexport using Flux using Flux: @functor -using Images +using ImageCore +using ImageTransformations: imresize +using FileIO using Distributions: Normal @reexport using Statistics diff --git a/src/dataloader.jl b/src/dataloader.jl index 8d18e7f..bd27aeb 100644 --- a/src/dataloader.jl +++ b/src/dataloader.jl @@ -6,9 +6,9 @@ using StatsBase: sample, shuffle function load_img(base::String, stub::String; rsize = (256,256)) im = joinpath(base, stub * "_im.tif") cr = joinpath(base, stub * "_cr.tif") - x, y = Images.load(im), Images.load(cr) - x = Images.imresize(x, rsize...) - y = Images.imresize(y, rsize...) + x, y = load(im), load(cr) + x = imresize(x, rsize...) + y = imresize(y, rsize...) x, y = channelview(x), channelview(y) x = reshape(x, rsize..., 1, 1) y = reshape(y, rsize..., 1, 1) @@ -55,10 +55,10 @@ function load_batch(base::String, name_template = "im", x = zeros(Float32, rsize..., channels, n) # [] y = zeros(Float32, rsize..., channels, n) # [] for (i,(img, mask)) in enumerate(batch) - img = Images.load(joinpath(base, dir, img)) - mask = Images.load(joinpath(base, dir, mask)) - img = Images.imresize(img, rsize...) - mask = Images.imresize(mask, rsize...) + img = load(joinpath(base, dir, img)) + mask = load(joinpath(base, dir, mask)) + img = imresize(img, rsize...) + mask = imresize(mask, rsize...) img = channelview(img) mask = channelview(mask) img = reshape(img, rsize..., channels) diff --git a/src/model.jl b/src/model.jl index b672f6f..76c7de1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -33,9 +33,10 @@ function (u::UNetUpBlock)(x, bridge) end """ - Unet(channels::Int = 1) + Unet(channels::Int = 1, labels::Int = channels) - Initializes a [UNet](https://arxiv.org/pdf/1505.04597.pdf) instance with the given number of channels, typically equal to the number of channels in the input images. + Initializes a [UNet](https://arxiv.org/pdf/1505.04597.pdf) instance with the given number of `channels`, typically equal to the number of channels in the input images. + `labels`, equal to the number of input channels by default, specifies the number of output channels. """ struct Unet conv_down_blocks @@ -45,7 +46,7 @@ end @functor Unet -function Unet(channels::Int = 1) +function Unet(channels::Int = 1, labels::Int = channels) conv_down_blocks = Chain(ConvDown(64,64), ConvDown(128,128), ConvDown(256,256), @@ -64,7 +65,7 @@ function Unet(channels::Int = 1) UNetUpBlock(512, 128), UNetUpBlock(256, 64,p = 0.0f0), Chain(x->leakyrelu.(x,0.2f0), - Conv((1, 1), 128=>channels;init=_random_normal))) + Conv((1, 1), 128=>labels;init=_random_normal))) Unet(conv_down_blocks, conv_blocks, up_blocks) end diff --git a/test/runtests.jl b/test/runtests.jl index f3b3bdc..ca1bf78 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,10 @@ using UNet.Flux, UNet.Flux.Zygote @test size(u(ip)) == size(ip) end + + u = Unet(2,5) + ip = rand(Float32, 256, 256, 2, 1) + @test size(u(ip)) == (256, 256, 5, 1) end @testset "Variable Sizes" begin