Skip to content

Commit

Permalink
more bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhairya Gandhi committed Sep 28, 2020
2 parents 59dae8d + 7660286 commit b901662
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 19 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: CompatHelper

on:
schedule:
- cron: '00 00 * * *'

jobs:
CompatHelper:
runs-on: ubuntu-latest
steps:
- name: Pkg.add("CompatHelper")
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
13 changes: 13 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ jobs:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1.3'
- '1'
- 'nightly'
os:
- ubuntu-latest
Expand All @@ -23,5 +25,16 @@ jobs:
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- name: Cache artifacts
uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
16 changes: 10 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
name = "UNet"
uuid = "0d73aaa9-994a-4556-95d0-da67cb772a03"
authors = ["dhairyagandhi <[email protected]>"]
version = "0.1.0"
version = "0.2.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Distributions = "0.20"
Flux = "0.10"
Images = "0.22"
Reexport = "0.2.2"
Distributions = "0.20, 0.21, 0.22, 0.23"
FileIO = "1"
Flux = "0.10, 0.11"
ImageCore = "0.8"
ImageTransformations = "0.8"
Reexport = "0.2"
StatsBase = "0.30"
julia = "1.3"

Expand Down
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,21 @@ UNet:
UNetUpBlock(256, 64)
```

To default input channel dimension is expected to be `1` ie. grayscale. To support different channel images, you can pass the `channels` to `Unet`.
The default input channel dimension is expected to be `1` ie. grayscale. To support different channel images, you can pass the `channels` to `Unet`.

```julia
julia> u = Unet(3) # for RGB images
```

The input size can be any power of two sized batch. Something like `(256,256, channels, batch_size)`.

The default output channel dimension is the input channel dimension. So, `1` for a `Unet()` and e.g. `3` for a `Unet(3)`.
The output channel dimension can be set by supplying a second argument:

```julia
julia> u = Unet(3, 5) # 3 input channels, 5 output channels.
```

## GPU Support

To train the model on UNet, it is as simple as calling `gpu` on the model.
Expand Down
4 changes: 3 additions & 1 deletion src/UNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ using Reexport
using Flux
using Flux: @functor

using Images
using ImageCore
using ImageTransformations: imresize
using FileIO
using Distributions: Normal

@reexport using Statistics
Expand Down
14 changes: 7 additions & 7 deletions src/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ using StatsBase: sample, shuffle
function load_img(base::String, stub::String; rsize = (256,256))
im = joinpath(base, stub * "_im.tif")
cr = joinpath(base, stub * "_cr.tif")
x, y = Images.load(im), Images.load(cr)
x = Images.imresize(x, rsize...)
y = Images.imresize(y, rsize...)
x, y = load(im), load(cr)
x = imresize(x, rsize...)
y = imresize(y, rsize...)
x, y = channelview(x), channelview(y)
x = reshape(x, rsize..., 1, 1)
y = reshape(y, rsize..., 1, 1)
Expand Down Expand Up @@ -55,10 +55,10 @@ function load_batch(base::String, name_template = "im",
x = zeros(Float32, rsize..., channels, n) # []
y = zeros(Float32, rsize..., channels, n) # []
for (i,(img, mask)) in enumerate(batch)
img = Images.load(joinpath(base, dir, img))
mask = Images.load(joinpath(base, dir, mask))
img = Images.imresize(img, rsize...)
mask = Images.imresize(mask, rsize...)
img = load(joinpath(base, dir, img))
mask = load(joinpath(base, dir, mask))
img = imresize(img, rsize...)
mask = imresize(mask, rsize...)
img = channelview(img)
mask = channelview(mask)
img = reshape(img, rsize..., channels)
Expand Down
9 changes: 5 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ function (u::UNetUpBlock)(x, bridge)
end

"""
Unet(channels::Int = 1)
Unet(channels::Int = 1, labels::Int = channels)
Initializes a [UNet](https://arxiv.org/pdf/1505.04597.pdf) instance with the given number of channels, typically equal to the number of channels in the input images.
Initializes a [UNet](https://arxiv.org/pdf/1505.04597.pdf) instance with the given number of `channels`, typically equal to the number of channels in the input images.
`labels`, equal to the number of input channels by default, specifies the number of output channels.
"""
struct Unet
conv_down_blocks
Expand All @@ -45,7 +46,7 @@ end

@functor Unet

function Unet(channels::Int = 1)
function Unet(channels::Int = 1, labels::Int = channels)
conv_down_blocks = Chain(ConvDown(64,64),
ConvDown(128,128),
ConvDown(256,256),
Expand All @@ -64,7 +65,7 @@ function Unet(channels::Int = 1)
UNetUpBlock(512, 128),
UNetUpBlock(256, 64,p = 0.0f0),
Chain(x->leakyrelu.(x,0.2f0),
Conv((1, 1), 128=>channels;init=_random_normal)))
Conv((1, 1), 128=>labels;init=_random_normal)))
Unet(conv_down_blocks, conv_blocks, up_blocks)
end

Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ using UNet.Flux, UNet.Flux.Zygote

@test size(u(ip)) == size(ip)
end

u = Unet(2,5)
ip = rand(Float32, 256, 256, 2, 1)
@test size(u(ip)) == (256, 256, 5, 1)
end

@testset "Variable Sizes" begin
Expand Down

7 comments on commit b901662

@DhairyaLGandhi
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegsitrator register

@DilumAluthge
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DhairyaLGandhi typo I think

@DilumAluthge
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Register Failed
@DilumAluthge, it looks like you don't have collaborator status on this repository.

@DilumAluthge
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DhairyaLGandhi Can you register?

@DhairyaLGandhi
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/21791

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" b90166297bc04257b36f28311206dedb68eeffec
git push origin v0.2.0

Please sign in to comment.