Skip to content

Generic UNet implementation written in pure Julia, based on Flux.jl

License

Notifications You must be signed in to change notification settings

DhairyaLGandhi/UNet.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

67 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

UNet.jl

Actions Status

This pacakge provides a generic UNet implemented in Julia.

The package is built on top of Flux.jl, and therefore can be extended as needed

julia> u = Unet()
UNet:
  ConvDown(64, 64)
  ConvDown(128, 128)
  ConvDown(256, 256)
  ConvDown(512, 512)


  UNetConvBlock(1, 3)
  UNetConvBlock(3, 64)
  UNetConvBlock(64, 128)
  UNetConvBlock(128, 256)
  UNetConvBlock(256, 512)
  UNetConvBlock(512, 1024)
  UNetConvBlock(1024, 1024)


  UNetUpBlock(1024, 512)
  UNetUpBlock(1024, 256)
  UNetUpBlock(512, 128)
  UNetUpBlock(256, 64)

The default input channel dimension is expected to be 1 ie. grayscale. To support different channel images, you can pass the channels to Unet.

julia> u = Unet(3) # for RGB images

The input size can be any power of two sized batch. Something like (256,256, channels, batch_size).

The default output channel dimension is the input channel dimension. So, 1 for a Unet() and e.g. 3 for a Unet(3). The output channel dimension can be set by supplying a second argument:

julia> u = Unet(3, 5) # 3 input channels, 5 output channels.

GPU Support

To train the model on UNet, it is as simple as calling gpu on the model.

julia> u = Unet();

julia> u = gpu(u);

julia> r = gpu(rand(Float32, 256, 256, 1, 1));

julia> size(u(r))
(256, 256, 1, 1)

Training

Training UNet is a breeze too.

You can define your own loss function, or use Flux binary cross entropy implementation.

using UNet, Flux,  Base.Iterators
import Flux.Losses.binarycrossentropy

device = gpu #cpu

function loss(x, y)
    op = clamp.(u(x), 0.001f0, 1.f0)
    binarycrossentropy(op,y)
end

u = Unet() |> device
w = rand(Float32, 256, 256, 1, 1) |> device
w′ = rand(Float32, 256, 256, 1, 1) |> device
rep = Iterators.repeated((w, w′), 10)

opt = ADAM()

Flux.train!(loss, Flux.params(u), rep, opt, cb = () -> @show(loss(w, w′)))

Further Reading

The package is an implementation of the paper, and all credits of the model itself go to the respective authors.

About

Generic UNet implementation written in pure Julia, based on Flux.jl

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages