Skip to content

Commit

Permalink
comment surgery helpers for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
shelhamer committed Sep 10, 2016
1 parent 03c1dae commit 0b07df9
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
import numpy as np

def transplant(new_net, net, suffix=''):
"""
Transfer weights by copying matching parameters, coercing parameters of
incompatible shape, and dropping unmatched parameters.
The coercion is useful to convert fully connected layers to their
equivalent convolutional layers, since the weights are the same and only
the shapes are different. In particular, equivalent fully connected and
convolution layers have shapes O x I and O x I x H x W respectively for O
outputs channels, I input channels, H kernel height, and W kernel width.
Both `net` to `new_net` arguments must be instantiated `caffe.Net`s.
"""
for p in net.params:
p_new = p + suffix
if p_new not in new_net.params:
Expand All @@ -18,12 +30,10 @@ def transplant(new_net, net, suffix=''):
print 'copying', p, ' -> ', p_new, i
new_net.params[p_new][i].data.flat = net.params[p][i].data.flat

def expand_score(new_net, new_layer, net, layer):
old_cl = net.params[layer][0].num
new_net.params[new_layer][0].data[:old_cl][...] = net.params[layer][0].data
new_net.params[new_layer][1].data[0,0,0,:old_cl][...] = net.params[layer][1].data

def upsample_filt(size):
"""
Make a 2D bilinear kernel suitable for upsampling of the given (h, w) size.
"""
factor = (size + 1) // 2
if size % 2 == 1:
center = factor - 1
Expand All @@ -34,6 +44,9 @@ def upsample_filt(size):
(1 - abs(og[1] - center) / factor)

def interp(net, layers):
"""
Set weights of each layer in layers to bilinear kernels for interpolation.
"""
for l in layers:
m, k, h, w = net.params[l][0].data.shape
if m != k and k != 1:
Expand All @@ -44,3 +57,12 @@ def interp(net, layers):
raise
filt = upsample_filt(h)
net.params[l][0].data[range(m), range(k), :, :] = filt

def expand_score(new_net, new_layer, net, layer):
"""
Transplant an old score layer's parameters, with k < k' classes, into a new
score layer with k classes s.t. the first k' are the old classes.
"""
old_cl = net.params[layer][0].num
new_net.params[new_layer][0].data[:old_cl][...] = net.params[layer][0].data
new_net.params[new_layer][1].data[0,0,0,:old_cl][...] = net.params[layer][1].data

0 comments on commit 0b07df9

Please sign in to comment.